├── in2in ├── __init__.py ├── utils │ ├── __init__.py │ ├── hml3d_mean.npy │ ├── hml3d_std.npy │ ├── interhuman_mean.npy │ ├── interhuman_std.npy │ ├── preprocess.py │ ├── configs.py │ ├── plot.py │ ├── paramUtil.py │ ├── skeleton.py │ ├── metrics.py │ └── quaternion.py ├── evaluation │ ├── __init__.py │ ├── utils.py │ ├── models.py │ └── datasets.py ├── models │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── utils.py │ │ ├── layers.py │ │ ├── cfg_sampler.py │ │ └── losses.py │ ├── dualmdm.py │ ├── in2in.py │ └── nets.py ├── datasets │ ├── __init__.py │ ├── dataloader.py │ ├── humanml3d.py │ └── interhuman.py └── scripts │ ├── eval │ ├── DualMDM.py │ └── interhuman.py │ ├── infer.py │ └── train.py ├── .gitignore ├── assets └── cover.png ├── requirements.txt ├── configs ├── eval.yaml ├── train │ ├── in2IN.yaml │ └── individual.yaml ├── infer.yaml ├── models │ ├── individual.yaml │ ├── in2IN.yaml │ └── DualMDM.yaml └── datasets.yaml ├── setup.py ├── LICENSE └── README.md /in2in/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /in2in/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /in2in/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /in2in/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /in2in/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/assets/cover.png -------------------------------------------------------------------------------- /in2in/utils/hml3d_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/hml3d_mean.npy -------------------------------------------------------------------------------- /in2in/utils/hml3d_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/hml3d_std.npy -------------------------------------------------------------------------------- /in2in/utils/interhuman_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/interhuman_mean.npy -------------------------------------------------------------------------------- /in2in/utils/interhuman_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/interhuman_std.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | lightning 4 | scipy 5 | matplotlib 6 | pillow 7 | yacs 8 | mmcv 9 | opencv-python 10 | tabulate 11 | termcolor 12 | smplx 13 | torch 14 | torchvision 15 | torchaudio 16 | pykeops 17 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterCLIP 2 | NUM_LAYERS: 8 3 | NUM_HEADS: 8 4 | DROPOUT: 0.1 5 | INPUT_DIM: 258 6 | LATENT_DIM: 1024 7 | FF_SIZE: 2048 8 | ACTIVATION: gelu 9 | 10 | MOTION_REP: global 11 | CHECKPOINT: checkpoints/evaluator/interclip.ckpt 12 | FINETUNE: False 13 | EXTENDED: True 14 | -------------------------------------------------------------------------------- /configs/train/in2IN.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: in2IN 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 32 10 | EPOCH: 2000 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_EPOCH: 50 14 | RESUME: 15 | NUM_WORKERS: 4 16 | MODE: finetune 17 | LAST_EPOCH: 0 18 | LAST_ITER: 0 19 | -------------------------------------------------------------------------------- /configs/train/individual.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: DualMDMIndividual 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 64 10 | EPOCH: 2000 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_EPOCH: 50 14 | RESUME: 15 | NUM_WORKERS: 8 16 | MODE: finetune 17 | LAST_EPOCH: 0 18 | LAST_ITER: 0 19 | -------------------------------------------------------------------------------- /configs/infer.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | EXP_NAME: INFERENCE 3 | CHECKPOINT: ./checkpoints 4 | LOG_DIR: ./log 5 | 6 | TRAIN: 7 | LR: 1e-4 8 | WEIGHT_DECAY: 0.00002 9 | BATCH_SIZE: 32 10 | EPOCH: 2000 11 | STEP: 1000000 12 | LOG_STEPS: 10 13 | SAVE_STEPS: 20000 14 | SAVE_EPOCH: 100 15 | RESUME: 16 | NUM_WORKERS: 2 17 | MODE: finetune 18 | LAST_EPOCH: 0 19 | LAST_ITER: 0 20 | -------------------------------------------------------------------------------- /configs/models/individual.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterGen 2 | NUM_LAYERS: 8 3 | NUM_HEADS: 8 4 | DROPOUT: 0.1 5 | INPUT_DIM: 262 6 | LATENT_DIM: 1024 7 | FF_SIZE: 2048 8 | ACTIVATION: gelu 9 | CHECKPOINT: checkpoints/DualMDM/epoch=1999-step=728000.pt 10 | 11 | DIFFUSION_STEPS: 1000 12 | BETA_SCHEDULER: cosine 13 | SAMPLER: uniform 14 | 15 | MOTION_REP: global 16 | FINETUNE: False 17 | 18 | TEXT_ENCODER: clip 19 | T_BAR: 700 20 | 21 | CONTROL: text 22 | STRATEGY: ddim50 23 | CFG_WEIGHT: 3.5 24 | 25 | -------------------------------------------------------------------------------- /configs/models/in2IN.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterGen 2 | NUM_LAYERS: 8 3 | NUM_HEADS: 8 4 | DROPOUT: 0.1 5 | INPUT_DIM: 262 6 | LATENT_DIM: 1024 7 | FF_SIZE: 2048 8 | ACTIVATION: gelu 9 | CHECKPOINT: checkpoints/in2IN/epoch=1999-step=352000.pt 10 | 11 | DIFFUSION_STEPS: 1000 12 | BETA_SCHEDULER: cosine 13 | SAMPLER: uniform 14 | 15 | MOTION_REP: global 16 | FINETUNE: False 17 | 18 | TEXT_ENCODER: clip 19 | T_BAR: 700 20 | 21 | CONTROL: text 22 | STRATEGY: ddim50 23 | CFG_WEIGHT: 3 24 | CFG_WEIGHT_INTERACTION: 3 25 | CFG_WEIGHT_INDIVIDUAL: 1 -------------------------------------------------------------------------------- /in2in/models/dualmdm.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from in2in.models.in2in import in2IN 4 | import torch 5 | 6 | def load_DualMDM_model(model_cfg): 7 | """ 8 | Load the I2I model with the 2 checkpoints 9 | :param model_cfg: Model Configuration file 10 | :return: I2I model 11 | """ 12 | model = in2IN(model_cfg, mode="dual") 13 | print("Model Created") 14 | ckpt = torch.load(model_cfg.CHECKPOINT_INTERACTION) 15 | ckpt_individual = torch.load(model_cfg.CHECKPOINT_INDIVIDUAL) 16 | ckpt.update(ckpt_individual) 17 | model.load_state_dict(ckpt, strict=True) 18 | 19 | return model 20 | -------------------------------------------------------------------------------- /configs/models/DualMDM.yaml: -------------------------------------------------------------------------------- 1 | NAME: InterGen 2 | NUM_LAYERS: 8 3 | NUM_HEADS: 8 4 | DROPOUT: 0.1 5 | INPUT_DIM: 262 6 | LATENT_DIM: 1024 7 | FF_SIZE: 2048 8 | ACTIVATION: gelu 9 | CHECKPOINT_INTERACTION: checkpoints/in2IN/epoch=1999-step=352000.pt 10 | CHECKPOINT_INDIVIDUAL: checkpoints/DualMDM/epoch=1999-step=728000.pt 11 | 12 | DIFFUSION_STEPS: 1000 13 | BETA_SCHEDULER: cosine 14 | SAMPLER: uniform 15 | 16 | MOTION_REP: global 17 | FINETUNE: False 18 | 19 | TEXT_ENCODER: clip 20 | T_BAR: 700 21 | 22 | CONTROL: text 23 | STRATEGY: ddim50 24 | 25 | CFG_WEIGHT_INDIVIDUAL: 2.5 26 | CFG_WEIGHT_INTERACTION: 3.5 27 | W_FUNC: exp 28 | W_VALUE: 0.0085 -------------------------------------------------------------------------------- /configs/datasets.yaml: -------------------------------------------------------------------------------- 1 | interhuman: 2 | NAME: interhuman 3 | DATA_ROOT: ./data/ 4 | MOTION_REP: global 5 | MODE: train 6 | CACHE: True 7 | EXTENDED: True 8 | 9 | interhuman_val: 10 | NAME: interhuman 11 | DATA_ROOT: ./data/ 12 | MOTION_REP: global 13 | MODE: val 14 | CACHE: True 15 | EXTENDED: True 16 | 17 | interhuman_test: 18 | NAME: interhuman 19 | DATA_ROOT: ./data/ 20 | MOTION_REP: global 21 | MODE: test 22 | CACHE: True 23 | EXTENDED: True 24 | 25 | humanml3d: 26 | NAME: humanml3d 27 | DATA_ROOT: ./data/HumanML3D/ 28 | MOTION_REP: global 29 | MODE: train 30 | CACHE: True 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="in2IN", 8 | version="1.0", 9 | description="", 10 | author="Pablo Ruiz Ponce", 11 | packages=find_packages(include=['in2in', 'in2in.*']), 12 | install_requires=[ 13 | 'numpy', 14 | 'tqdm', 15 | 'lightning', 16 | 'scipy', 17 | 'matplotlib', 18 | 'pillow', 19 | 'yacs', 20 | 'mmcv', 21 | 'opencv-python', 22 | 'tabulate', 23 | 'termcolor', 24 | 'smplx', 25 | 'torch', 26 | 'torchvision', 27 | 'torchaudio', 28 | 'pykeops', 29 | ], 30 | dependency_links=[ 31 | 'git+https://github.com/openai/CLIP.git', 32 | ], 33 | package_data={'': ['*.npy']}, 34 | ) -------------------------------------------------------------------------------- /in2in/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .utils import * 3 | 4 | FPS = 30 5 | 6 | def load_motion(file_path, min_length, swap=False): 7 | """ 8 | Load motions from the original dataset with all the information needed to the convert them to Interhuman format 9 | :param file_path: path to the motion file 10 | :param min_length: minimum length of the motion 11 | :param swap: swap the left and right side of the motion 12 | """ 13 | try: 14 | motion = np.load(file_path).astype(np.float32) 15 | except: 16 | print("error: ", file_path) 17 | return None, None 18 | 19 | # Reshape motion 20 | motion1 = motion[:, :22 * 3] 21 | motion2 = motion[:, 62 * 3:62 * 3 + 21 * 6] 22 | motion = np.concatenate([motion1, motion2], axis=1) 23 | 24 | # If the motion is to short, return none. 25 | if motion.shape[0] < min_length: 26 | return None, None 27 | 28 | # Swap 29 | if swap: 30 | motion_swap = swap_left_right(motion, 22) 31 | else: 32 | motion_swap = None 33 | 34 | return motion, motion_swap 35 | 36 | def load_motion_hml3d(pos_file_path, rot_file_path , min_length): 37 | """ 38 | Load motions from hml3d dataset with all the information needed to the convert them to Interhuman format 39 | :param pos_file_path: path to the position file 40 | :param rot_file_path: path to the rotation file 41 | """ 42 | 43 | # Try to extract motions from the original dataset 44 | try: 45 | pos_motion = np.load(pos_file_path).astype(np.float32) 46 | rot_motion = np.load(rot_file_path).astype(np.float32) 47 | except: 48 | print("error: ", pos_motion) 49 | return None, None 50 | 51 | # Conver postition from (LENGHT, JOINTS, 3) to (LENGHT, JOINTS*3) 52 | pos_motion = pos_motion[:,:22] 53 | pos_motion = pos_motion.reshape(pos_motion.shape[0], -1)[:-1,:] 54 | 55 | # Extract relative rotations from HumanML3D representation 56 | rot_motion = rot_motion[:,4+(21*3)+(22*3):4+(21*3)+(22*3)+(21*6)].reshape(rot_motion.shape[0], -1) 57 | 58 | # Concatenate position and rotation 59 | motion = np.concatenate([pos_motion, rot_motion], axis=1) 60 | 61 | if motion.shape[0] < min_length: 62 | return None, None 63 | 64 | return motion, None -------------------------------------------------------------------------------- /in2in/utils/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | def to_lower(x: Dict) -> Dict: 7 | """ 8 | Convert all dictionary keys to lowercase 9 | Args: 10 | x (dict): Input dictionary 11 | Returns: 12 | dict: Output dictionary with all keys converted to lowercase 13 | """ 14 | return {k.lower(): v for k, v in x.items()} 15 | 16 | _C = CN(new_allowed=True) 17 | 18 | def default_config() -> CN: 19 | """ 20 | Get a yacs CfgNode object with the default config values. 21 | """ 22 | # Return a clone so that the defaults will not be altered 23 | # This is for the "local variable" use pattern 24 | return _C.clone() 25 | 26 | def get_config(config_file: str, merge: bool = True) -> CN: 27 | """ 28 | Read a config file and optionally merge it with the default config file. 29 | Args: 30 | config_file (str): Path to config file. 31 | merge (bool): Whether to merge with the default config or not. 32 | Returns: 33 | CfgNode: Config as a yacs CfgNode object. 34 | """ 35 | if merge: 36 | cfg = default_config() 37 | else: 38 | cfg = CN(new_allowed=True) 39 | cfg.merge_from_file(config_file) 40 | cfg.freeze() 41 | return cfg 42 | 43 | def get_config_model(config_file: str, merge: bool = True, w_func:str = None, w_value:float = None) -> CN: 44 | """ 45 | Read a config file and optionally merge it with the default config file. 46 | Args: 47 | config_file (str): Path to config file. 48 | merge (bool): Whether to merge with the default config or not. 49 | Returns: 50 | CfgNode: Config as a yacs CfgNode object. 51 | """ 52 | if merge: 53 | cfg = default_config() 54 | else: 55 | cfg = CN(new_allowed=True) 56 | cfg.merge_from_file(config_file) 57 | 58 | 59 | cfg.W_FUNC = w_func 60 | cfg.W_VALUE = w_value 61 | cfg.freeze() 62 | return cfg 63 | 64 | def dataset_config() -> CN: 65 | """ 66 | Get dataset config file 67 | Returns: 68 | CfgNode: Dataset config as a yacs CfgNode object. 69 | """ 70 | cfg = CN(new_allowed=True) 71 | config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets.yaml') 72 | cfg.merge_from_file(config_file) 73 | cfg.freeze() 74 | return cfg 75 | 76 | -------------------------------------------------------------------------------- /in2in/models/utils/blocks.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | class TransformerBlock(nn.Module): 4 | def __init__(self, 5 | latent_dim=512, 6 | num_heads=8, 7 | ff_size=1024, 8 | dropout=0., 9 | cond_abl=False, 10 | **kargs): 11 | super().__init__() 12 | self.latent_dim = latent_dim 13 | self.num_heads = num_heads 14 | self.dropout = dropout 15 | self.cond_abl = cond_abl 16 | 17 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout) 18 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim) 19 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim) 20 | 21 | def forward(self, x, y, emb=None, key_padding_mask=None): 22 | h1 = self.sa_block(x, emb, key_padding_mask) 23 | h1 = h1 + x 24 | h2 = self.ca_block(h1, y, emb, key_padding_mask) 25 | h2 = h2 + h1 26 | out = self.ffn(h2, emb) 27 | out = out + h2 28 | return out 29 | 30 | class TransformerBlockDoubleCond(nn.Module): 31 | def __init__(self, 32 | mode, 33 | latent_dim=512, 34 | num_heads=8, 35 | ff_size=1024, 36 | dropout=0., 37 | cond_abl=False, 38 | **kargs): 39 | super().__init__() 40 | self.latent_dim = latent_dim 41 | self.num_heads = num_heads 42 | self.dropout = dropout 43 | self.cond_abl = cond_abl 44 | self.mode = mode 45 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout) 46 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim) 47 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim) 48 | 49 | def forward(self, x, y, emb=None, emb_interaction=None, key_padding_mask=None): 50 | h1 = self.sa_block(x, emb, key_padding_mask) 51 | h1 = h1 + x 52 | # If only individual, we ingore the cross attention layers 53 | if self.mode == "individual" or self.mode == "dual_individual": 54 | h2 = h1 55 | else: 56 | h2 = self.ca_block(h1, y, emb_interaction, key_padding_mask) 57 | h2 = h2 + h1 58 | 59 | # All modes have the FFN layer 60 | out = self.ffn(h2, emb) 61 | out = out + h2 62 | return out 63 | 64 | -------------------------------------------------------------------------------- /in2in/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | import torch 3 | from .interhuman import InterHuman 4 | from .humanml3d import HumanML3D 5 | 6 | class DataModuleHML3D(pl.LightningDataModule): 7 | def __init__(self, cfg, batch_size, num_workers): 8 | """ 9 | Initialize LightningDataModule for ProHMR training 10 | Args: 11 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. 12 | dataset_cfg (CfgNode): Dataset configuration file 13 | """ 14 | super().__init__() 15 | self.cfg = cfg 16 | self.batch_size = batch_size 17 | self.num_workers = num_workers 18 | 19 | def setup(self, stage = None): 20 | """ 21 | Create train and validation datasets 22 | """ 23 | if self.cfg.NAME == "humanml3d": 24 | self.train_dataset = HumanML3D(self.cfg) 25 | else: 26 | raise NotImplementedError 27 | 28 | def train_dataloader(self): 29 | """ 30 | Return train dataloader 31 | """ 32 | return torch.utils.data.DataLoader( 33 | self.train_dataset, 34 | batch_size=self.batch_size, 35 | num_workers=self.num_workers, 36 | pin_memory=False, 37 | shuffle=True, 38 | drop_last=True, 39 | ) 40 | 41 | class DataModule(pl.LightningDataModule): 42 | def __init__(self, cfg, batch_size, num_workers): 43 | """ 44 | Initialize LightningDataModule for ProHMR training 45 | Args: 46 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. 47 | dataset_cfg (CfgNode): Dataset configuration file 48 | """ 49 | super().__init__() 50 | self.cfg = cfg 51 | self.batch_size = batch_size 52 | self.num_workers = num_workers 53 | 54 | def setup(self, stage = None): 55 | """ 56 | Create train and validation datasets 57 | """ 58 | if self.cfg.NAME == "interhuman": 59 | self.train_dataset = InterHuman(self.cfg) 60 | else: 61 | raise NotImplementedError 62 | 63 | def train_dataloader(self): 64 | """ 65 | Return train dataloader 66 | """ 67 | return torch.utils.data.DataLoader( 68 | self.train_dataset, 69 | batch_size=self.batch_size, 70 | num_workers=self.num_workers, 71 | pin_memory=False, 72 | shuffle=True, 73 | drop_last=True, 74 | ) 75 | -------------------------------------------------------------------------------- /in2in/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | from functools import partial 5 | from typing import Optional, Union 6 | from mmcv.runner import get_dist_info 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Dataset 9 | 10 | 11 | def build_dataloader(dataset: Dataset, 12 | samples_per_gpu: int, 13 | workers_per_gpu: int, 14 | num_gpus: Optional[int] = 1, 15 | shuffle: Optional[bool] = True, 16 | round_up: Optional[bool] = True, 17 | seed: Optional[Union[int, None]] = None, 18 | persistent_workers: Optional[bool] = True, 19 | **kwargs): 20 | """Build PyTorch DataLoader. 21 | 22 | In distributed training, each GPU/process has a dataloader. 23 | In non-distributed training, there is only one dataloader for all GPUs. 24 | 25 | Args: 26 | dataset (:obj:`Dataset`): A PyTorch dataset. 27 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 28 | batch size of each GPU. 29 | workers_per_gpu (int): How many subprocesses to use for data loading 30 | for each GPU. 31 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed 32 | training. 33 | dist (bool, optional): Distributed training/test or not. Default: True. 34 | shuffle (bool, optional): Whether to shuffle the data at every epoch. 35 | Default: True. 36 | round_up (bool, optional): Whether to round up the length of dataset by 37 | adding extra samples to make it evenly divisible. Default: True. 38 | persistent_workers (bool): If True, the data loader will not shutdown 39 | the worker processes after a dataset has been consumed once. 40 | This allows to maintain the workers Dataset instances alive. 41 | The argument also has effect in PyTorch>=1.7.0. 42 | Default: True 43 | kwargs: any keyword argument to be used to initialize DataLoader 44 | 45 | Returns: 46 | DataLoader: A PyTorch dataloader. 47 | """ 48 | rank, world_size = get_dist_info() 49 | sampler = None 50 | batch_size = num_gpus * samples_per_gpu 51 | num_workers = num_gpus * workers_per_gpu 52 | 53 | init_fn = partial( 54 | worker_init_fn, num_workers=num_workers, rank=rank, 55 | seed=seed) if seed is not None else None 56 | 57 | data_loader = DataLoader( 58 | dataset, 59 | batch_size=batch_size, 60 | sampler=sampler, 61 | num_workers=num_workers, 62 | pin_memory=False, 63 | shuffle=shuffle, 64 | worker_init_fn=init_fn, 65 | persistent_workers=persistent_workers, 66 | **kwargs) 67 | 68 | 69 | 70 | return data_loader 71 | 72 | 73 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): 74 | """Init random seed for each worker.""" 75 | # The seed of each worker equals to 76 | # num_worker * rank + worker_id + user_seed 77 | worker_seed = num_workers * rank + worker_id + seed 78 | np.random.seed(worker_seed) 79 | random.seed(worker_seed) 80 | -------------------------------------------------------------------------------- /in2in/models/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | 5 | 6 | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, warmup, max_iters, verbose=False): 8 | self.warmup = warmup 9 | self.max_num_iters = max_iters 10 | super().__init__(optimizer, verbose=verbose) 11 | 12 | def get_lr(self): 13 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 14 | return [base_lr * lr_factor for base_lr in self.base_lrs] 15 | 16 | def get_lr_factor(self, epoch): 17 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 18 | if epoch <= self.warmup: 19 | lr_factor *= (epoch+1) * 1.0 / self.warmup 20 | return lr_factor 21 | 22 | 23 | 24 | class PositionalEncoding(nn.Module): 25 | def __init__(self, d_model, dropout=0.0, max_len=5000): 26 | super(PositionalEncoding, self).__init__() 27 | self.dropout = nn.Dropout(p=dropout) 28 | 29 | pe = torch.zeros(max_len, d_model) 30 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 31 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 32 | pe[:, 0::2] = torch.sin(position * div_term) 33 | pe[:, 1::2] = torch.cos(position * div_term) 34 | 35 | self.register_buffer('pe', pe) 36 | 37 | def forward(self, x): 38 | x = x + self.pe[:x.shape[1], :].unsqueeze(0) 39 | return self.dropout(x) 40 | 41 | class TimestepEmbedder(nn.Module): 42 | def __init__(self, latent_dim, sequence_pos_encoder): 43 | super().__init__() 44 | self.latent_dim = latent_dim 45 | self.sequence_pos_encoder = sequence_pos_encoder 46 | 47 | time_embed_dim = self.latent_dim 48 | self.time_embed = nn.Sequential( 49 | nn.Linear(self.latent_dim, time_embed_dim), 50 | nn.SiLU(), 51 | nn.Linear(time_embed_dim, time_embed_dim), 52 | ) 53 | 54 | def forward(self, timesteps): 55 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]) 56 | 57 | 58 | class IdentityEmbedder(nn.Module): 59 | def __init__(self, latent_dim, sequence_pos_encoder): 60 | super().__init__() 61 | self.latent_dim = latent_dim 62 | self.sequence_pos_encoder = sequence_pos_encoder 63 | 64 | time_embed_dim = self.latent_dim 65 | self.time_embed = nn.Sequential( 66 | nn.Linear(self.latent_dim, time_embed_dim), 67 | nn.SiLU(), 68 | nn.Linear(time_embed_dim, time_embed_dim), 69 | ) 70 | 71 | def forward(self, timesteps): 72 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).unsqueeze(1) 73 | 74 | 75 | def set_requires_grad(nets, requires_grad=False): 76 | """Set requies_grad for all the networks. 77 | 78 | Args: 79 | nets (nn.Module | list[nn.Module]): A list of networks or a single 80 | network. 81 | requires_grad (bool): Whether the networks require gradients or not 82 | """ 83 | if not isinstance(nets, list): 84 | nets = [nets] 85 | for net in nets: 86 | if net is not None: 87 | for param in net.parameters(): 88 | param.requires_grad = requires_grad 89 | 90 | 91 | def zero_module(module): 92 | """ 93 | Zero out the parameters of a module and return it. 94 | """ 95 | for p in module.parameters(): 96 | p.detach().zero_() 97 | return module 98 | -------------------------------------------------------------------------------- /in2in/models/utils/layers.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | 3 | class AdaLN(nn.Module): 4 | def __init__(self, latent_dim, embed_dim=None): 5 | super().__init__() 6 | if embed_dim is None: 7 | embed_dim = latent_dim 8 | self.emb_layers = nn.Sequential( 9 | # nn.Linear(embed_dim, latent_dim, bias=True), 10 | nn.SiLU(), 11 | zero_module(nn.Linear(embed_dim, 2 * latent_dim, bias=True)), 12 | ) 13 | self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False, eps=1e-6) 14 | 15 | def forward(self, h, emb): 16 | """ 17 | h: B, T, D 18 | emb: B, D 19 | """ 20 | # B, 1, 2D 21 | emb_out = self.emb_layers(emb) 22 | # scale: B, 1, D / shift: B, 1, D 23 | scale, shift = torch.chunk(emb_out, 2, dim=-1) 24 | h = self.norm(h) * (1 + scale[:, None]) + shift[:, None] 25 | return h 26 | 27 | 28 | class VanillaSelfAttention(nn.Module): 29 | def __init__(self, latent_dim, num_head, dropout, embed_dim=None): 30 | super().__init__() 31 | self.num_head = num_head 32 | self.norm = AdaLN(latent_dim, embed_dim) 33 | self.attention = nn.MultiheadAttention(latent_dim, num_head, dropout=dropout, batch_first=True, 34 | add_zero_attn=True) 35 | 36 | def forward(self, x, emb, key_padding_mask=None): 37 | """ 38 | x: B, T, D 39 | """ 40 | x_norm = self.norm(x, emb) 41 | y = self.attention(x_norm, x_norm, x_norm, 42 | attn_mask=None, 43 | key_padding_mask=key_padding_mask, 44 | need_weights=False)[0] 45 | return y 46 | 47 | 48 | class VanillaCrossAttention(nn.Module): 49 | def __init__(self, latent_dim, xf_latent_dim, num_head, dropout, embed_dim=None): 50 | super().__init__() 51 | self.num_head = num_head 52 | self.norm = AdaLN(latent_dim, embed_dim) 53 | self.xf_norm = AdaLN(xf_latent_dim, embed_dim) 54 | self.attention = nn.MultiheadAttention(latent_dim, num_head, kdim=xf_latent_dim, vdim=xf_latent_dim, 55 | dropout=dropout, batch_first=True, add_zero_attn=True) 56 | 57 | def forward(self, x, xf, emb, key_padding_mask=None): 58 | """ 59 | x: B, T, D 60 | xf: B, N, L 61 | """ 62 | x_norm = self.norm(x, emb) 63 | xf_norm = self.xf_norm(xf, emb) 64 | y = self.attention(x_norm, xf_norm, xf_norm, 65 | attn_mask=None, 66 | key_padding_mask=key_padding_mask, 67 | need_weights=False)[0] 68 | return y 69 | 70 | 71 | class FFN(nn.Module): 72 | def __init__(self, latent_dim, ffn_dim, dropout, embed_dim=None): 73 | super().__init__() 74 | self.norm = AdaLN(latent_dim, embed_dim) 75 | self.linear1 = nn.Linear(latent_dim, ffn_dim, bias=True) 76 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim, bias=True)) 77 | self.activation = nn.GELU() 78 | self.dropout = nn.Dropout(dropout) 79 | 80 | def forward(self, x, emb=None): 81 | if emb is not None: 82 | x_norm = self.norm(x, emb) 83 | else: 84 | x_norm = x 85 | y = self.linear2(self.dropout(self.activation(self.linear1(x_norm)))) 86 | return y 87 | 88 | 89 | class FinalLayer(nn.Module): 90 | def __init__(self, latent_dim, out_dim): 91 | super().__init__() 92 | self.linear = zero_module(nn.Linear(latent_dim, out_dim, bias=True)) 93 | 94 | def forward(self, x): 95 | x = self.linear(x) 96 | return x -------------------------------------------------------------------------------- /in2in/utils/plot.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import mpl_toolkits.mplot3d.axes3d as p3 6 | 7 | from mpl_toolkits.mplot3d import Axes3D 8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 9 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter 10 | from tqdm import tqdm 11 | 12 | 13 | def plot_3d_motion(save_path, kinematic_tree, mp_joints, title, figsize=(10, 10), fps=120, radius=6, mode='interaction'): 14 | """ 15 | Function to plot an interaction between two agents in 3D in matplotlib 16 | :param save_path: path to save the animation 17 | :param kinematic_tree: kinematic tree of the motion 18 | :param mp_joints: list of motion data for each agent 19 | :param title: title of the plot 20 | :param figsize: size of the figure 21 | :param fps: frames per second of the animation 22 | :param radius: radius of the plot 23 | :param mode: mode of the plot 24 | """ 25 | matplotlib.use('Agg') 26 | 27 | # Define initial limits of the plot 28 | def init(): 29 | ax.set_xlim3d([-radius / 4, radius / 4]) 30 | ax.set_ylim3d([0, radius / 2]) 31 | ax.set_zlim3d([0, radius / 2]) 32 | ax.grid(b=False) 33 | 34 | # Funtion to plot a floor in the animation 35 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 36 | verts = [ 37 | [minx, miny, minz], 38 | [minx, miny, maxz], 39 | [maxx, miny, maxz], 40 | [maxx, miny, minz] 41 | ] 42 | xz_plane = Poly3DCollection([verts]) 43 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 44 | ax.add_collection3d(xz_plane) 45 | 46 | 47 | # Create the figure and axis 48 | fig = plt.figure(figsize=figsize) 49 | ax = fig.add_subplot(111, projection='3d') 50 | init() 51 | 52 | # Offsets and colors 53 | mp_offset = list(range(-len(mp_joints)//2, len(mp_joints)//2, 1)) 54 | colors = ['red', 'blue', 'black', 'red', 'blue', 55 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', 56 | 'darkred', 'darkred', 'darkred', 'darkred', 'darkred'] 57 | mp_colors = [[colors[i]] * 15 for i in range(len(mp_offset))] 58 | 59 | # Store the data for each agent 60 | mp_data = [] 61 | for i,joints in enumerate(mp_joints): 62 | 63 | data = joints.copy().reshape(len(joints), -1, 3) 64 | 65 | MINS = data.min(axis=0).min(axis=0) 66 | MAXS = data.max(axis=0).max(axis=0) 67 | 68 | height_offset = MINS[1] 69 | data[:, :, 1] -= height_offset 70 | trajec = data[:, 0, [0, 2]] 71 | 72 | mp_data.append({"joints":data, 73 | "MINS":MINS, 74 | "MAXS":MAXS, 75 | "trajec":trajec, }) 76 | 77 | def update(index): 78 | """ 79 | Update function for the matplotlib animation 80 | :param index: index of the frame 81 | """ 82 | # Update the progress bar 83 | bar.update(1) 84 | 85 | # Clear the axis and setting initial parameters 86 | ax.clear() 87 | plt.axis('off') 88 | ax.view_init(elev=120, azim=-90) 89 | ax.dist = 7.5 90 | ax.set_xticklabels([]) 91 | ax.set_yticklabels([]) 92 | ax.set_zticklabels([]) 93 | 94 | # Plot the floor 95 | plot_xzPlane(-3, 3, 0, -3, 3) 96 | 97 | # Plot each of the persons in the motion 98 | for pid,data in enumerate(mp_data): 99 | for i, (chain, color) in enumerate(zip(kinematic_tree, mp_colors[pid])): 100 | linewidth = 3.0 101 | ax.plot3D(data["joints"][index, chain, 0], 102 | data["joints"][index, chain, 1], 103 | data["joints"][index, chain, 2], 104 | linewidth=linewidth, 105 | color=color, 106 | alpha=1) 107 | 108 | # Generate animation 109 | frame_number = min([data.shape[0] for data in mp_joints]) 110 | bar = tqdm(total=frame_number+1) 111 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) 112 | ani.save(save_path, fps=fps) 113 | plt.close() 114 | 115 | -------------------------------------------------------------------------------- /in2in/scripts/eval/DualMDM.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(sys.path[0]+r"/../../../") 3 | 4 | import torch 5 | import random 6 | import argparse 7 | import numpy as np 8 | 9 | from in2in.utils.configs import get_config 10 | from in2in.models.dualmdm import load_DualMDM_model 11 | from in2in.utils.metrics import calculate_wasserstein 12 | from in2in.evaluation.utils import EvaluatorModelWrapper, get_dataset_motion_loader, get_motion_loader_DualMDM 13 | 14 | # Randon Seed Configuration 15 | np.random.seed(0) 16 | random.seed(0) 17 | 18 | def calculate_individual_diversity(generated_motion_embeddings): 19 | s_int = generated_motion_embeddings[:num_times] 20 | s_ind = generated_motion_embeddings[num_times:] 21 | 22 | diff, corrs_1_to_2, corrs_2_to_1 = calculate_wasserstein(s_int, s_ind, max_iters=500, verbose=True) 23 | return diff 24 | 25 | 26 | def evaluation(): 27 | metrics = { 28 | 'Individual Diversity': [], 29 | } 30 | 31 | for i in range(replication_times): 32 | print("Replication: ", i) 33 | 34 | individual_diversity = [] 35 | eval_motion_loader = eval_motion_loader_getter() 36 | 37 | for bidx ,eval_data in enumerate(eval_motion_loader): 38 | 39 | generated_motions1, generated_motions2, motion1, motion2, motion_lens, text, text_individual1, text_individual2 = eval_data 40 | 41 | # This is needed in order to work the motion emeddings 42 | generated_motions1 = generated_motions1[0] 43 | generated_motions2 = generated_motions2[0] 44 | motion1 = motion1[0] 45 | motion2 = motion2[0] 46 | motion_lens = motion_lens[0] 47 | 48 | generated_motion_embeddings = eval_wrapper.get_motion_embeddings([ 49 | "name", 50 | text, 51 | generated_motions1, 52 | generated_motions2, 53 | motion_lens, 54 | text_individual1, 55 | text_individual2 56 | ]) 57 | 58 | motion_len = motion_lens[0].item() 59 | generated_motions1 = generated_motions1[:,:motion_len,:] 60 | generated_motions2 = generated_motions2[:,:motion_len,:] 61 | motion1 = motion1[:,:motion_len,:] 62 | motion2 = motion2[:,:motion_len,:] 63 | 64 | individual_diversity.append(calculate_individual_diversity(generated_motion_embeddings).cpu().numpy().tolist()) 65 | 66 | metrics['Individual Diversity'].append(np.mean(individual_diversity)) 67 | print("Individual Diversity: ", np.mean(individual_diversity)) 68 | 69 | print("---- Final Metrics ----") 70 | print("Individual Diversity: ", np.mean(metrics['Individual Diversity']), ", std: ",np.std(metrics['Individual Diversity'])) 71 | 72 | if __name__ == '__main__': 73 | 74 | # Configuration values 75 | num_samples = 100 76 | num_times = 32 77 | replication_times = 5 78 | batch_size = 1 79 | 80 | # Create the parser 81 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments") 82 | 83 | # Add optional arguments 84 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file') 85 | parser.add_argument('--evaluator', type=str, required=True, help='Evaluator Configuration file') 86 | parser.add_argument('--device', type=int, default=0, help='GPU device id') 87 | 88 | # Parse the arguments 89 | args = parser.parse_args() 90 | 91 | # Loading configuration files 92 | data_cfg = get_config("configs/datasets.yaml").interhuman_test 93 | model_cfg = get_config(args.model) 94 | evalmodel_cfg = get_config(args.evaluator) 95 | 96 | # Cuda configuration 97 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu') 98 | torch.cuda.set_device(args.device) 99 | 100 | # Build and Load Model 101 | model = load_DualMDM_model(model_cfg) 102 | 103 | # Get Datasets and DataLoaders 104 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size, num_samples=num_samples) 105 | eval_motion_loader_getter = lambda: get_motion_loader_DualMDM(batch_size, model, gt_dataset, device, num_times) 106 | 107 | # Evaluator Model 108 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device) 109 | evaluation() 110 | -------------------------------------------------------------------------------- /in2in/datasets/humanml3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | 5 | from tqdm import tqdm 6 | from utils.utils import * 7 | from torch.utils import data 8 | from utils.preprocess import * 9 | from os.path import join as pjoin 10 | 11 | class HumanML3D(data.Dataset): 12 | """ 13 | HumanML3D dataset 14 | """ 15 | def __init__(self, opt): 16 | 17 | # Configuration variables 18 | self.opt = opt 19 | self.max_cond_length = 1 20 | self.min_cond_length = 1 21 | self.max_gt_length = 300 22 | self.min_gt_length = 15 23 | self.max_length = self.max_cond_length + self.max_gt_length -1 24 | self.min_length = self.min_cond_length + self.min_gt_length -1 25 | self.motion_rep = opt.MOTION_REP 26 | self.cache = opt.CACHE 27 | 28 | # Data structures 29 | self.motion_dict = {} 30 | self.data_list = [] 31 | data_list = [] 32 | 33 | # Load paths from the given split 34 | if self.opt.MODE == "train": 35 | try: 36 | data_list = open(os.path.join(opt.DATA_ROOT, "train.txt"), "r").readlines() 37 | except Exception as e: 38 | print(e) 39 | elif self.opt.MODE == "val": 40 | try: 41 | data_list = open(os.path.join(opt.DATA_ROOT, "val.txt"), "r").readlines() 42 | except Exception as e: 43 | print(e) 44 | elif self.opt.MODE == "test": 45 | try: 46 | data_list = open(os.path.join(opt.DATA_ROOT, "test.txt"), "r").readlines() 47 | except Exception as e: 48 | print(e) 49 | 50 | # Suffle paths 51 | random.shuffle(data_list) 52 | 53 | # Load data 54 | index = 0 55 | motion_path = pjoin(opt.DATA_ROOT, "interhuman/") 56 | for file in tqdm(os.listdir(motion_path)): 57 | 58 | # Comment if you want to use the whole dataset 59 | if file.split(".")[0]+"\n" not in data_list: 60 | continue 61 | 62 | motion_name = file.split(".")[0] 63 | motion_file_path = pjoin(motion_path, file) 64 | text_path = motion_file_path.replace("interhuman", "texts").replace("npy", "txt") 65 | 66 | # Load motion and text 67 | texts = [item.replace("\n", "") for item in open(text_path, "r").readlines()] 68 | motion1 = np.load(motion_file_path).astype(np.float32) 69 | 70 | # Check if the motion is too short 71 | if motion1.shape[0] < self.min_length: 72 | continue 73 | 74 | # Cache the motion if needed 75 | if self.cache: 76 | self.motion_dict[index] = motion1 77 | else: 78 | self.motion_dict[index] = motion_file_path 79 | 80 | self.data_list.append({ 81 | "name": motion_name, 82 | "motion_id": index, 83 | "swap":False, 84 | "texts":texts 85 | }) 86 | 87 | index += 1 88 | 89 | print("Total Dataset Size: ", len(self.data_list)) 90 | 91 | def __len__(self): 92 | """ 93 | Get the length of the dataset 94 | """ 95 | return len(self.data_list) 96 | 97 | def __getitem__(self, item): 98 | """ 99 | Get an item from the dataset 100 | param item: Index of the item to get 101 | """ 102 | 103 | # Get the data from the dataset 104 | idx = item % self.__len__() 105 | data = self.data_list[idx] 106 | name = data["name"] 107 | motion_id = data["motion_id"] 108 | 109 | # Select a random text from the list 110 | text = random.choice(data["texts"]).strip().split('#')[0] 111 | 112 | # Load the motion 113 | if self.cache: 114 | full_motion1 = self.motion_dict[motion_id] 115 | else: 116 | file_path1 = self.motion_dict[motion_id] 117 | full_motion1 = np.load(file_path1).astype(np.float32) 118 | 119 | # Get motion lenght and select a random segment 120 | length = full_motion1.shape[0] 121 | if length > self.max_length: 122 | idx = random.choice(list(range(0, length - self.max_gt_length, 1))) 123 | gt_length = self.max_gt_length 124 | motion1 = full_motion1[idx:idx + gt_length] 125 | else: 126 | idx = 0 127 | gt_length = min(length, self.max_gt_length ) 128 | motion1 = full_motion1[idx:idx + gt_length] 129 | 130 | # Check if the motion is too short and pad it 131 | gt_motion1 = motion1 132 | gt_length = len(gt_motion1) 133 | if gt_length < self.max_gt_length: 134 | padding_len = self.max_gt_length - gt_length 135 | D = gt_motion1.shape[1] 136 | padding_zeros = np.zeros((padding_len, D)) 137 | gt_motion1 = np.concatenate((gt_motion1, padding_zeros), axis=0) 138 | 139 | # Return the data 140 | return name, text, gt_motion1, gt_length 141 | 142 | -------------------------------------------------------------------------------- /in2in/models/utils/cfg_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class ClassifierFreeSampleModel(nn.Module): 6 | 7 | def __init__(self, model, cfg_scale): 8 | super().__init__() 9 | self.model = model # model is the actual model to run 10 | self.s = cfg_scale 11 | 12 | def forward(self, x, timesteps, cond=None, mask=None): 13 | B, T, D = x.shape 14 | 15 | x_combined = torch.cat([x, x], dim=0) 16 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0) 17 | if cond is not None: 18 | cond = torch.cat([cond, torch.zeros_like(cond)], dim=0) 19 | if mask is not None: 20 | mask = torch.cat([mask, mask], dim=0) 21 | 22 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask) 23 | 24 | out_cond = out[:B] 25 | out_uncond = out[B:] 26 | 27 | cfg_out = self.s * out_cond + (1-self.s) * out_uncond 28 | return cfg_out 29 | 30 | 31 | class ClassifierFreeSampleModelMultiple(nn.Module): 32 | def __init__(self, model, cfg_scale, cfg_scale_interaction, cfg_scale_individuals): 33 | super().__init__() 34 | self.model = model 35 | self.s = cfg_scale 36 | self.s_interaction = cfg_scale_interaction 37 | self.s_individuals = cfg_scale_individuals 38 | 39 | def forward(self, x, timesteps, cond=None, mask=None): 40 | B, T, D = x.shape 41 | 42 | x_combined = torch.cat([x, x, x, x], dim=0) 43 | timesteps_combined = torch.cat([timesteps, timesteps, timesteps, timesteps], dim=0) 44 | if cond is not None: 45 | 46 | cond_full = cond 47 | 48 | cond_interaction = torch.zeros_like(cond) 49 | cond_interaction[:,:768] = cond[:,:768] 50 | 51 | cond_individuals = torch.zeros_like(cond) 52 | cond_individuals[:,768:] = cond[:,768:] 53 | 54 | cond = torch.cat([cond_full, cond_interaction, cond_individuals, torch.zeros_like(cond)], dim=0) 55 | if mask is not None: 56 | mask = torch.cat([mask, mask, mask, mask], dim=0) 57 | 58 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask) 59 | 60 | out_cond = out[:B] 61 | out_cond_interaction = out[B:B*2] 62 | out_cond_individuals = out[B*2:B*3] 63 | out_uncond = out[B*3:] 64 | 65 | cfg_out = (self.s * out_cond) + (self.s_interaction * out_cond_interaction) + (self.s_individuals * out_cond_individuals) + ((1-(self.s+self.s_interaction+self.s_individuals)) * out_uncond) 66 | return cfg_out 67 | 68 | 69 | class ClassifierFreeSampleDualMDM(nn.Module): 70 | 71 | def __init__(self, m_individual, m_interaction, s_individual, s_interaction, s_composition_func, s_composition_value): 72 | super().__init__() 73 | self.m_individual = m_individual 74 | self.m_interaction = m_interaction 75 | self.s_individual = s_individual 76 | self.s_interaction = s_interaction 77 | self.s_composition = self.weight(s_composition_func, s_composition_value) 78 | 79 | 80 | def weight(self, func, value): 81 | print(f"Diffusion Weight Scheduler func: {func}, value: {value}") 82 | 83 | if func == "exp": 84 | return lambda x: np.exp(-value * (1000 - x))[0] 85 | elif func == "exp-inv": 86 | return lambda x: 1 - np.exp(-value * (1000 - x))[0] 87 | elif func == "lin": 88 | return lambda x: 1 - ((1000 - x) / 1000)[0] 89 | elif func == "const": 90 | return lambda x: value 91 | else: 92 | raise ValueError("Unknown function") 93 | 94 | def forward(self, x, timesteps, cond=None, mask=None): 95 | B, T, D = x.shape 96 | 97 | x_combined = torch.cat([x, x], dim=0) 98 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0) 99 | 100 | if cond is not None: 101 | cond_combined = torch.cat([cond, torch.zeros_like(cond)], dim=0) 102 | 103 | if mask is not None: 104 | mask_combined = torch.cat([mask, mask], dim=0) 105 | else: 106 | mask_combined = None 107 | 108 | out_interaction = self.m_interaction(x_combined, timesteps_combined, cond=cond_combined, mask=mask_combined) 109 | out_individual = self.m_individual(x_combined, timesteps_combined, cond=cond_combined, mask=mask_combined) 110 | 111 | out_interaction_cond = out_interaction[:B] 112 | out_interaction_uncond = out_interaction[B:] 113 | 114 | out_individual_cond = out_individual[:B] 115 | out_individual_uncond = out_individual[B:] 116 | 117 | cfg_out_interaction = (out_interaction_uncond + self.s_interaction * (out_interaction_cond - out_interaction_uncond)) 118 | cfg_out_individual = (out_individual_uncond + self.s_individual * (out_individual_cond - out_individual_uncond)) 119 | 120 | w = self.s_composition(timesteps.cpu().numpy()) 121 | cfg_out = cfg_out_interaction + w * (cfg_out_individual - cfg_out_interaction) 122 | return cfg_out 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /in2in/scripts/infer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(sys.path[0] + r"/../../") 3 | 4 | from collections import OrderedDict 5 | import copy 6 | import os.path 7 | import argparse 8 | import torch 9 | from scipy.ndimage import gaussian_filter1d 10 | import numpy as np 11 | 12 | from lightning import LightningModule 13 | from in2in.utils.utils import MotionNormalizer, MotionNormalizerHML3D 14 | from in2in.utils.plot import plot_3d_motion 15 | from in2in.utils.paramUtil import HML_KINEMATIC_CHAIN 16 | from in2in.utils.configs import get_config 17 | from in2in.models.dualmdm import load_DualMDM_model 18 | from in2in.models.in2in import in2IN 19 | 20 | class LitGenModel(LightningModule): 21 | def __init__(self, model, cfg, save_folder, mode): 22 | super().__init__() 23 | # cfg init 24 | self.cfg = cfg 25 | 26 | self.automatic_optimization = False 27 | self.save_folder = save_folder 28 | 29 | # train model init 30 | self.model = model 31 | 32 | self.save_folder = os.path.join("results",save_folder) 33 | if not os.path.exists(self.save_folder): 34 | os.makedirs(self.save_folder) 35 | 36 | self.mode = mode 37 | 38 | if self.mode == "individual": 39 | self.normalizer = MotionNormalizerHML3D() 40 | else: 41 | self.normalizer = MotionNormalizer() 42 | 43 | def plot_t2m(self, mp_data, result_path, caption): 44 | mp_joint = [] 45 | 46 | if self.mode == "individual": 47 | joint = mp_data.reshape(-1,22,3) 48 | mp_joint.append(joint) 49 | else: 50 | for i, data in enumerate(mp_data): 51 | if i == 0: 52 | joint = data[:,:22*3].reshape(-1,22,3) 53 | else: 54 | joint = data[:,:22*3].reshape(-1,22,3) 55 | 56 | mp_joint.append(joint) 57 | 58 | plot_3d_motion(result_path, HML_KINEMATIC_CHAIN, mp_joint, title=caption, fps=30) 59 | 60 | 61 | def generate_one_sample(self, prompt_interaction, prompt_individual1, prompt_individual2, name): 62 | self.model.eval() 63 | batch = OrderedDict({}) 64 | 65 | batch["motion_lens"] = torch.zeros(1,1).long().cuda() 66 | batch["prompt_interaction"] = prompt_interaction 67 | 68 | if self.mode != "individual": 69 | batch["prompt_individual1"] = prompt_individual1 70 | batch["prompt_individual2"] = prompt_individual2 71 | 72 | window_size = 210 73 | motion_output = self.generate_loop(batch, window_size) 74 | result_path = f"{self.save_folder}/{name}.mp4" 75 | 76 | if self.mode == "individual": 77 | self.plot_t2m(motion_output, 78 | result_path, 79 | batch["prompt_interaction"]) 80 | else: 81 | self.plot_t2m([motion_output[0], motion_output[1]], 82 | result_path, 83 | batch["prompt_interaction"]) 84 | 85 | def generate_loop(self, batch, window_size): 86 | prompt_interaction = batch["prompt_interaction"] 87 | 88 | if self.mode != "individual": 89 | prompt_individual1 = batch["prompt_individual1"] 90 | prompt_individual2 = batch["prompt_individual2"] 91 | 92 | batch = copy.deepcopy(batch) 93 | batch["motion_lens"][:] = window_size 94 | 95 | batch["text"] = [prompt_interaction] 96 | if self.mode != "individual": 97 | batch["text_individual1"] = [prompt_individual1] 98 | batch["text_individual2"] = [prompt_individual2] 99 | 100 | batch = self.model.forward_test(batch) 101 | 102 | if self.mode == "individual": 103 | motion_output = batch["output"][0].reshape(-1, 262) 104 | motion_output = self.normalizer.backward(motion_output.cpu().detach().numpy()) 105 | joints3d = motion_output[:,:22*3].reshape(-1,22,3) 106 | joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') 107 | return joints3d 108 | 109 | motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1) 110 | motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy()) 111 | 112 | sequences = [[], []] 113 | for j in range(2): 114 | motion_output = motion_output_both[:,j] 115 | joints3d = motion_output[:,:22*3].reshape(-1,22,3) 116 | joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') 117 | sequences[j].append(joints3d) 118 | 119 | sequences[0] = np.concatenate(sequences[0], axis=0) 120 | sequences[1] = np.concatenate(sequences[1], axis=0) 121 | return sequences 122 | 123 | if __name__ == '__main__': 124 | 125 | # Create the parser 126 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments") 127 | 128 | # Add optional arguments 129 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file') 130 | parser.add_argument('--infer', type=str, required=True, help='Infer Configuration file') 131 | parser.add_argument('--mode', type=str, required=True, help='Mode of the inference (individual, interaction, dual)') 132 | parser.add_argument('--out', type=str, required=True, help='Folder to save the results') 133 | parser.add_argument('--device', type=str, required=True, help='Device to run the model') 134 | 135 | parser.add_argument('--text_interaction', type=str, required=True, help='Interaction prompt') 136 | parser.add_argument('--text_individual1', type=str, required=False, help='Individual 1 prompt') 137 | parser.add_argument('--text_individual2', type=str, required=False, help='Individual 2 prompt') 138 | parser.add_argument('--name', type=str, required=True, help='Name of the output file') 139 | 140 | # Parse the arguments 141 | args = parser.parse_args() 142 | 143 | model_cfg = get_config(args.model) 144 | infer_cfg = get_config(args.infer) 145 | 146 | if args.mode == "dual": 147 | model = load_DualMDM_model(model_cfg) 148 | else: 149 | model = in2IN(model_cfg, args.mode) 150 | model.load_state_dict(torch.load(model_cfg.CHECKPOINT), strict=True) 151 | 152 | litmodel = LitGenModel(model, infer_cfg, args.out, mode=args.mode).to(torch.device("cuda:"+ args.device)) 153 | litmodel.generate_one_sample(args.text_interaction, args.text_individual1, args.text_individual2, args.name) 154 | 155 | -------------------------------------------------------------------------------- /in2in/models/in2in.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | 4 | from torch import nn 5 | from in2in.models.nets import in2INDiffusion 6 | from in2in.models.utils.utils import set_requires_grad 7 | 8 | class in2IN(nn.Module): 9 | def __init__(self, cfg, mode): 10 | super().__init__() 11 | self.cfg = cfg 12 | self.latent_dim = cfg.LATENT_DIM 13 | # Mode can be individual interaction or dual 14 | self.mode = mode 15 | 16 | # DECODER (Denoiser) 17 | self.decoder = in2INDiffusion(cfg, mode, sampling_strategy=cfg.STRATEGY) 18 | 19 | # TEXT ENCODER (Trainable) - 1 FOR EACH MODEL # 20 | # INTERACTION 21 | if self.mode == "interaction" or self.mode == "dual": 22 | clipTransEncoderLayer_interaction = nn.TransformerEncoderLayer( 23 | d_model=768, 24 | nhead=8, 25 | dim_feedforward=2048, 26 | dropout=0.1, 27 | activation="gelu", 28 | batch_first=True) 29 | 30 | self.clipTransEncoder_interaction = nn.TransformerEncoder( 31 | clipTransEncoderLayer_interaction, 32 | num_layers=2) 33 | 34 | self.clip_ln_interaction = nn.LayerNorm(768) 35 | 36 | # INDIVIDUAL 37 | if self.mode == "individual" or self.mode == "dual": 38 | clipTransEncoderLayer_individual = nn.TransformerEncoderLayer( 39 | d_model=768, 40 | nhead=8, 41 | dim_feedforward=2048, 42 | dropout=0.1, 43 | activation="gelu", 44 | batch_first=True) 45 | 46 | self.clipTransEncoder_individual = nn.TransformerEncoder( 47 | clipTransEncoderLayer_individual, 48 | num_layers=2) 49 | 50 | self.clip_ln_individual = nn.LayerNorm(768) 51 | 52 | # CLIP MODEL (No trainable) 53 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) 54 | 55 | self.token_embedding = clip_model.token_embedding 56 | self.clip_transformer = clip_model.transformer 57 | self.positional_embedding = clip_model.positional_embedding 58 | self.ln_final = clip_model.ln_final 59 | self.dtype = clip_model.dtype 60 | 61 | set_requires_grad(self.clip_transformer, False) 62 | set_requires_grad(self.token_embedding, False) 63 | set_requires_grad(self.ln_final, False) 64 | 65 | def compute_loss(self, batch): 66 | if self.mode == "dual": 67 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction") 68 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1") 69 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2") 70 | batch = self.text_process(batch, mode="individual",text_name="text_individual1", out_name="cond_individual_individual1") 71 | batch = self.text_process(batch, mode="individual",text_name="text_individual2", out_name="cond_individual_individual2") 72 | elif self.mode == "interaction" : 73 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction") 74 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1") 75 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2") 76 | elif self.mode == "individual": 77 | batch = self.text_process(batch, mode="individual", out_name="cond_individual_individual1") 78 | 79 | losses = self.decoder.compute_loss(batch) 80 | return losses["total"], losses 81 | 82 | def decode_motion(self, batch): 83 | batch.update(self.decoder(batch)) 84 | return batch 85 | 86 | def forward(self, batch): 87 | return self.compute_loss(batch) 88 | 89 | def forward_test(self, batch): 90 | if self.mode == "dual": 91 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction") 92 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1") 93 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2") 94 | batch = self.text_process(batch, mode="individual",text_name="text_individual1", out_name="cond_individual_individual1") 95 | batch = self.text_process(batch, mode="individual",text_name="text_individual2", out_name="cond_individual_individual2") 96 | elif self.mode == "interaction" : 97 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction") 98 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1") 99 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2") 100 | elif self.mode == "individual": 101 | batch = self.text_process(batch, mode="individual", out_name="cond_individual_individual1") 102 | 103 | batch.update(self.decode_motion(batch)) 104 | return batch 105 | 106 | def text_process(self, batch, mode, text_name="text", out_name="cond"): 107 | device = next(self.clip_transformer.parameters()).device 108 | raw_text = batch[text_name] 109 | 110 | with torch.no_grad(): 111 | 112 | text = clip.tokenize(raw_text, truncate=True).to(device) 113 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 114 | pe_tokens = x + self.positional_embedding.type(self.dtype) 115 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND 116 | x = self.clip_transformer(x) 117 | x = x.permute(1, 0, 2) 118 | clip_out = self.ln_final(x).type(self.dtype) 119 | 120 | if mode == "individual": 121 | out = self.clipTransEncoder_individual(clip_out) 122 | out = self.clip_ln_individual(out) 123 | elif mode == "interaction": 124 | out = self.clipTransEncoder_interaction(clip_out) 125 | out = self.clip_ln_interaction(out) 126 | else: 127 | raise ValueError("Mode not recognized") 128 | 129 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] 130 | batch[out_name] = cond 131 | 132 | return batch 133 | -------------------------------------------------------------------------------- /in2in/utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### CONSTANTS ### 4 | 5 | HML_RAW_OFFSETS = np.array([[0,0,0], 6 | [1,0,0], 7 | [-1,0,0], 8 | [0,1,0], 9 | [0,-1,0], 10 | [0,-1,0], 11 | [0,1,0], 12 | [0,-1,0], 13 | [0,-1,0], 14 | [0,1,0], 15 | [0,0,1], 16 | [0,0,1], 17 | [0,1,0], 18 | [1,0,0], 19 | [-1,0,0], 20 | [0,0,1], 21 | [0,-1,0], 22 | [0,-1,0], 23 | [0,-1,0], 24 | [0,-1,0], 25 | [0,-1,0], 26 | [0,-1,0]]) 27 | HML_KINEMATIC_CHAIN = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 28 | HML_LEFT_HAND_CHAIN = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 29 | HML_RIGHT_HAND_CHAIN = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 30 | HML_TGT_SKEL_ID = '000021' 31 | HML_JOINT_NAMES = [ 32 | 'pelvis', 33 | 'left_hip', 34 | 'right_hip', 35 | 'spine1', 36 | 'left_knee', 37 | 'right_knee', 38 | 'spine2', 39 | 'left_ankle', 40 | 'right_ankle', 41 | 'spine3', 42 | 'left_foot', 43 | 'right_foot', 44 | 'neck', 45 | 'left_collar', 46 | 'right_collar', 47 | 'head', 48 | 'left_shoulder', 49 | 'right_shoulder', 50 | 'left_elbow', 51 | 'right_elbow', 52 | 'left_wrist', 53 | 'right_wrist', 54 | ] 55 | NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints 56 | HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 57 | 'left_hip', 58 | 'right_hip', 59 | 'left_knee', 60 | 'right_knee', 61 | 'left_ankle', 62 | 'right_ankle', 63 | 'left_foot', 64 | 'right_foot',]] 65 | SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS] 66 | HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1)) 67 | HML_ROOT_MASK = np.concatenate(([True]*(1+2+1), 68 | HML_ROOT_BINARY[1:].repeat(3), 69 | HML_ROOT_BINARY[1:].repeat(6), 70 | HML_ROOT_BINARY.repeat(3), 71 | [False] * 4)) 72 | HML_ROOT_HORIZONTAL_MASK = np.concatenate(([True]*(1+2) + [False], 73 | np.zeros_like(HML_ROOT_BINARY[1:].repeat(3)), 74 | np.zeros_like(HML_ROOT_BINARY[1:].repeat(6)), 75 | np.zeros_like(HML_ROOT_BINARY.repeat(3)), 76 | [False] * 4)) 77 | HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) 78 | HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1), 79 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3), 80 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6), 81 | HML_LOWER_BODY_JOINTS_BINARY.repeat(3), 82 | [True]*4)) 83 | HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK 84 | HML_TRAJ_MASK = np.zeros_like(HML_ROOT_MASK) 85 | HML_TRAJ_MASK[1:3] = True 86 | NUM_HML_FEATS = 263 87 | L_IDX1, L_IDX2 = 5, 8 # Lower legs 88 | FID_R, FID_L = [8, 11], [7, 10] # Right/Left foot 89 | FACE_JOINT_INDX = [2, 1, 17, 16] # Face direction, r_hip, l_hip, sdr_r, sdr_l 90 | R_HIP, L_HIP = 2, 1 # l_hip, r_hip 91 | JOINTS_NUM = 22 92 | 93 | ### FUNCTIONS ### 94 | 95 | def expand_mask(mask, shape): 96 | """ 97 | expands a mask of shape (num_feat, seq_len) to the requested shape (usually, (batch_size, num_feat, 1, seq_len)) 98 | """ 99 | _, num_feat, _, _ = shape 100 | return np.ones(shape) * mask.reshape((1, num_feat, 1, -1)) 101 | 102 | def get_joints_mask(join_names): 103 | joins_mask = np.array([joint_name in join_names for joint_name in HML_JOINT_NAMES]) 104 | mask = np.concatenate(([False]*(1+2+1), 105 | joins_mask[1:].repeat(3), 106 | np.zeros_like(joins_mask[1:].repeat(6)), 107 | np.zeros_like(joins_mask.repeat(3)), 108 | [False] * 4)) 109 | return mask 110 | 111 | def get_batch_joint_mask(shape, joint_names): 112 | return expand_mask(get_joints_mask(joint_names), shape) 113 | 114 | def get_in_between_mask(shape, lengths, prefix_end, suffix_end): 115 | mask = np.ones(shape) # True means use gt motion 116 | for i, length in enumerate(lengths): 117 | start_idx, end_idx = int(prefix_end * length), int(suffix_end * length) 118 | mask[i, :, :, start_idx: end_idx] = 0 # do inpainting in those frames 119 | return mask 120 | 121 | def get_prefix_mask(shape, prefix_length=20): 122 | _, num_feat, _, seq_len = shape 123 | prefix_mask = np.concatenate((np.ones((num_feat, prefix_length)), np.zeros((num_feat, seq_len - prefix_length))), axis=-1) 124 | return expand_mask(prefix_mask, shape) 125 | 126 | def get_inpainting_mask(mask_name, shape, **kwargs): 127 | mask_names = mask_name.split(',') 128 | 129 | mask = np.zeros(shape) 130 | if 'in_between' in mask_names: 131 | mask = np.maximum(mask, get_in_between_mask(shape, **kwargs)) 132 | 133 | if 'root' in mask_names: 134 | mask = np.maximum(mask, expand_mask(HML_ROOT_MASK, shape)) 135 | 136 | if 'root_horizontal' in mask_names: 137 | mask = np.maximum(mask, expand_mask(HML_ROOT_HORIZONTAL_MASK, shape)) 138 | 139 | if 'prefix' in mask_names: 140 | mask = np.maximum(mask, get_prefix_mask(shape, **kwargs)) 141 | 142 | if 'upper_body' in mask_names: 143 | mask = np.maximum(mask, expand_mask(HML_UPPER_BODY_MASK, shape)) 144 | 145 | if 'lower_body' in mask_names: 146 | mask = np.maximum(mask, expand_mask(HML_LOWER_BODY_MASK, shape)) 147 | 148 | return np.maximum(mask, get_batch_joint_mask(shape, mask_names)) 149 | 150 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT 2 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY 3 | 4 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. 5 | 6 | This is a license agreement ("Agreement") between your academic institution or non profit organization or self (called "Licensee" or "You" in this Agreement) and Universidad de Alicante (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. 7 | 8 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: 9 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, 10 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). 11 | 12 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. 13 | 14 | COPYRIGHT: The Software is owned by Licensor and is protected by Spain copyright laws and applicable international treaties and/or conventions. 15 | 16 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. 17 | 18 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. 19 | 20 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. 21 | 22 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark "in2IN", "DualMDM", "Universidad de Alicante", or any renditions thereof without the prior written permission of Licensor. 23 | 24 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. 25 | 26 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. 27 | 28 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by clicking "I Agree" below or by using the Software until terminated as provided below. 29 | 30 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. 31 | 32 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. 33 | 34 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. 35 | 36 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. 37 | 38 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. 39 | 40 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable 41 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. 42 | 43 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. 44 | 45 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. 46 | 47 | GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Spain without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Alicante, Spain. 48 | 49 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

in2IN:Leveraging individual Information to Generate Human INteractions

2 | 3 |

4 | Project 5 | arXiv 6 | 7 | 8 |

9 | 10 |
11 | 12 | ## 🔎 About 13 |
14 | 15 |
16 |
17 | Generating human-human motion interactions conditioned on textual descriptions is a very useful application in many areas such as robotics, gaming, animation, and the metaverse. Alongside this utility also comes a great difficulty in modeling the highly dimensional inter-personal dynamics. In addition, properly capturing the intra-personal diversity of interactions has a lot of challenges. Current methods generate interactions with limited diversity of intra-person dynamics due to the limitations of the available datasets and conditioning strategies. For this, we introduce in2IN, a novel diffusion model for human-human motion generation which is conditioned not only on the textual description of the overall interaction but also on the individual descriptions of the actions performed by each person involved in the interaction. To train this model, we use a large language model to extend the InterHuman dataset with individual descriptions. As a result, in2IN achieves state-of-the-art performance in the InterHuman dataset. Furthermore, in order to increase the intra-personal diversity on the existing interaction datasets, we propose DualMDM, a model composition technique that combines the motions generated with in2IN and the motions generated by a single-person motion prior pre-trained on HumanML3D. As a result, DualMDM generates motions with higher individual diversity and improves control over the intra-person dynamics while maintaining inter-personal coherence. 18 | 19 | 20 | 21 | ## 📌 News 22 | - [2024-06-04] Code, model weights, and additional training data are now available! 23 | - [2024-04-16] Our paper is available on [arXiv](https://arxiv.org/abs/2404.09988) 24 | - [2024-04-06] in2IN is now accepted at CVPR 2024 Workshop [HuMoGen](https://humogen.github.io)! 25 | 26 | ## 📝 TODO List 27 | - [x] Release code 28 | - [x] Release model weights 29 | - [x] Release individual descriptions from InterHuman dataset. 30 | - [ ] Release visualization code. 31 | 32 | 33 | ## 💻 Usage 34 | ### 🛠️ Installation 35 | 1. Clone the repo 36 | ```sh 37 | git clone https://github.com/pabloruizponce/in2IN.git 38 | ``` 39 | 2. Install the requirements 40 | 1. Download the required libraries 41 | ```sh 42 | pip install -r requirements.txt 43 | ``` 44 | 2. Install ffmpeg 45 | ```sh 46 | sudo apt update 47 | sudo apt install ffmpeg 48 | ``` 49 | 50 | > [!WARNING] 51 | > All the code has been tested with Ubuntu 22.04.3 LTS x86_64 using Python 3.12.2 and CUDA 12.3.1. If you have any issues, please open and issue. 52 | 3. Download the individual descriptions from the InterHuman dataset from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `data` folder. 53 | > [!IMPORTANT] 54 | > The original InterHuman dataset is needed to run the code. You can download it from [here](https://github.com/tr3e/InterGen). If you use the dataset, please cite us and the original paper. 55 | 56 | ### 🕹️ Inference 57 | Download the model weights from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `checkpoints` folder. 58 | 59 | ```sh 60 | python in2in/scripts/infer.py \ 61 | --model configs/models/in2IN.yaml \ 62 | --infer configs/infer.yaml \ 63 | --mode interaction \ 64 | --out results \ 65 | --device 0 \ 66 | --text_interaction "Interaction textual description" \ 67 | --text_individual1 "Individual textual description" \ 68 | --text_individual2 "Individual textual description" \ 69 | --name "output_name" \ 70 | ``` 71 | 72 | > [!NOTE] 73 | > More information about the parameters can be found using the `--help` flag. 74 | 75 | 76 | ### 🏃🏻‍♂️ Training 77 | 78 | ```sh 79 | python in2in/scripts/train.py \ 80 | --train configs/train/in2IN.yaml \ 81 | --model configs/models/in2IN.yaml \ 82 | --data configs/datasets.yaml \ 83 | --mode interaction \ 84 | --device 0 \ 85 | ``` 86 | 87 | ### 🎖️ Evaluation 88 | Download the evaluator model weights from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `checkpoints` folder. 89 | 90 | #### Interaction Quality 91 | ```sh 92 | python in2in/scripts/eval/interhuman.py \ 93 | --model configs/models/in2IN.yaml \ 94 | --evaluator configs/eval.yaml \ 95 | --mode [interaction, dual] \ 96 | --out results \ 97 | --device 0 \ 98 | ``` 99 | 100 | #### Individual Diversity 101 | ```sh 102 | python in2in/scripts/eval/DualMDM.py \ 103 | --model configs/models/DualMDM.yaml \ 104 | --evaluator configs/eval.yaml \ 105 | --device 0 \ 106 | ``` 107 | 108 | ## 📚 Citation 109 | 110 | If you find our work helpful, please cite: 111 | 112 | ```bibtex 113 | @InProceedings{Ruiz-Ponce_2024_CVPR, 114 | author = {Ruiz-Ponce, Pablo and Barquero, German and Palmero, Cristina and Escalera, Sergio and Garc{\'\i}a-Rodr{\'\i}guez, Jos\'e}, 115 | title = {in2IN: Leveraging Individual Information to Generate Human INteractions}, 116 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 117 | month = {June}, 118 | year = {2024}, 119 | pages = {1941-1951} 120 | } 121 | ``` 122 | 123 | ## 🫶🏼 Acknowledgments 124 | - [InterGen](https://github.com/tr3e/InterGen) as we inherit a lot of code from them. 125 | - [MDM](https://github.com/GuyTevet/motion-diffusion-model) as we used their evaluation code for text-motion models. 126 | - [Diffusion Models Beat GANS on Image Synthesis](https://github.com/openai/guided-diffusion) as we used their gaussian diffusion code as a base for our implementation. 127 | -------------------------------------------------------------------------------- /in2in/scripts/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(sys.path[0] + r"/../../") 3 | 4 | import os 5 | import time 6 | import wandb 7 | import torch 8 | import lightning.pytorch as pl 9 | from pytorch_lightning.loggers import WandbLogger 10 | 11 | import torch.optim as optim 12 | from collections import OrderedDict 13 | from datasets import DataModule, DataModuleHML3D 14 | from in2in.utils.configs import get_config 15 | from os.path import join as pjoin 16 | import argparse 17 | from in2in.models.utils.utils import CosineWarmupScheduler 18 | from in2in.utils.utils import print_current_loss 19 | from in2in.models.in2in import in2IN 20 | 21 | 22 | os.environ['PL_TORCH_DISTRIBUTED_BACKEND'] = 'nccl' 23 | from lightning.pytorch.strategies import DDPStrategy 24 | torch.set_float32_matmul_precision('medium') 25 | 26 | def list_of_ints(arg): 27 | return list(map(int, arg.split(','))) 28 | 29 | class LitTrainModel(pl.LightningModule): 30 | def __init__(self, model, cfg, mode): 31 | super().__init__() 32 | 33 | self.cfg = cfg 34 | self.automatic_optimization = False 35 | self.save_root = pjoin(self.cfg.GENERAL.CHECKPOINT, self.cfg.GENERAL.EXP_NAME) 36 | self.model_dir = pjoin(self.save_root, 'model') 37 | self.meta_dir = pjoin(self.save_root, 'meta') 38 | self.log_dir = pjoin(self.save_root, 'log') 39 | 40 | os.makedirs(self.model_dir, exist_ok=True) 41 | os.makedirs(self.meta_dir, exist_ok=True) 42 | os.makedirs(self.log_dir, exist_ok=True) 43 | 44 | self.model = model 45 | 46 | # Can be ["individual", "iteraction", "dual"] 47 | self.mode = mode 48 | 49 | # save hyper-parameters to self.hparamsm auto-logged by wandb 50 | self.save_hyperparameters(ignore=['model']) 51 | 52 | def _configure_optim(self): 53 | optimizer = optim.AdamW(self.model.parameters(), lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) 54 | if self.mode == "individual": 55 | optimizer = optim.AdamW(self.model.parameters(), lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) 56 | return [optimizer] 57 | else: 58 | scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=10, max_iters=self.cfg.TRAIN.EPOCH, verbose=True) 59 | return [optimizer], [scheduler] 60 | 61 | def configure_optimizers(self): 62 | return self._configure_optim() 63 | 64 | def forward(self, batch_data): 65 | if self.mode == "individual": 66 | name, text, motion1, motion_lens = batch_data 67 | motions = motion1.detach().float() 68 | elif self.mode == "interaction": 69 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data 70 | motion1 = motion1.detach().float() 71 | motion2 = motion2.detach().float() 72 | motions = torch.cat([motion1, motion2], dim=-1) 73 | 74 | B, T = motion1.shape[:2] 75 | 76 | batch = OrderedDict({}) 77 | batch["text"] = text 78 | 79 | if self.mode == "interaction": 80 | batch["text_individual1"] = text_individual1 81 | batch["text_individual2"] = text_individual2 82 | 83 | batch["motions"] = motions.reshape(B, T, -1).type(torch.float32) 84 | batch["motion_lens"] = motion_lens.long() 85 | 86 | loss, loss_logs = self.model(batch) 87 | return loss, loss_logs 88 | 89 | def on_train_start(self): 90 | self.rank = 0 91 | self.world_size = 1 92 | self.start_time = time.time() 93 | self.it = self.cfg.TRAIN.LAST_ITER if self.cfg.TRAIN.LAST_ITER else 0 94 | self.epoch = self.cfg.TRAIN.LAST_EPOCH if self.cfg.TRAIN.LAST_EPOCH else 0 95 | self.logs = OrderedDict() 96 | 97 | print("Model Iterations", self.it) 98 | print("Model Epochs", self.epoch) 99 | 100 | 101 | def training_step(self, batch, batch_idx): 102 | loss, loss_logs = self.forward(batch) 103 | opt = self.optimizers() 104 | opt.zero_grad() 105 | self.manual_backward(loss) 106 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) 107 | opt.step() 108 | 109 | return {"loss": loss, 110 | "loss_logs": loss_logs} 111 | 112 | 113 | def on_train_batch_end(self, outputs, batch, batch_idx): 114 | if outputs.get('skip_batch') or not outputs.get('loss_logs'): 115 | return 116 | for k, v in outputs['loss_logs'].items(): 117 | if k not in self.logs: 118 | self.logs[k] = v.item() 119 | else: 120 | self.logs[k] += v.item() 121 | 122 | self.it += 1 123 | if self.it % self.cfg.TRAIN.LOG_STEPS == 0: 124 | mean_loss = OrderedDict({}) 125 | for tag, value in self.logs.items(): 126 | mean_loss[tag] = value / self.cfg.TRAIN.LOG_STEPS 127 | # log metrics to wandb 128 | self.log(tag, mean_loss[tag], on_step=True, on_epoch=False, prog_bar=True) 129 | self.logs = OrderedDict() 130 | print_current_loss(self.start_time, self.it, mean_loss, 131 | self.trainer.current_epoch, 132 | inner_iter=batch_idx, 133 | lr=self.trainer.optimizers[0].param_groups[0]['lr']) 134 | 135 | 136 | def on_train_epoch_end(self): 137 | if self.mode == "interaction": 138 | sch = self.lr_schedulers() 139 | if sch is not None: 140 | sch.step() 141 | 142 | def save(self, file_name): 143 | state = {} 144 | try: 145 | state['model'] = self.model.module.state_dict() 146 | except: 147 | state['model'] = self.model.state_dict() 148 | torch.save(state, file_name, _use_new_zipfile_serialization=False) 149 | return 150 | 151 | 152 | if __name__ == '__main__': 153 | 154 | # Create the parser 155 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments") 156 | 157 | # Add arguments 158 | parser.add_argument('--train', type=str, required=True, help='Training Configuration file') 159 | parser.add_argument('--model' , type=str, required=True, help='Model Configuration file') 160 | parser.add_argument('--data', type=str, required=True, help='Data Configuration file') 161 | parser.add_argument('--mode', type=str, required=True, help='Model mode (individual or interaction)') 162 | parser.add_argument('--resume', type=str, required=False, help='Resume training from checkpoint') 163 | parser.add_argument('--device', type=list_of_ints, required=True, help='Device to run the training') 164 | 165 | # Parse the arguments 166 | args = parser.parse_args() 167 | 168 | model_cfg = get_config(args.model) 169 | train_cfg = get_config(args.train) 170 | 171 | # initialise the wandb logger and name your wandb project 172 | wandb_logger = WandbLogger(project='in2IN', name=train_cfg.GENERAL.EXP_NAME) 173 | 174 | if args.mode == "individual": 175 | data_cfg = get_config(args.data).humanml3d 176 | datamodule = DataModuleHML3D(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS) 177 | elif args.mode == "interaction": 178 | data_cfg = get_config(args.data).interhuman 179 | datamodule = DataModule(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS) 180 | 181 | model = in2IN(model_cfg, args.mode) 182 | litmodel = LitTrainModel(model, train_cfg, args.mode) 183 | 184 | # Checkpoint callback 185 | checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=litmodel.model_dir, 186 | every_n_epochs=train_cfg.TRAIN.SAVE_EPOCH, 187 | save_top_k=-1) 188 | 189 | # Trainer 190 | trainer = pl.Trainer( 191 | default_root_dir=litmodel.model_dir, 192 | devices=args.device, accelerator='gpu', 193 | max_epochs=train_cfg.TRAIN.EPOCH, 194 | strategy=DDPStrategy(find_unused_parameters=True), 195 | precision='16-mixed', 196 | logger=wandb_logger, 197 | callbacks=[checkpoint_callback] 198 | ) 199 | 200 | if args.resume: 201 | trainer.fit(model=litmodel, datamodule=datamodule, ckpt_path=args.resume) 202 | else: 203 | trainer.fit(model=litmodel, datamodule=datamodule) -------------------------------------------------------------------------------- /in2in/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | 5 | from in2in.evaluation.datasets import EvaluationDatasetDualMDM, EvaluationDatasetInterHuman, MMGeneratedDatasetInterHuman 6 | from in2in.models import * 7 | from in2in.datasets import InterHuman 8 | from in2in.evaluation.models import InterCLIP 9 | from torch.utils.data import DataLoader 10 | 11 | def get_dataset_motion_loader(opt, batch_size, num_samples=-1): 12 | """ 13 | Get the ground truth dataset of motions with his given dataloader. 14 | :param opt: Configuration of the dataset. 15 | :param batch_size: Batch size of the dataloader. 16 | :return: Dataloader of the motion datase. 17 | """ 18 | opt = copy.deepcopy(opt) 19 | if opt.NAME == 'interhuman': 20 | print('Loading dataset %s ...' % opt.NAME) 21 | 22 | dataset = InterHuman(opt, num_samples=num_samples) 23 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=True, shuffle=True) 24 | else: 25 | raise KeyError('Dataset not Recognized !!') 26 | 27 | print('Ground Truth Dataset Loading Completed!!!') 28 | return dataloader, dataset 29 | 30 | 31 | 32 | 33 | def get_motion_loader_in2IN(batch_size, model, ground_truth_dataset, device, mm_num_samples, mm_num_repeats): 34 | """ 35 | Get the generated dataset of motions with his given dataloader and the MultiModality one. 36 | :param batch_size: Batch size of the dataloader. 37 | :param model: Model to generate the motions. 38 | :param ground_truth_dataset: Ground truth dataset. 39 | :param device: Device to run the model. 40 | :param mm_num_samples: Number of samples to generate for the MultiModality metric. 41 | :param mm_num_repeats: Number of repeats for each sample in the MultiModality metric. 42 | :return: Dataloader of the generated motion dataset and the MultiModality one. 43 | """ 44 | 45 | dataset = EvaluationDatasetInterHuman(model, ground_truth_dataset, device, mm_num_samples=mm_num_samples, mm_num_repeats=mm_num_repeats) 46 | mm_dataset = MMGeneratedDatasetInterHuman(dataset) 47 | 48 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True) 49 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0) 50 | 51 | print('Generated Dataset Loading Completed!!!') 52 | 53 | return motion_loader, mm_motion_loader 54 | 55 | 56 | 57 | def get_motion_loader_DualMDM(batch_size, model, ground_truth_dataset, device, num_repeats): 58 | """ 59 | Get the generated dataset of motions with his given dataloader 60 | :param batch_size: Batch size of the dataloader. 61 | :param model: Model to generate the motions. 62 | :param ground_truth_dataset: Ground truth dataset. 63 | :param device: Device to run the model. 64 | :param num_repeats: Number of repeats for each sample. 65 | :return: Dataloader of the generated motion dataset. 66 | """ 67 | dataset = EvaluationDatasetDualMDM(model, ground_truth_dataset, device, num_repeats=num_repeats) 68 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True) 69 | return motion_loader 70 | 71 | 72 | def build_models(cfg): 73 | """ 74 | Create and load the feature extractor model for the evaluation. 75 | :param cfg: Configuration of the model. 76 | :return: Feature extractor model for the evaluation. 77 | """ 78 | 79 | model = InterCLIP(cfg) 80 | 81 | # Load the model from the checkpoint 82 | checkpoint = torch.load(cfg.CHECKPOINT, map_location="cpu") 83 | for k in list(checkpoint["state_dict"].keys()): 84 | if "model" in k: 85 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k) 86 | model.load_state_dict(checkpoint["state_dict"], strict=True) 87 | 88 | return model 89 | 90 | 91 | class EvaluatorModelWrapper(object): 92 | """ 93 | Wrapper of the model for the evaluation. 94 | The model will be used to extract features from the generated motions and the gt motions. 95 | """ 96 | def __init__(self, cfg, device): 97 | """ 98 | Initialization of the model. 99 | :param cfg: Configuration of the model. 100 | :param device: Device to run the model. 101 | """ 102 | self.model = build_models(cfg) 103 | self.cfg = cfg 104 | self.device = device 105 | self.model = self.model.to(device) 106 | self.model.eval() 107 | self.extended = cfg.EXTENDED 108 | 109 | 110 | def get_co_embeddings(self, batch_data): 111 | """ 112 | Get the embeddings of the text and the motions of a given batch of data. 113 | :param batch_data: Batch of data to extract the embeddings. 114 | :return: Embeddings of the text and the motions. 115 | Please note that the results does not following the order of inputs 116 | """ 117 | with torch.no_grad(): 118 | # Extract data from the batch provided by the evaluation datasets 119 | if self.extended: 120 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data 121 | else: 122 | name, text, motion1, motion2, motion_lens = batch_data 123 | 124 | motion1 = motion1.detach().float() # .to(self.device) 125 | motion2 = motion2.detach().float() # .to(self.device) 126 | motions = torch.cat([motion1, motion2], dim=-1) 127 | motions = motions.detach().to(self.device).float() 128 | 129 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 130 | motions = motions[align_idx] 131 | motion_lens = motion_lens[align_idx] 132 | text = list(text) 133 | if self.extended: 134 | text_individual1 = list(text_individual1) 135 | text_individual2 = list(text_individual2) 136 | 137 | # Create padding for the motions 138 | B, T = motions.shape[:2] 139 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 140 | padded_len = cur_len.max() 141 | 142 | # Create batch for feature prediction 143 | batch = {} 144 | batch["text"] = text 145 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 146 | batch["motion_lens"] = motion_lens 147 | if self.extended: 148 | batch["text_individual1"] = text_individual1 149 | batch["text_individual2"] = text_individual2 150 | 151 | # Motion Encoding 152 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 153 | 154 | # Text Encoding 155 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx] 156 | 157 | return text_embedding, motion_embedding 158 | 159 | def get_motion_embeddings(self, batch_data): 160 | """ 161 | Get the embeddings of the motions of a given batch of data. 162 | :param batch_data: Batch of data to extract the embeddings. 163 | :return: Embeddings of the motions. 164 | Please note that the results does not following the order of inputs 165 | """ 166 | with torch.no_grad(): 167 | # Extract data from the batch provided by the evaluation datasets 168 | if self.extended: 169 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data 170 | else: 171 | name, text, motion1, motion2, motion_lens = batch_data 172 | 173 | motion1 = motion1.detach().float() # .to(self.device) 174 | motion2 = motion2.detach().float() # .to(self.device) 175 | motions = torch.cat([motion1, motion2], dim=-1) 176 | motions = motions.detach().to(self.device).float() 177 | 178 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy() 179 | motions = motions[align_idx] 180 | motion_lens = motion_lens[align_idx] 181 | text = list(text) 182 | 183 | # Create padding for the motions 184 | B, T = motions.shape[:2] 185 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device) 186 | padded_len = cur_len.max() 187 | 188 | # Create batch for feature prediction 189 | batch = {} 190 | batch["text"] = text 191 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len] 192 | batch["motion_lens"] = motion_lens 193 | if self.extended: 194 | batch["text_individual1"] = text_individual1 195 | batch["text_individual2"] = text_individual2 196 | 197 | # Motion Encoding 198 | motion_embedding = self.model.encode_motion(batch)['motion_emb'] 199 | 200 | return motion_embedding 201 | -------------------------------------------------------------------------------- /in2in/evaluation/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import clip 4 | from in2in.models.utils.utils import PositionalEncoding, set_requires_grad 5 | 6 | 7 | from models import * 8 | 9 | class MotionEncoder(nn.Module): 10 | """ 11 | Motion encoder module for feature extractor evaluation model 12 | """ 13 | def __init__(self, cfg): 14 | """ 15 | Initialize the motion encoder module 16 | :param cfg: model configuration file 17 | """ 18 | super().__init__() 19 | 20 | # Model parameters 21 | self.cfg = cfg 22 | self.input_feats = cfg.INPUT_DIM 23 | self.latent_dim = cfg.LATENT_DIM 24 | self.ff_size = cfg.FF_SIZE 25 | self.num_layers = cfg.NUM_LAYERS 26 | self.num_heads = cfg.NUM_HEADS 27 | self.dropout = cfg.DROPOUT 28 | self.activation = cfg.ACTIVATION 29 | 30 | # Model architecture 31 | self.query_token = nn.Parameter(torch.randn(1, self.latent_dim)) 32 | self.embed_motion = nn.Linear(self.input_feats*2, self.latent_dim) 33 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=2000) 34 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, 35 | nhead=self.num_heads, 36 | dim_feedforward=self.ff_size, 37 | dropout=self.dropout, 38 | activation=self.activation, 39 | batch_first=True) 40 | self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) 41 | self.out_ln = nn.LayerNorm(self.latent_dim) 42 | self.out = nn.Linear(self.latent_dim, 512) 43 | 44 | 45 | def forward(self, batch): 46 | """ 47 | Forward pass of the motion encoder module 48 | :param batch: input batch 49 | :return batch: updated batch 50 | """ 51 | 52 | x, mask = batch["motions"], batch["mask"] 53 | B, T, D = x.shape 54 | x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1) 55 | 56 | # Embedding 57 | x_emb = self.embed_motion(x) 58 | emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=x.device)][:,None], x_emb], dim=1) 59 | 60 | # Masking 61 | seq_mask = (mask>0.5) 62 | token_mask = torch.ones((B, 1), dtype=bool, device=x.device) 63 | valid_mask = torch.cat([token_mask, seq_mask], dim=1) 64 | 65 | # Positional encoder and transformer 66 | h = self.sequence_pos_encoder(emb) 67 | h = self.transformer(h, src_key_padding_mask=~valid_mask) 68 | h = self.out_ln(h) 69 | motion_emb = self.out(h[:,0]) 70 | batch["motion_emb"] = motion_emb 71 | 72 | return batch 73 | 74 | class InterCLIP(nn.Module): 75 | """ 76 | InterCLIP model for feature extractor evaluation 77 | It is based in clip model and MotionEncoder 78 | """ 79 | def __init__(self, cfg): 80 | """ 81 | Initialize the InterCLIP model 82 | :param cfg: model configuration file 83 | """ 84 | super().__init__() 85 | 86 | # Model parameters 87 | self.cfg = cfg 88 | self.latent_dim = cfg.LATENT_DIM 89 | self.motion_encoder = MotionEncoder(cfg) 90 | self.latent_dim = self.latent_dim 91 | 92 | # CLIP model 93 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) 94 | self.token_embedding = clip_model.token_embedding 95 | self.positional_embedding = clip_model.positional_embedding 96 | self.dtype = clip_model.dtype 97 | self.latent_scale = nn.Parameter(torch.Tensor([1])) 98 | set_requires_grad(self.token_embedding, False) 99 | 100 | # Additional text encoding layers 101 | textTransEncoderLayer = nn.TransformerEncoderLayer( 102 | d_model=768, 103 | nhead=8, 104 | dim_feedforward=cfg.FF_SIZE, 105 | dropout=0.1, 106 | activation="gelu", 107 | batch_first=True) 108 | self.textTransEncoder = nn.TransformerEncoder( 109 | textTransEncoderLayer, 110 | num_layers=8) 111 | self.text_ln = nn.LayerNorm(768) 112 | self.out = nn.Linear(768, 512) 113 | 114 | # Losses 115 | self.clip_training = "text_" 116 | self.l1_criterion = torch.nn.L1Loss(reduction='mean') 117 | self.loss_ce = nn.CrossEntropyLoss() 118 | 119 | 120 | def generate_src_mask(self, T, length): 121 | """ 122 | Generate source mask for transformer 123 | :param T: sequence length 124 | :param length: sequence length 125 | :return src_mask: source mask 126 | """ 127 | B = length.shape[0] 128 | src_mask = torch.ones(B, T) 129 | for i in range(B): 130 | for j in range(length[i], T): 131 | src_mask[i, j] = 0 132 | return src_mask 133 | 134 | def encode_motion(self, batch): 135 | """ 136 | Encode motion features 137 | :param batch: input batch 138 | :return batch: updated batch 139 | """ 140 | batch["mask"] = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(batch["motions"].device) 141 | batch.update(self.motion_encoder(batch)) 142 | batch["motion_emb"] = batch["motion_emb"] / batch["motion_emb"].norm(dim=-1, keepdim=True) * self.latent_scale 143 | 144 | return batch 145 | 146 | def encode_text(self, batch): 147 | """ 148 | Encode text features 149 | :param batch: input batch 150 | :return batch: updated batch 151 | """ 152 | device = next(self.parameters()).device 153 | raw_text = batch["text"] 154 | 155 | with torch.no_grad(): 156 | text = clip.tokenize(raw_text, truncate=True).to(device) 157 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 158 | pe_tokens = x + self.positional_embedding.type(self.dtype) 159 | 160 | out = self.textTransEncoder(pe_tokens) 161 | out = self.text_ln(out) 162 | 163 | out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] 164 | out = self.out(out) 165 | 166 | # Normalize 167 | batch['text_emb'] = out 168 | batch["text_emb"] = batch["text_emb"] / batch["text_emb"].norm(dim=-1, keepdim=True) * self.latent_scale 169 | 170 | return batch 171 | 172 | 173 | def compute_loss(self, batch): 174 | """ 175 | Wrapper for calculating the loss of the model 176 | :param batch: input batch 177 | """ 178 | 179 | losses = {} 180 | losses["total"] = 0 181 | 182 | # Encode text and motion 183 | batch = self.encode_text(batch) 184 | batch = self.encode_motion(batch) 185 | 186 | # Compute clip losses 187 | mixed_clip_loss, clip_losses = self.compute_clip_losses(batch) 188 | losses.update(clip_losses) 189 | losses["total"] += mixed_clip_loss 190 | 191 | return losses["total"], losses 192 | 193 | def compute_clip_losses(self, batch): 194 | """ 195 | Computing losses from the motion encoder and the text encoder 196 | :param batch: input batch 197 | :return mixed_clip_loss: mixed clip loss 198 | :return clip_losses: clip losses 199 | """ 200 | mixed_clip_loss = 0. 201 | clip_losses = {} 202 | 203 | for d in self.clip_training.split('_')[:1]: 204 | if d == 'image': 205 | features = self.clip_model.encode_image(batch['images']).float() # preprocess is done in dataloader 206 | elif d == 'text': 207 | features = batch['text_emb'] 208 | motion_features = batch['motion_emb'] 209 | 210 | # Normalized features 211 | features_norm = features / features.norm(dim=-1, keepdim=True) 212 | motion_features_norm = motion_features / motion_features.norm(dim=-1, keepdim=True) 213 | 214 | # Compute logits 215 | logit_scale = self.latent_scale ** 2 216 | logits_per_motion = logit_scale * motion_features_norm @ features_norm.t() 217 | logits_per_d = logits_per_motion.t() 218 | 219 | batch_size = motion_features.shape[0] 220 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=motion_features.device) 221 | 222 | # Compute losses 223 | ce_from_motion_loss = self.loss_ce(logits_per_motion, ground_truth) 224 | ce_from_d_loss = self.loss_ce(logits_per_d, ground_truth) 225 | clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2. 226 | 227 | clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item() 228 | clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item() 229 | clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item() 230 | mixed_clip_loss += clip_mixed_loss 231 | 232 | return mixed_clip_loss, clip_losses 233 | 234 | def forward(self, batch): 235 | """ 236 | Forward pass of the InterCLIP model 237 | :param batch: input batch 238 | :return batch: updated batch 239 | """ 240 | return self.compute_loss(batch) 241 | 242 | 243 | -------------------------------------------------------------------------------- /in2in/datasets/interhuman.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | 5 | from tqdm import tqdm 6 | from torch.utils import data 7 | from utils.preprocess import load_motion 8 | from os.path import join as pjoin 9 | from utils.quaternion import * 10 | from utils.utils import rigid_transform, process_motion_interhuman 11 | 12 | class InterHuman(data.Dataset): 13 | """ 14 | InterHuman dataset 15 | """ 16 | 17 | def __init__(self, opt, num_samples=-1): 18 | 19 | # Configuration variables 20 | self.opt = opt 21 | self.max_cond_length = 1 22 | self.min_cond_length = 1 23 | self.max_gt_length = 300 24 | self.min_gt_length = 15 25 | self.max_length = self.max_cond_length + self.max_gt_length -1 26 | self.min_length = self.min_cond_length + self.min_gt_length -1 27 | self.motion_rep = opt.MOTION_REP 28 | self.cache = opt.CACHE 29 | self.extended = opt.EXTENDED 30 | 31 | # Data structures 32 | self.motion_dict = {} 33 | self.data_list = [] 34 | data_list = [] 35 | 36 | # Load paths from the given split 37 | if self.opt.MODE == "train": 38 | try: 39 | data_list = open(os.path.join(opt.DATA_ROOT, "split/train.txt"), "r").readlines() 40 | except Exception as e: 41 | print(e) 42 | elif self.opt.MODE == "val": 43 | try: 44 | data_list = open(os.path.join(opt.DATA_ROOT, "split/val.txt"), "r").readlines() 45 | except Exception as e: 46 | print(e) 47 | elif self.opt.MODE == "test": 48 | try: 49 | data_list = open(os.path.join(opt.DATA_ROOT, "split/test.txt"), "r").readlines() 50 | except Exception as e: 51 | print(e) 52 | 53 | # Suffle paths 54 | random.shuffle(data_list) 55 | 56 | if num_samples > 0: 57 | data_list = data_list[:num_samples] 58 | print("Using only {} samples".format(num_samples)) 59 | 60 | # Load data 61 | index = 0 62 | root = pjoin(opt.DATA_ROOT, "motions_processed/person1") 63 | for file in tqdm(os.listdir(root)): 64 | 65 | # Comment if you want to use the whole dataset 66 | if file.split(".")[0]+"\n" not in data_list: 67 | continue 68 | 69 | motion_name = file.split(".")[0] 70 | file_path_person1 = pjoin(root, file) 71 | file_path_person2 = pjoin(root.replace("person1", "person2"), file) 72 | text_path = file_path_person1.replace("motions_processed", "annots").replace("person1", "").replace("npy", "txt") 73 | 74 | # Load interaction texts and make the swaps 75 | texts = [item.replace("\n", "") for item in open(text_path, "r").readlines()] 76 | texts_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right") 77 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts] 78 | 79 | # If using extended version, load individual desciptions of the motions 80 | if self.extended: 81 | text_path_individual1 = file_path_person1.replace("motions_processed", "annots_individual").replace("npy", "txt") 82 | text_path_individual2 = file_path_person2.replace("motions_processed", "annots_individual").replace("npy", "txt") 83 | 84 | if not os.path.exists(text_path_individual1): 85 | continue 86 | else: 87 | texts_individual1 = [item.replace("\n", "") for item in open(text_path_individual1, "r").readlines()] 88 | texts_individual2 = [item.replace("\n", "") for item in open(text_path_individual2, "r").readlines()] 89 | 90 | # Make the swaps of the individual descriptions 91 | texts_individual1_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right") 92 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_individual2] 93 | texts_individual2_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right") 94 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_individual1] 95 | 96 | # Load motion and check if it is too short and cache it if needed 97 | if self.cache: 98 | motion1, motion1_swap = load_motion(file_path_person1, self.min_length, swap=True) 99 | motion2, motion2_swap = load_motion(file_path_person2, self.min_length, swap=True) 100 | if motion1 is None: 101 | continue 102 | 103 | if self.cache: 104 | self.motion_dict[index] = [motion1, motion2] 105 | self.motion_dict[index+1] = [motion1_swap, motion2_swap] 106 | else: 107 | self.motion_dict[index] = [file_path_person1, file_path_person2] 108 | self.motion_dict[index + 1] = [file_path_person1, file_path_person2] 109 | 110 | # Fill data structures depending on the variable of the dataset used 111 | if self.extended: 112 | self.data_list.append({ 113 | "name": motion_name, 114 | "motion_id": index, 115 | "swap":False, 116 | "texts":texts, 117 | "texts_individual1":texts_individual1_swap, 118 | "texts_individual2":texts_individual2_swap, 119 | }) 120 | 121 | if opt.MODE == "train": 122 | self.data_list.append({ 123 | "name": motion_name+"_swap", 124 | "motion_id": index+1, 125 | "swap": True, 126 | "texts": texts_swap, 127 | "texts_individual1":texts_individual1, 128 | "texts_individual2":texts_individual2, 129 | }) 130 | else: 131 | self.data_list.append({ 132 | "name": motion_name, 133 | "motion_id": index, 134 | "swap":False, 135 | "texts":texts 136 | }) 137 | 138 | if opt.MODE == "train": 139 | self.data_list.append({ 140 | "name": motion_name+"_swap", 141 | "motion_id": index+1, 142 | "swap": True, 143 | "texts": texts_swap, 144 | }) 145 | 146 | index += 2 147 | 148 | print("Total Dataset Size: ", len(self.data_list)) 149 | 150 | def __len__(self): 151 | """ 152 | Get the length of the dataset 153 | """ 154 | return len(self.data_list) 155 | 156 | def __getitem__(self, item): 157 | """ 158 | Get an item from the dataset 159 | param item: Index of the item to get 160 | """ 161 | 162 | # Get the data from the dataset 163 | idx = item % self.__len__() 164 | data = self.data_list[idx] 165 | name = data["name"] 166 | motion_id = data["motion_id"] 167 | swap = data["swap"] 168 | 169 | # Select a random text from the list and if extended also select the individual descriptions 170 | text = random.choice(data["texts"]).strip() 171 | if self.extended: 172 | text_individual1 = random.choice(data["texts_individual1"]).strip() 173 | text_individual2 = random.choice(data["texts_individual2"]).strip() 174 | 175 | # Load the motion 176 | if self.cache: 177 | full_motion1, full_motion2 = self.motion_dict[motion_id] 178 | else: 179 | file_path1, file_path2 = self.motion_dict[motion_id] 180 | motion1, motion1_swap = load_motion(file_path1, self.min_length, swap=swap) 181 | motion2, motion2_swap = load_motion(file_path2, self.min_length, swap=swap) 182 | if swap: 183 | full_motion1 = motion1_swap 184 | full_motion2 = motion2_swap 185 | else: 186 | full_motion1 = motion1 187 | full_motion2 = motion2 188 | 189 | # Get motion lenght and select a random segment 190 | length = full_motion1.shape[0] 191 | if length > self.max_length: 192 | idx = random.choice(list(range(0, length - self.max_gt_length, 1))) 193 | gt_length = self.max_gt_length 194 | motion1 = full_motion1[idx:idx + gt_length] 195 | motion2 = full_motion2[idx:idx + gt_length] 196 | else: 197 | idx = 0 198 | gt_length = min(length - idx, self.max_gt_length ) 199 | motion1 = full_motion1[idx:idx + gt_length] 200 | motion2 = full_motion2[idx:idx + gt_length] 201 | 202 | # Swap the motions randomly 203 | if np.random.rand() > 0.5: 204 | motion1, motion2 = motion2, motion1 205 | 206 | # Process the motion 207 | motion1, root_quat_init1, root_pos_init1 = process_motion_interhuman(motion1, 0.001, 0, n_joints=22) 208 | motion2, root_quat_init2, root_pos_init2 = process_motion_interhuman(motion2, 0.001, 0, n_joints=22) 209 | 210 | # Rotate motion 2 211 | r_relative = qmul_np(root_quat_init2, qinv_np(root_quat_init1)) 212 | angle = np.arctan2(r_relative[:, 2:3], r_relative[:, 0:1]) 213 | xz = qrot_np(root_quat_init1, root_pos_init2 - root_pos_init1)[:, [0, 2]] 214 | relative = np.concatenate([angle, xz], axis=-1)[0] 215 | motion2 = rigid_transform(relative, motion2) 216 | 217 | gt_motion1 = motion1 218 | gt_motion2 = motion2 219 | 220 | # Check if the motion is too short and pad it 221 | gt_length = len(gt_motion1) 222 | if gt_length < self.max_gt_length: 223 | padding_len = self.max_gt_length - gt_length 224 | D = gt_motion1.shape[1] 225 | padding_zeros = np.zeros((padding_len, D)) 226 | gt_motion1 = np.concatenate((gt_motion1, padding_zeros), axis=0) 227 | gt_motion2 = np.concatenate((gt_motion2, padding_zeros), axis=0) 228 | 229 | # Return the data 230 | if self.extended: 231 | return name, text, gt_motion1, gt_motion2, gt_length, text_individual1, text_individual2 232 | else: 233 | return name, text, gt_motion1, gt_motion2, gt_length 234 | -------------------------------------------------------------------------------- /in2in/utils/skeleton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .paramUtil import FACE_JOINT_INDX, HML_KINEMATIC_CHAIN, HML_RAW_OFFSETS, L_IDX1, L_IDX2 4 | from .quaternion import * 5 | from scipy.ndimage import filters 6 | class Skeleton(object): 7 | """ 8 | Skeleton class to handle the skeleton data such as 9 | offsets, kinematic tree, forward and inverse kinematics 10 | """ 11 | def __init__(self, offset, kinematic_tree, device): 12 | """ 13 | Initialize the Skeleton class 14 | :param offset: Offset of the skeleton that we are using 15 | :param kinematic_tree: Kinematic tree of the skeleton 16 | :param device: Device to run the skeleton 17 | """ 18 | self.device = device 19 | self._raw_offset_np = offset.numpy() 20 | self._raw_offset = offset.clone().detach().to(device).float() 21 | self._kinematic_tree = kinematic_tree 22 | self._offset = None 23 | self._parents = [0] * len(self._raw_offset) 24 | self._parents[0] = -1 25 | for chain in self._kinematic_tree: 26 | for j in range(1, len(chain)): 27 | self._parents[chain[j]] = chain[j-1] 28 | 29 | def njoints(self): 30 | return len(self._raw_offset) 31 | 32 | def offset(self): 33 | return self._offset 34 | 35 | def set_offset(self, offsets): 36 | self._offset = offsets.clone().detach().to(self.device).float() 37 | 38 | def kinematic_tree(self): 39 | return self._kinematic_tree 40 | 41 | def parents(self): 42 | return self._parents 43 | 44 | # joints (batch_size, joints_num, 3) 45 | def get_offsets_joints_batch(self, joints): 46 | assert len(joints.shape) == 3 47 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() 48 | for i in range(1, self._raw_offset.shape[0]): 49 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] 50 | 51 | self._offset = _offsets.detach() 52 | return _offsets 53 | 54 | # joints (joints_num, 3) 55 | def get_offsets_joints(self, joints): 56 | assert len(joints.shape) == 2 57 | _offsets = self._raw_offset.clone() 58 | for i in range(1, self._raw_offset.shape[0]): 59 | # print(joints.shape) 60 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] 61 | 62 | self._offset = _offsets.detach() 63 | return _offsets 64 | 65 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder 66 | # joints (batch_size, joints_num, 3) 67 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): 68 | assert len(face_joint_idx) == 4 69 | # Get Forward Direction 70 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx 71 | across1 = joints[:, r_hip] - joints[:, l_hip] 72 | across2 = joints[:, sdr_r] - joints[:, sdr_l] 73 | across = across1 + across2 74 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] 75 | 76 | # forward (batch_size, 3) 77 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 78 | if smooth_forward: 79 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') 80 | # forward (batch_size, 3) 81 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] 82 | 83 | # Get Root Rotation 84 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0) 85 | root_quat = qbetween_np(forward, target) 86 | 87 | # Inverse Kinematics 88 | # quat_params (batch_size, joints_num, 4) 89 | quat_params = np.zeros(joints.shape[:-1] + (4,)) 90 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) 91 | quat_params[:, 0] = root_quat 92 | for chain in self._kinematic_tree: 93 | R = root_quat 94 | for j in range(len(chain) - 1): 95 | # (batch, 3) 96 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) 97 | # (batch, 3) 98 | v = joints[:, chain[j+1]] - joints[:, chain[j]] 99 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] 100 | rot_u_v = qbetween_np(u, v) 101 | R_loc = qmul_np(qinv_np(R), rot_u_v) 102 | quat_params[:,chain[j + 1], :] = R_loc 103 | R = qmul_np(R, R_loc) 104 | 105 | return quat_params 106 | 107 | # Be sure root joint is at the beginning of kinematic chains 108 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 109 | # quat_params (batch_size, joints_num, 4) 110 | # joints (batch_size, joints_num, 3) 111 | # root_pos (batch_size, 3) 112 | if skel_joints is not None: 113 | offsets = self.get_offsets_joints_batch(skel_joints) 114 | if len(self._offset.shape) == 2: 115 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 116 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) 117 | joints[:, 0] = root_pos 118 | for chain in self._kinematic_tree: 119 | if do_root_R: 120 | R = quat_params[:, 0] 121 | else: 122 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) 123 | for i in range(1, len(chain)): 124 | R = qmul(R, quat_params[:, chain[i]]) 125 | offset_vec = offsets[:, chain[i]] 126 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] 127 | return joints 128 | 129 | # Be sure root joint is at the beginning of kinematic chains 130 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): 131 | # quat_params (batch_size, joints_num, 4) 132 | # joints (batch_size, joints_num, 3) 133 | # root_pos (batch_size, 3) 134 | if skel_joints is not None: 135 | skel_joints = torch.from_numpy(skel_joints) 136 | offsets = self.get_offsets_joints_batch(skel_joints) 137 | if len(self._offset.shape) == 2: 138 | offsets = self._offset.expand(quat_params.shape[0], -1, -1) 139 | offsets = offsets.numpy() 140 | joints = np.zeros(quat_params.shape[:-1] + (3,)) 141 | joints[:, 0] = root_pos 142 | for chain in self._kinematic_tree: 143 | if do_root_R: 144 | R = quat_params[:, 0] 145 | else: 146 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) 147 | for i in range(1, len(chain)): 148 | R = qmul_np(R, quat_params[:, chain[i]]) 149 | offset_vec = offsets[:, chain[i]] 150 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] 151 | return joints 152 | 153 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 154 | # cont6d_params (batch_size, joints_num, 6) 155 | # joints (batch_size, joints_num, 3) 156 | # root_pos (batch_size, 3) 157 | if skel_joints is not None: 158 | skel_joints = torch.from_numpy(skel_joints) 159 | offsets = self.get_offsets_joints_batch(skel_joints) 160 | if len(self._offset.shape) == 2: 161 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 162 | offsets = offsets.numpy() 163 | joints = np.zeros(cont6d_params.shape[:-1] + (3,)) 164 | joints[:, 0] = root_pos 165 | for chain in self._kinematic_tree: 166 | if do_root_R: 167 | matR = cont6d_to_matrix_np(cont6d_params[:, 0]) 168 | else: 169 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) 170 | for i in range(1, len(chain)): 171 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) 172 | offset_vec = offsets[:, chain[i]][..., np.newaxis] 173 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 174 | return joints 175 | 176 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): 177 | # cont6d_params (batch_size, joints_num, 6) 178 | # joints (batch_size, joints_num, 3) 179 | # root_pos (batch_size, 3) 180 | if skel_joints is not None: 181 | # skel_joints = torch.from_numpy(skel_joints) 182 | offsets = self.get_offsets_joints_batch(skel_joints) 183 | if len(self._offset.shape) == 2: 184 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) 185 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) 186 | joints[..., 0, :] = root_pos 187 | for chain in self._kinematic_tree: 188 | if do_root_R: 189 | matR = cont6d_to_matrix(cont6d_params[:, 0]) 190 | else: 191 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) 192 | for i in range(1, len(chain)): 193 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) 194 | offset_vec = offsets[:, chain[i]].unsqueeze(-1) 195 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] 196 | return joints 197 | 198 | 199 | def uniform_skeleton(positions, target_skeleton_path="data/motions_processed/person1/1.npy"): 200 | """ 201 | Create a uniform skeleton from the source skeleton to the target skeleton 202 | :param positions: Source Skeleton motion 203 | :param target_skeleton_path: Target Skeleton path of the motion 204 | :return: New joints of the source skeleton in the target skeleton format 205 | """ 206 | 207 | # Target Skeleton 208 | n_raw_offsets = torch.from_numpy(HML_RAW_OFFSETS) 209 | kinematic_chain = HML_KINEMATIC_CHAIN 210 | example_data = np.load(target_skeleton_path) 211 | example_data = example_data.reshape(len(example_data), -1, 3) 212 | example_data = torch.from_numpy(example_data) 213 | target_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') 214 | target_offset = target_skel.get_offsets_joints(example_data[0]) 215 | 216 | # Source Skeleton 217 | src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu') 218 | src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0])) 219 | src_offset = src_offset.numpy() 220 | tgt_offset = target_offset.numpy() 221 | 222 | # Calculate Scale Ratio as the ratio of legs 223 | src_leg_len = np.abs(src_offset[L_IDX1]).max() + np.abs(src_offset[L_IDX2]).max() 224 | tgt_leg_len = np.abs(tgt_offset[L_IDX1]).max() + np.abs(tgt_offset[L_IDX2]).max() 225 | 226 | scale_rt = tgt_leg_len / src_leg_len 227 | src_root_pos = positions[:, 0] 228 | tgt_root_pos = src_root_pos * scale_rt 229 | 230 | # Inverse Kinematics 231 | quat_params = src_skel.inverse_kinematics_np(positions, FACE_JOINT_INDX) 232 | 233 | # Forward Kinematics 234 | src_skel.set_offset(target_offset) 235 | new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos) 236 | return new_joints -------------------------------------------------------------------------------- /in2in/scripts/eval/interhuman.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(sys.path[0]+r"/../../../") 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from datetime import datetime 8 | from in2in.evaluation.utils import get_dataset_motion_loader, get_motion_loader_in2IN, EvaluatorModelWrapper 9 | from in2in.utils.metrics import * 10 | from collections import OrderedDict 11 | from in2in.utils.plot import * 12 | from in2in.utils.utils import * 13 | from in2in.utils.configs import get_config 14 | from tqdm import tqdm 15 | from in2in.models.dualmdm import load_DualMDM_model 16 | from in2in.models.in2in import in2IN 17 | 18 | import argparse 19 | 20 | def evaluate_matching_score(motion_loaders, file): 21 | match_score_dict = OrderedDict({}) 22 | R_precision_dict = OrderedDict({}) 23 | activation_dict = OrderedDict({}) 24 | print('========== Evaluating MM Distance ==========') 25 | for motion_loader_name, motion_loader in motion_loaders.items(): 26 | all_motion_embeddings = [] 27 | score_list = [] 28 | all_size = 0 29 | mm_dist_sum = 0 30 | top_k_count = 0 31 | with torch.no_grad(): 32 | for idx, batch in tqdm(enumerate(motion_loader)): 33 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch) 34 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), 35 | motion_embeddings.cpu().numpy()) 36 | mm_dist_sum += dist_mat.trace() 37 | 38 | argsmax = np.argsort(dist_mat, axis=1) 39 | top_k_mat = calculate_top_k(argsmax, top_k=3) 40 | top_k_count += top_k_mat.sum(axis=0) 41 | 42 | all_size += text_embeddings.shape[0] 43 | 44 | all_motion_embeddings.append(motion_embeddings.cpu().numpy()) 45 | 46 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) 47 | mm_dist = mm_dist_sum / all_size 48 | R_precision = top_k_count / all_size 49 | match_score_dict[motion_loader_name] = mm_dist 50 | R_precision_dict[motion_loader_name] = R_precision 51 | activation_dict[motion_loader_name] = all_motion_embeddings 52 | 53 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}') 54 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}', file=file, flush=True) 55 | 56 | line = f'---> [{motion_loader_name}] R_precision: ' 57 | for i in range(len(R_precision)): 58 | line += '(top %d): %.4f ' % (i+1, R_precision[i]) 59 | print(line) 60 | print(line, file=file, flush=True) 61 | 62 | return match_score_dict, R_precision_dict, activation_dict 63 | 64 | 65 | def evaluate_fid(groundtruth_loader, activation_dict, file): 66 | eval_dict = OrderedDict({}) 67 | gt_motion_embeddings = [] 68 | print('========== Evaluating FID ==========') 69 | with torch.no_grad(): 70 | for idx, batch in tqdm(enumerate(groundtruth_loader)): 71 | motion_embeddings = eval_wrapper.get_motion_embeddings(batch) 72 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) 73 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) 74 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) 75 | 76 | for model_name, motion_embeddings in activation_dict.items(): 77 | mu, cov = calculate_activation_statistics(motion_embeddings) 78 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 79 | print(f'---> [{model_name}] FID: {fid:.4f}') 80 | print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) 81 | eval_dict[model_name] = fid 82 | return eval_dict 83 | 84 | 85 | def evaluate_diversity(activation_dict, file): 86 | eval_dict = OrderedDict({}) 87 | print('========== Evaluating Diversity ==========') 88 | for model_name, motion_embeddings in activation_dict.items(): 89 | diversity = calculate_diversity(motion_embeddings, diversity_times) 90 | eval_dict[model_name] = diversity 91 | print(f'---> [{model_name}] Diversity: {diversity:.4f}') 92 | print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) 93 | return eval_dict 94 | 95 | 96 | def evaluate_multimodality(mm_motion_loaders, file): 97 | eval_dict = OrderedDict({}) 98 | print('========== Evaluating MultiModality ==========') 99 | for model_name, mm_motion_loader in mm_motion_loaders.items(): 100 | mm_motion_embeddings = [] 101 | with torch.no_grad(): 102 | for idx, batch in enumerate(mm_motion_loader): 103 | # (1, mm_replications, dim_pos) 104 | batch[2] = batch[2][0] 105 | batch[3] = batch[3][0] 106 | batch[4] = batch[4][0] 107 | motion_embedings = eval_wrapper.get_motion_embeddings(batch) 108 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) 109 | if len(mm_motion_embeddings) == 0: 110 | multimodality = 0 111 | else: 112 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() 113 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) 114 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}') 115 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) 116 | eval_dict[model_name] = multimodality 117 | return eval_dict 118 | 119 | 120 | def get_metric_statistics(values): 121 | mean = np.mean(values, axis=0) 122 | std = np.std(values, axis=0) 123 | conf_interval = 1.96 * std / np.sqrt(replication_times) 124 | return mean, conf_interval 125 | 126 | 127 | def evaluation(log_file): 128 | with open(log_file, 'w') as f: 129 | all_metrics = OrderedDict({'MM Distance': OrderedDict({}), 130 | 'R_precision': OrderedDict({}), 131 | 'FID': OrderedDict({}), 132 | 'Diversity': OrderedDict({}), 133 | 'MultiModality': OrderedDict({})}) 134 | for replication in range(replication_times): 135 | motion_loaders = {} 136 | mm_motion_loaders = {} 137 | motion_loaders['ground truth'] = gt_loader 138 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): 139 | motion_loader, mm_motion_loader = motion_loader_getter() 140 | motion_loaders[motion_loader_name] = motion_loader 141 | mm_motion_loaders[motion_loader_name] = mm_motion_loader 142 | 143 | print(f'==================== Replication {replication} ====================') 144 | print(f'==================== Replication {replication} ====================', file=f, flush=True) 145 | print(f'Time: {datetime.now()}') 146 | print(f'Time: {datetime.now()}', file=f, flush=True) 147 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f) 148 | 149 | print(f'Time: {datetime.now()}') 150 | print(f'Time: {datetime.now()}', file=f, flush=True) 151 | fid_score_dict = evaluate_fid(gt_loader, acti_dict, f) 152 | 153 | print(f'Time: {datetime.now()}') 154 | print(f'Time: {datetime.now()}', file=f, flush=True) 155 | div_score_dict = evaluate_diversity(acti_dict, f) 156 | 157 | print(f'Time: {datetime.now()}') 158 | print(f'Time: {datetime.now()}', file=f, flush=True) 159 | mm_score_dict = evaluate_multimodality(mm_motion_loaders, f) 160 | 161 | print(f'!!! DONE !!!') 162 | print(f'!!! DONE !!!', file=f, flush=True) 163 | 164 | for key, item in mat_score_dict.items(): 165 | if key not in all_metrics['MM Distance']: 166 | all_metrics['MM Distance'][key] = [item] 167 | else: 168 | all_metrics['MM Distance'][key] += [item] 169 | 170 | for key, item in R_precision_dict.items(): 171 | if key not in all_metrics['R_precision']: 172 | all_metrics['R_precision'][key] = [item] 173 | else: 174 | all_metrics['R_precision'][key] += [item] 175 | 176 | for key, item in fid_score_dict.items(): 177 | if key not in all_metrics['FID']: 178 | all_metrics['FID'][key] = [item] 179 | else: 180 | all_metrics['FID'][key] += [item] 181 | 182 | for key, item in div_score_dict.items(): 183 | if key not in all_metrics['Diversity']: 184 | all_metrics['Diversity'][key] = [item] 185 | else: 186 | all_metrics['Diversity'][key] += [item] 187 | 188 | for key, item in mm_score_dict.items(): 189 | if key not in all_metrics['MultiModality']: 190 | all_metrics['MultiModality'][key] = [item] 191 | else: 192 | all_metrics['MultiModality'][key] += [item] 193 | 194 | for metric_name, metric_dict in all_metrics.items(): 195 | print('========== %s Summary ==========' % metric_name) 196 | print('========== %s Summary ==========' % metric_name, file=f, flush=True) 197 | 198 | for model_name, values in metric_dict.items(): 199 | mean, conf_interval = get_metric_statistics(np.array(values)) 200 | if isinstance(mean, np.float64) or isinstance(mean, np.float32): 201 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') 202 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) 203 | elif isinstance(mean, np.ndarray): 204 | line = f'---> [{model_name}]' 205 | for i in range(len(mean)): 206 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) 207 | print(line) 208 | print(line, file=f, flush=True) 209 | 210 | 211 | if __name__ == '__main__': 212 | 213 | # Create the parser 214 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments") 215 | 216 | # Add optional arguments 217 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file') 218 | parser.add_argument('--evaluator', type=str, required=True, help='Evaluator Configuration file') 219 | parser.add_argument('--out', type=str, required=True, help='Out file') 220 | parser.add_argument('--device', type=int, default=0, help='GPU device id') 221 | parser.add_argument('--mode', type=str, required=True, help='Mode of the inference (interaction, dual)') 222 | 223 | 224 | # Parse the arguments 225 | args = parser.parse_args() 226 | 227 | mm_num_samples = 100 228 | mm_num_repeats = 30 229 | mm_num_times = 10 230 | 231 | diversity_times = 300 232 | replication_times = 10 233 | 234 | # batch_size is fixed to 96!! 235 | batch_size = 96 236 | 237 | data_cfg = get_config("configs/datasets.yaml").interhuman_test 238 | 239 | eval_motion_loaders = {} 240 | model_cfg = get_config( args.model) 241 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu') 242 | torch.cuda.set_device(args.device) 243 | 244 | if args.mode == "dual": 245 | model = load_DualMDM_model(model_cfg) 246 | elif args.mode == "interaction": 247 | model = in2IN(model_cfg, args.mode) 248 | model.load_state_dict(torch.load(model_cfg.CHECKPOINT)) 249 | 250 | eval_motion_loaders[model_cfg.NAME] = lambda: get_motion_loader_in2IN( 251 | batch_size, 252 | model, 253 | gt_dataset, 254 | device, 255 | mm_num_samples, 256 | mm_num_repeats 257 | ) 258 | 259 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu') 260 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size) 261 | evalmodel_cfg = get_config(args.evaluator) 262 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device) 263 | log_file = args.out 264 | evaluation(log_file) 265 | -------------------------------------------------------------------------------- /in2in/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | from typing import Union 4 | import pykeops.torch as keops 5 | import torch 6 | import tqdm 7 | 8 | emb_scale = 6 9 | 10 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 11 | def euclidean_distance_matrix(matrix1, matrix2): 12 | """ 13 | Params: 14 | -- matrix1: N1 x D 15 | -- matrix2: N2 x D 16 | Returns: 17 | -- dist: N1 x N2 18 | dist[i, j] == distance(matrix1[i], matrix2[j]) 19 | """ 20 | assert matrix1.shape[1] == matrix2.shape[1] 21 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 22 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 23 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 24 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 25 | return dists 26 | 27 | def calculate_top_k(mat, top_k): 28 | size = mat.shape[0] 29 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 30 | bool_mat = (mat == gt_mat) 31 | correct_vec = False 32 | top_k_list = [] 33 | for i in range(top_k): 34 | # print(correct_vec, bool_mat[:, i]) 35 | correct_vec = (correct_vec | bool_mat[:, i]) 36 | # print(correct_vec) 37 | top_k_list.append(correct_vec[:, None]) 38 | top_k_mat = np.concatenate(top_k_list, axis=1) 39 | return top_k_mat 40 | 41 | 42 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False): 43 | dist_mat = euclidean_distance_matrix(embedding1, embedding2) 44 | argmax = np.argsort(dist_mat, axis=1) 45 | top_k_mat = calculate_top_k(argmax, top_k) 46 | if sum_all: 47 | return top_k_mat.sum(axis=0) 48 | else: 49 | return top_k_mat 50 | 51 | 52 | def calculate_matching_score(embedding1, embedding2, sum_all=False): 53 | assert len(embedding1.shape) == 2 54 | assert embedding1.shape[0] == embedding2.shape[0] 55 | assert embedding1.shape[1] == embedding2.shape[1] 56 | 57 | dist = linalg.norm(embedding1 - embedding2, axis=1) 58 | if sum_all: 59 | return dist.sum(axis=0) 60 | else: 61 | return dist 62 | 63 | def calculate_activation_statistics(activations): 64 | """ 65 | Params: 66 | -- activation: num_samples x dim_feat 67 | Returns: 68 | -- mu: dim_feat 69 | -- sigma: dim_feat x dim_feat 70 | """ 71 | activations = activations * emb_scale 72 | mu = np.mean(activations, axis=0) 73 | cov = np.cov(activations, rowvar=False) 74 | return mu, cov 75 | 76 | 77 | def calculate_diversity(activation, diversity_times): 78 | assert len(activation.shape) == 2 79 | assert activation.shape[0] > diversity_times 80 | num_samples = activation.shape[0] 81 | 82 | activation = activation * emb_scale 83 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 84 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 85 | dist = linalg.norm((activation[first_indices] - activation[second_indices])/2, axis=1) 86 | return dist.mean() 87 | 88 | 89 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 90 | """Numpy implementation of the Frechet Distance. 91 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 92 | and X_2 ~ N(mu_2, C_2) is 93 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 94 | Stable version by Dougal J. Sutherland. 95 | Params: 96 | -- mu1 : Numpy array containing the activations of a layer of the 97 | inception net (like returned by the function 'get_predictions') 98 | for generated samples. 99 | -- mu2 : The sample mean over activations, precalculated on an 100 | representative data set. 101 | -- sigma1: The covariance matrix over activations for generated samples. 102 | -- sigma2: The covariance matrix over activations, precalculated on an 103 | representative data set. 104 | Returns: 105 | -- : The Frechet Distance. 106 | """ 107 | 108 | mu1 = np.atleast_1d(mu1) 109 | mu2 = np.atleast_1d(mu2) 110 | 111 | sigma1 = np.atleast_2d(sigma1) 112 | sigma2 = np.atleast_2d(sigma2) 113 | 114 | assert mu1.shape == mu2.shape, \ 115 | 'Training and test mean vectors have different lengths' 116 | assert sigma1.shape == sigma2.shape, \ 117 | 'Training and test covariances have different dimensions' 118 | 119 | diff = mu1 - mu2 120 | 121 | # Product might be almost singular 122 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 123 | if not np.isfinite(covmean).all(): 124 | msg = ('fid calculation produces singular product; ' 125 | 'adding %s to diagonal of cov estimates') % eps 126 | print(msg) 127 | offset = np.eye(sigma1.shape[0]) * eps 128 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 129 | 130 | # Numerical error might give slight imaginary component 131 | if np.iscomplexobj(covmean): 132 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 133 | m = np.max(np.abs(covmean.imag)) 134 | raise ValueError('Imaginary component {}'.format(m)) 135 | covmean = covmean.real 136 | 137 | tr_covmean = np.trace(covmean) 138 | 139 | return (diff.dot(diff) + np.trace(sigma1) + 140 | np.trace(sigma2) - 2 * tr_covmean) 141 | 142 | 143 | def calculate_multimodality(activation, multimodality_times): 144 | assert len(activation.shape) == 3 145 | assert activation.shape[1] > multimodality_times 146 | num_per_sent = activation.shape[1] 147 | 148 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 149 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 150 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 151 | return dist.mean() 152 | 153 | def calculate_wasserstein(x: torch.Tensor, y: torch.Tensor, p: float = 2, 154 | w_x: Union[torch.Tensor, None] = None, 155 | w_y: Union[torch.Tensor, None] = None, 156 | eps: float = 1e-3, 157 | max_iters: int = 100, stop_thresh: float = 1e-5, 158 | verbose=False): 159 | """ 160 | Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds 161 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors. 162 | Note that this algorithm can be backpropped through 163 | (though this may be slow if using many iterations). 164 | 165 | :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row) 166 | :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row) 167 | :param p: Which norm to use. Must be an integer greater than 0. 168 | :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 169 | :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None. 170 | :param eps: The reciprocal of the sinkhorn entropy regularization parameter. 171 | :param max_iters: The maximum number of Sinkhorn iterations to perform. 172 | :param stop_thresh: Stop if the maximum change in the parameters is below this amount 173 | :param verbose: Print iterations 174 | :return: a triple (d, corrs_x_to_y, corr_y_to_x) where: 175 | * d is the approximate p-wasserstein distance between point clouds x and y 176 | * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair) 177 | * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair) 178 | """ 179 | 180 | if not isinstance(p, int): 181 | raise TypeError(f"p must be an integer greater than 0, got {p}") 182 | if p <= 0: 183 | raise ValueError(f"p must be an integer greater than 0, got {p}") 184 | 185 | if eps <= 0: 186 | raise ValueError("Entropy regularization term eps must be > 0") 187 | 188 | if not isinstance(p, int): 189 | raise TypeError(f"max_iters must be an integer > 0, got {max_iters}") 190 | if max_iters <= 0: 191 | raise ValueError(f"max_iters must be an integer > 0, got {max_iters}") 192 | 193 | if not isinstance(stop_thresh, float): 194 | raise TypeError(f"stop_thresh must be a float, got {stop_thresh}") 195 | 196 | if len(x.shape) != 2: 197 | raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}") 198 | if len(y.shape) != 2: 199 | raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}") 200 | if x.shape[1] != y.shape[1]: 201 | raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], " 202 | f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}") 203 | 204 | if w_x is not None: 205 | if w_y is None: 206 | raise ValueError("If w_x is not None, w_y must also be not None") 207 | if len(w_x.shape) > 1: 208 | w_x = w_x.squeeze() 209 | if len(w_x.shape) != 1: 210 | raise ValueError(f"w_x must have shape [n,] or [n, 1] " 211 | f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}") 212 | if w_x.shape[0] != x.shape[0]: 213 | raise ValueError(f"w_x must match the shape of x in dimension 0 but got " 214 | f"x.shape = {x.shape} and w_x.shape = {w_x.shape}") 215 | if w_y is not None: 216 | if w_x is None: 217 | raise ValueError("If w_y is not None, w_x must also be not None") 218 | if len(w_y.shape) > 1: 219 | w_y = w_y.squeeze() 220 | if len(w_y.shape) != 1: 221 | raise ValueError(f"w_y must have shape [n,] or [n, 1] " 222 | f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}") 223 | if w_x.shape[0] != x.shape[0]: 224 | raise ValueError(f"w_y must match the shape of y in dimension 0 but got " 225 | f"y.shape = {y.shape} and w_y.shape = {w_y.shape}") 226 | 227 | 228 | # Distance matrix [n, m] 229 | x_i = keops.Vi(x) # [n, 1, d] 230 | y_j = keops.Vj(y) # [i, m, d] 231 | if p == 1: 232 | M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m] 233 | else: 234 | M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p) # [n, m] 235 | 236 | # Weights [n,] and [m,] 237 | if w_x is None and w_y is None: 238 | w_x = torch.ones(x.shape[0]).to(x) / x.shape[0] 239 | w_y = torch.ones(y.shape[0]).to(x) / y.shape[0] 240 | w_y *= (w_x.shape[0] / w_y.shape[0]) 241 | 242 | sum_w_x = w_x.sum().item() 243 | sum_w_y = w_y.sum().item() 244 | if abs(sum_w_x - sum_w_y) > 1e-5: 245 | raise ValueError(f"Weights w_x and w_y do not sum to the same value, " 246 | f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} " 247 | f"(absolute difference = {abs(sum_w_x - sum_w_y)}") 248 | 249 | log_a = torch.log(w_x) # [n] 250 | log_b = torch.log(w_y) # [m] 251 | 252 | # Initialize the iteration with the change of variable 253 | u = torch.zeros_like(w_x) 254 | v = eps * torch.log(w_y) 255 | 256 | u_i = keops.Vi(u.unsqueeze(-1)) 257 | v_j = keops.Vj(v.unsqueeze(-1)) 258 | 259 | if verbose: 260 | pbar = tqdm.trange(max_iters) 261 | else: 262 | pbar = range(max_iters) 263 | 264 | for _ in pbar: 265 | u_prev = u 266 | v_prev = v 267 | 268 | summand_u = (-M_ij + v_j) / eps 269 | u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze()) 270 | u_i = keops.Vi(u.unsqueeze(-1)) 271 | 272 | summand_v = (-M_ij + u_i) / eps 273 | v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze()) 274 | v_j = keops.Vj(v.unsqueeze(-1)) 275 | 276 | max_err_u = torch.max(torch.abs(u_prev-u)) 277 | max_err_v = torch.max(torch.abs(v_prev-v)) 278 | if verbose: 279 | pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()}) 280 | if max_err_u < stop_thresh and max_err_v < stop_thresh: 281 | break 282 | 283 | P_ij = ((-M_ij + u_i + v_j) / eps).exp() 284 | 285 | approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1) 286 | approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1) 287 | 288 | if u.shape[0] > v.shape[0]: 289 | distance = (P_ij * M_ij).sum(dim=1).sum() 290 | else: 291 | distance = (P_ij * M_ij).sum(dim=0).sum() 292 | return distance, approx_corr_1, approx_corr_2 293 | -------------------------------------------------------------------------------- /in2in/utils/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | _EPS4 = np.finfo(np.float32).eps * 4.0 12 | 13 | _FLOAT_EPS = np.finfo(np.float32).eps 14 | 15 | # PyTorch-backed implementations 16 | def qinv(q): 17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 18 | mask = torch.ones_like(q) 19 | mask[..., 1:] = -mask[..., 1:] 20 | return q * mask 21 | 22 | 23 | def qinv_np(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | return qinv(torch.from_numpy(q).float()).numpy() 26 | 27 | 28 | def qnormalize(q): 29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 30 | return q / torch.norm(q, dim=-1, keepdim=True) 31 | 32 | 33 | def qmul(q, r): 34 | """ 35 | Multiply quaternion(s) q with quaternion(s) r. 36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 37 | Returns q*r as a tensor of shape (*, 4). 38 | """ 39 | assert q.shape[-1] == 4 40 | assert r.shape[-1] == 4 41 | 42 | original_shape = q.shape 43 | 44 | # Compute outer product 45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 46 | 47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 51 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 52 | 53 | 54 | def qrot(q, v): 55 | """ 56 | Rotate vector(s) v about the rotation described by quaternion(s) q. 57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 58 | where * denotes any number of dimensions. 59 | Returns a tensor of shape (*, 3). 60 | """ 61 | assert q.shape[-1] == 4 62 | assert v.shape[-1] == 3 63 | assert q.shape[:-1] == v.shape[:-1] 64 | 65 | original_shape = list(v.shape) 66 | # print(q.shape) 67 | q = q.contiguous().view(-1, 4) 68 | v = v.contiguous().view(-1, 3) 69 | 70 | qvec = q[:, 1:] 71 | uv = torch.cross(qvec, v, dim=1) 72 | uuv = torch.cross(qvec, uv, dim=1) 73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 74 | 75 | 76 | def qeuler(q, order, epsilon=0, deg=True): 77 | """ 78 | Convert quaternion(s) q to Euler angles. 79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 80 | Returns a tensor of shape (*, 3). 81 | """ 82 | assert q.shape[-1] == 4 83 | 84 | original_shape = list(q.shape) 85 | original_shape[-1] = 3 86 | q = q.view(-1, 4) 87 | 88 | q0 = q[:, 0] 89 | q1 = q[:, 1] 90 | q2 = q[:, 2] 91 | q3 = q[:, 3] 92 | 93 | if order == 'xyz': 94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | elif order == 'yzx': 98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 101 | elif order == 'zxy': 102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 105 | elif order == 'xzy': 106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 109 | elif order == 'yxz': 110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 113 | elif order == 'zyx': 114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 117 | else: 118 | raise 119 | 120 | if deg: 121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi 122 | else: 123 | return torch.stack((x, y, z), dim=1).view(original_shape) 124 | 125 | 126 | # Numpy-backed implementations 127 | def qmul_np(q, r): 128 | q = torch.from_numpy(q).contiguous().float() 129 | r = torch.from_numpy(r).contiguous().float() 130 | return qmul(q, r).numpy() 131 | 132 | 133 | def qrot_np(q, v): 134 | q = torch.from_numpy(q).contiguous().float() 135 | v = torch.from_numpy(v).contiguous().float() 136 | return qrot(q, v).numpy() 137 | 138 | 139 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 140 | if use_gpu: 141 | q = torch.from_numpy(q).cuda().float() 142 | return qeuler(q, order, epsilon).cpu().numpy() 143 | else: 144 | q = torch.from_numpy(q).contiguous().float() 145 | return qeuler(q, order, epsilon).numpy() 146 | 147 | 148 | def qfix(q): 149 | """ 150 | Enforce quaternion continuity across the time dimension by selecting 151 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 152 | between two consecutive frames. 153 | 154 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 155 | Returns a tensor of the same shape. 156 | """ 157 | assert len(q.shape) == 3 158 | assert q.shape[-1] == 4 159 | 160 | result = q.copy() 161 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 162 | mask = dot_products < 0 163 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 164 | result[1:][mask] *= -1 165 | return result 166 | 167 | 168 | def euler2quat(e, order, deg=True): 169 | """ 170 | Convert Euler angles to quaternions. 171 | """ 172 | assert e.shape[-1] == 3 173 | 174 | original_shape = list(e.shape) 175 | original_shape[-1] = 4 176 | 177 | e = e.view(-1, 3) 178 | 179 | ## if euler angles in degrees 180 | if deg: 181 | e = e * np.pi / 180. 182 | 183 | x = e[:, 0] 184 | y = e[:, 1] 185 | z = e[:, 2] 186 | 187 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1) 188 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1) 189 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1) 190 | 191 | result = None 192 | for coord in order: 193 | if coord == 'x': 194 | r = rx 195 | elif coord == 'y': 196 | r = ry 197 | elif coord == 'z': 198 | r = rz 199 | else: 200 | raise 201 | if result is None: 202 | result = r 203 | else: 204 | result = qmul(result, r) 205 | 206 | # Reverse antipodal representation to have a non-negative "w" 207 | if order in ['xyz', 'yzx', 'zxy']: 208 | result *= -1 209 | 210 | return result.view(original_shape) 211 | 212 | 213 | def expmap_to_quaternion(e): 214 | """ 215 | Convert axis-angle rotations (aka exponential maps) to quaternions. 216 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 217 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 218 | Returns a tensor of shape (*, 4). 219 | """ 220 | assert e.shape[-1] == 3 221 | 222 | original_shape = list(e.shape) 223 | original_shape[-1] = 4 224 | e = e.reshape(-1, 3) 225 | 226 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 227 | w = np.cos(0.5 * theta).reshape(-1, 1) 228 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 229 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 230 | 231 | 232 | def euler_to_quaternion(e, order): 233 | """ 234 | Convert Euler angles to quaternions. 235 | """ 236 | assert e.shape[-1] == 3 237 | 238 | original_shape = list(e.shape) 239 | original_shape[-1] = 4 240 | 241 | e = e.reshape(-1, 3) 242 | 243 | x = e[:, 0] 244 | y = e[:, 1] 245 | z = e[:, 2] 246 | 247 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1) 248 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1) 249 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1) 250 | 251 | result = None 252 | for coord in order: 253 | if coord == 'x': 254 | r = rx 255 | elif coord == 'y': 256 | r = ry 257 | elif coord == 'z': 258 | r = rz 259 | else: 260 | raise 261 | if result is None: 262 | result = r 263 | else: 264 | result = qmul_np(result, r) 265 | 266 | # Reverse antipodal representation to have a non-negative "w" 267 | if order in ['xyz', 'yzx', 'zxy']: 268 | result *= -1 269 | 270 | return result.reshape(original_shape) 271 | 272 | 273 | def quaternion_to_matrix(quaternions): 274 | """ 275 | Convert rotations given as quaternions to rotation matrices. 276 | Args: 277 | quaternions: quaternions with real part first, 278 | as tensor of shape (..., 4). 279 | Returns: 280 | Rotation matrices as tensor of shape (..., 3, 3). 281 | """ 282 | r, i, j, k = torch.unbind(quaternions, -1) 283 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 284 | 285 | o = torch.stack( 286 | ( 287 | 1 - two_s * (j * j + k * k), 288 | two_s * (i * j - k * r), 289 | two_s * (i * k + j * r), 290 | two_s * (i * j + k * r), 291 | 1 - two_s * (i * i + k * k), 292 | two_s * (j * k - i * r), 293 | two_s * (i * k - j * r), 294 | two_s * (j * k + i * r), 295 | 1 - two_s * (i * i + j * j), 296 | ), 297 | -1, 298 | ) 299 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 300 | 301 | 302 | def quaternion_to_matrix_np(quaternions): 303 | q = torch.from_numpy(quaternions).contiguous().float() 304 | return quaternion_to_matrix(q).numpy() 305 | 306 | 307 | def quaternion_to_cont6d_np(quaternions): 308 | rotation_mat = quaternion_to_matrix_np(quaternions) 309 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1) 310 | return cont_6d 311 | 312 | 313 | def quaternion_to_cont6d(quaternions): 314 | rotation_mat = quaternion_to_matrix(quaternions) 315 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) 316 | return cont_6d 317 | 318 | 319 | def cont6d_to_matrix(cont6d): 320 | assert cont6d.shape[-1] == 6, "The last dimension must be 6" 321 | x_raw = cont6d[..., 0:3] 322 | y_raw = cont6d[..., 3:6] 323 | 324 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) 325 | z = torch.cross(x, y_raw, dim=-1) 326 | z = z / torch.norm(z, dim=-1, keepdim=True) 327 | 328 | y = torch.cross(z, x, dim=-1) 329 | 330 | x = x[..., None] 331 | y = y[..., None] 332 | z = z[..., None] 333 | 334 | mat = torch.cat([x, y, z], dim=-1) 335 | return mat 336 | 337 | 338 | def cont6d_to_matrix_np(cont6d): 339 | q = torch.from_numpy(cont6d).contiguous().float() 340 | return cont6d_to_matrix(q).numpy() 341 | 342 | 343 | def qpow(q0, t, dtype=torch.float): 344 | ''' q0 : tensor of quaternions 345 | t: tensor of powers 346 | ''' 347 | q0 = qnormalize(q0) 348 | theta0 = torch.acos(q0[..., 0]) 349 | 350 | ## if theta0 is close to zero, add epsilon to avoid NaNs 351 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) 352 | theta0 = (1 - mask) * theta0 + mask * 10e-10 353 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) 354 | 355 | if isinstance(t, torch.Tensor): 356 | q = torch.zeros(t.shape + q0.shape) 357 | theta = t.view(-1, 1) * theta0.view(1, -1) 358 | else: ## if t is a number 359 | q = torch.zeros(q0.shape) 360 | theta = t * theta0 361 | 362 | q[..., 0] = torch.cos(theta) 363 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) 364 | 365 | return q.to(dtype) 366 | 367 | 368 | def qslerp(q0, q1, t): 369 | ''' 370 | q0: starting quaternion 371 | q1: ending quaternion 372 | t: array of points along the way 373 | 374 | Returns: 375 | Tensor of Slerps: t.shape + q0.shape 376 | ''' 377 | 378 | q0 = qnormalize(q0) 379 | q1 = qnormalize(q1) 380 | q_ = qpow(qmul(q1, qinv(q0)), t) 381 | 382 | return qmul(q_, 383 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous()) 384 | 385 | 386 | def qbetween(v0, v1): 387 | ''' 388 | find the quaternion used to rotate v0 to v1 389 | ''' 390 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 391 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 392 | 393 | v = torch.cross(v0, v1) 394 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, 395 | keepdim=True) 396 | return qnormalize(torch.cat([w, v], dim=-1)) 397 | 398 | 399 | def qbetween_np(v0, v1): 400 | ''' 401 | find the quaternion used to rotate v0 to v1 402 | ''' 403 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' 404 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' 405 | 406 | v0 = torch.from_numpy(v0).float() 407 | v1 = torch.from_numpy(v1).float() 408 | return qbetween(v0, v1).numpy() 409 | 410 | 411 | def lerp(p0, p1, t): 412 | if not isinstance(t, torch.Tensor): 413 | t = torch.Tensor([t]) 414 | 415 | new_shape = t.shape + p0.shape 416 | new_view_t = t.shape + torch.Size([1] * len(p0.shape)) 417 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape 418 | p0 = p0.view(new_view_p).expand(new_shape) 419 | p1 = p1.view(new_view_p).expand(new_shape) 420 | t = t.view(new_view_t).expand(new_shape) 421 | 422 | return p0 + t * (p1 - p0) 423 | -------------------------------------------------------------------------------- /in2in/models/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from in2in.utils.utils import * 4 | 5 | 6 | class InterLoss(nn.Module): 7 | def __init__(self, recons_loss, nb_joints): 8 | super(InterLoss, self).__init__() 9 | self.nb_joints = nb_joints 10 | if recons_loss == 'l1': 11 | self.Loss = torch.nn.L1Loss(reduction='none') 12 | elif recons_loss == 'l2': 13 | self.Loss = torch.nn.MSELoss(reduction='none') 14 | elif recons_loss == 'l1_smooth': 15 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 16 | 17 | self.normalizer = MotionNormalizerTorch() 18 | 19 | self.weights = {} 20 | self.weights["RO"] = 0.01 21 | self.weights["JA"] = 3 22 | self.weights["DM"] = 3 23 | 24 | self.losses = {} 25 | 26 | def seq_masked_mse(self, prediction, target, mask): 27 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 28 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7) 29 | return loss 30 | 31 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None): 32 | if dm_mask is not None: 33 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) 34 | else: 35 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1] 36 | if contact_mask is not None: 37 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7) 38 | loss = (loss * mask).sum(dim=(-1, -2, -3)) / (mask.sum(dim=(-1, -2, -3)) + 1.e-7) # [b] 39 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7) 40 | 41 | return loss 42 | 43 | def forward(self, motion_pred, motion_gt, mask, timestep_mask): 44 | B, T = motion_pred.shape[:2] 45 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) 46 | target = self.normalizer.backward(motion_gt, global_rt=True) 47 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 48 | 49 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 50 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) 51 | 52 | self.mask = mask 53 | self.timestep_mask = timestep_mask 54 | 55 | self.forward_distance_map(thresh=1) 56 | self.forward_joint_affinity(thresh=0.1) 57 | self.forward_relatvie_rot() 58 | self.accum_loss() 59 | 60 | 61 | def forward_relatvie_rot(self): 62 | r_hip, l_hip, sdr_r, sdr_l = FACE_JOINT_INDX 63 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :] 64 | across = across / across.norm(dim=-1, keepdim=True) 65 | across_gt = self.tgt_g_joints[..., r_hip, :] - self.tgt_g_joints[..., l_hip, :] 66 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True) 67 | 68 | y_axis = torch.zeros_like(across) 69 | y_axis[..., 1] = 1 70 | 71 | forward = torch.cross(y_axis, across, axis=-1) 72 | forward = forward / forward.norm(dim=-1, keepdim=True) 73 | forward_gt = torch.cross(y_axis, across_gt, axis=-1) 74 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True) 75 | 76 | pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :]) 77 | tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :]) 78 | 79 | self.losses["RO"] = self.mix_masked_mse(pred_relative_rot[..., [0, 2]], 80 | tgt_relative_rot[..., [0, 2]], 81 | self.mask[..., 0, :], self.timestep_mask) * self.weights["RO"] 82 | 83 | 84 | def forward_distance_map(self, thresh): 85 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 86 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 87 | 88 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 89 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 90 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 91 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 92 | 93 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape( 94 | self.mask.shape[:-2] + (1, -1,)) 95 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape( 96 | self.mask.shape[:-2] + (1, -1,)) 97 | 98 | distance_matrix_mask = (pred_distance_matrix < thresh).float() 99 | 100 | self.losses["DM"] = self.mix_masked_mse(pred_distance_matrix, tgt_distance_matrix, 101 | self.mask[..., 0:1, :], 102 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["DM"] 103 | 104 | def forward_joint_affinity(self, thresh): 105 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 106 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,)) 107 | 108 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 109 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 110 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3) 111 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3) 112 | 113 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape( 114 | self.mask.shape[:-2] + (1, -1,)) 115 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape( 116 | self.mask.shape[:-2] + (1, -1,)) 117 | 118 | distance_matrix_mask = (tgt_distance_matrix < thresh).float() 119 | 120 | self.losses["JA"] = self.mix_masked_mse(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix), 121 | self.mask[..., 0:1, :], 122 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["JA"] 123 | 124 | def accum_loss(self): 125 | loss = 0 126 | for term in self.losses.keys(): 127 | loss += self.losses[term] 128 | self.losses["total"] = loss 129 | return self.losses 130 | 131 | 132 | 133 | class GeometricLoss(nn.Module): 134 | def __init__(self, recons_loss, nb_joints, name, mode="interaction"): 135 | super(GeometricLoss, self).__init__() 136 | self.mode = mode 137 | self.name = name 138 | self.nb_joints = nb_joints 139 | if recons_loss == 'l1': 140 | self.Loss = torch.nn.L1Loss(reduction='none') 141 | elif recons_loss == 'l2': 142 | self.Loss = torch.nn.MSELoss(reduction='none') 143 | elif recons_loss == 'l1_smooth': 144 | self.Loss = torch.nn.SmoothL1Loss(reduction='none') 145 | 146 | if mode == "individual": 147 | self.normalizer = MotionNormalizerTorchHML3D() 148 | else: 149 | self.normalizer = MotionNormalizerTorch() 150 | 151 | self.fids = [7, 10, 8, 11] 152 | 153 | self.weights = {} 154 | self.weights["VEL"] = 30 155 | self.weights["BL"] = 10 156 | self.weights["FC"] = 30 157 | self.weights["POSE"] = 1 158 | self.weights["TR"] = 100 159 | 160 | self.losses = {} 161 | 162 | def seq_masked_mse(self, prediction, target, mask): 163 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) 164 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7) 165 | return loss 166 | 167 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None): 168 | if dm_mask is not None: 169 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) # [b,t,p,4,1] 170 | else: 171 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1] 172 | if contact_mask is not None: 173 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7) 174 | loss = (loss * mask).sum(dim=(-1, -2)) / (mask.sum(dim=(-1, -2)) + 1.e-7) # [b] 175 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7) 176 | 177 | return loss 178 | 179 | def forward(self, motion_pred, motion_gt, mask, timestep_mask): 180 | B, T = motion_pred.shape[:2] 181 | 182 | if self.mode == "individual": 183 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) 184 | 185 | target = self.normalizer.backward(motion_gt, global_rt=True) 186 | prediction = self.normalizer.backward(motion_pred, global_rt=True) 187 | 188 | self.first_motion_pred =motion_pred[:,0:1] 189 | self.first_motion_gt =motion_gt[:,0:1] 190 | 191 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3) 192 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3) 193 | self.mask = mask 194 | self.timestep_mask = timestep_mask 195 | 196 | if self.mode != "individual": 197 | self.forward_vel() 198 | self.forward_bone_length() 199 | self.forward_contact() 200 | 201 | self.accum_loss() 202 | 203 | def get_local_positions(self, positions, r_rot): 204 | '''Local pose''' 205 | positions[..., 0] -= positions[..., 0:1, 0] 206 | positions[..., 2] -= positions[..., 0:1, 2] 207 | '''All pose face Z+''' 208 | positions = qrot(r_rot[..., None, :].repeat(1, 1, positions.shape[-2], 1), positions) 209 | return positions 210 | 211 | def forward_local_pose(self): 212 | r_hip, l_hip, sdr_r, sdr_l = FACE_JOINT_INDX 213 | 214 | pred_g_joints = self.pred_g_joints.clone() 215 | tgt_g_joints = self.tgt_g_joints.clone() 216 | 217 | across = pred_g_joints[..., r_hip, :] - pred_g_joints[..., l_hip, :] 218 | across = across / across.norm(dim=-1, keepdim=True) 219 | across_gt = tgt_g_joints[..., r_hip, :] - tgt_g_joints[..., l_hip, :] 220 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True) 221 | 222 | y_axis = torch.zeros_like(across) 223 | y_axis[..., 1] = 1 224 | 225 | forward = torch.cross(y_axis, across, axis=-1) 226 | forward = forward / forward.norm(dim=-1, keepdim=True) 227 | forward_gt = torch.cross(y_axis, across_gt, axis=-1) 228 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True) 229 | 230 | z_axis = torch.zeros_like(forward) 231 | z_axis[..., 2] = 1 232 | noise = torch.randn_like(z_axis) *0.0001 233 | z_axis = z_axis+noise 234 | z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True) 235 | 236 | 237 | pred_rot = qbetween(forward, z_axis) 238 | tgt_rot = qbetween(forward_gt, z_axis) 239 | 240 | B, T, J, D = self.pred_g_joints.shape 241 | pred_joints = self.get_local_positions(pred_g_joints, pred_rot).reshape(B, T, -1) 242 | tgt_joints = self.get_local_positions(tgt_g_joints, tgt_rot).reshape(B, T, -1) 243 | 244 | self.losses["POSE_"+self.name] = self.mix_masked_mse(pred_joints, tgt_joints, self.mask, self.timestep_mask) * self.weights["POSE"] 245 | 246 | def forward_vel(self): 247 | B, T = self.pred_g_joints.shape[:2] 248 | 249 | pred_vel = self.pred_g_joints[:, 1:] - self.pred_g_joints[:, :-1] 250 | tgt_vel = self.tgt_g_joints[:, 1:] - self.tgt_g_joints[:, :-1] 251 | 252 | pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,)) 253 | tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,)) 254 | 255 | self.losses["VEL_"+self.name] = self.mix_masked_mse(pred_vel, tgt_vel, self.mask[:, :-1], self.timestep_mask) * self.weights["VEL"] 256 | 257 | 258 | def forward_contact(self): 259 | 260 | feet_vel = self.pred_g_joints[:, 1:, self.fids, :] - self.pred_g_joints[:, :-1, self.fids,:] 261 | feet_h = self.pred_g_joints[:, :-1, self.fids, 1] 262 | 263 | contact = self.foot_detect(feet_vel, feet_h, 0.001) 264 | 265 | self.losses["FC_"+self.name] = self.mix_masked_mse(feet_vel, torch.zeros_like(feet_vel), self.mask[:, :-1], 266 | self.timestep_mask, 267 | contact) * self.weights["FC"] 268 | 269 | def forward_bone_length(self): 270 | pred_g_joints = self.pred_g_joints 271 | tgt_g_joints = self.tgt_g_joints 272 | pred_bones = [] 273 | tgt_bones = [] 274 | for chain in HML_KINEMATIC_CHAIN: 275 | for i, joint in enumerate(chain[:-1]): 276 | pred_bone = (pred_g_joints[..., chain[i], :] - pred_g_joints[..., chain[i + 1], :]).norm(dim=-1, 277 | keepdim=True) # [B,T,P,1] 278 | tgt_bone = (tgt_g_joints[..., chain[i], :] - tgt_g_joints[..., chain[i + 1], :]).norm(dim=-1, 279 | keepdim=True) 280 | pred_bones.append(pred_bone) 281 | tgt_bones.append(tgt_bone) 282 | 283 | pred_bones = torch.cat(pred_bones, dim=-1) 284 | tgt_bones = torch.cat(tgt_bones, dim=-1) 285 | 286 | self.losses["BL_"+self.name] = self.mix_masked_mse(pred_bones, tgt_bones, self.mask, self.timestep_mask) * self.weights[ 287 | "BL"] 288 | 289 | 290 | def forward_traj(self): 291 | B, T = self.pred_g_joints.shape[:2] 292 | 293 | pred_traj = self.pred_g_joints[..., 0, [0, 2]] 294 | tgt_g_traj = self.tgt_g_joints[..., 0, [0, 2]] 295 | 296 | self.losses["TR_"+self.name] = self.mix_masked_mse(pred_traj, tgt_g_traj, self.mask, self.timestep_mask) * self.weights["TR"] 297 | 298 | 299 | def accum_loss(self): 300 | loss = 0 301 | for term in self.losses.keys(): 302 | loss += self.losses[term] 303 | self.losses[self.name] = loss 304 | 305 | def foot_detect(self, feet_vel, feet_h, thres): 306 | velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor( 307 | [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device) 308 | 309 | feet_x = (feet_vel[..., 0]) ** 2 310 | feet_y = (feet_vel[..., 1]) ** 2 311 | feet_z = (feet_vel[..., 2]) ** 2 312 | 313 | contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float() 314 | return contact -------------------------------------------------------------------------------- /in2in/evaluation/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | from os.path import join as pjoin 8 | from torch.utils.data import Dataset, DataLoader 9 | from in2in.utils.utils import MotionNormalizer 10 | 11 | class EvaluationDatasetInterHuman(Dataset): 12 | """ 13 | Evaluation Dataset of InterHuman. 14 | Motions are generated by the trained model to later be compared with the ground truth. 15 | """ 16 | def __init__(self, model, dataset, device, mm_num_samples, mm_num_repeats): 17 | """ 18 | Initialization of the dataset and generation of the motions. 19 | :param model: Model to generate the motions. 20 | :param dataset: Ground truth dataset. 21 | :param device: Device to run the model. 22 | :param mm_num_samples: Number of samples to generate for the MultiModality metric. 23 | :param mm_num_repeats: Number of repeats for each sample in the MultiModality metric. 24 | """ 25 | 26 | # Configuration variables 27 | self.normalizer = MotionNormalizer() 28 | self.model = model.to(device) 29 | self.model.eval() 30 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) 31 | self.max_length = dataset.max_length 32 | self.extended = dataset.extended 33 | 34 | # Indexes of the motions to generate 35 | idxs = list(range(len(dataset))) 36 | random.shuffle(idxs) 37 | mm_idxs = idxs[:mm_num_samples] 38 | 39 | # Data structures 40 | generated_motions = [] 41 | mm_generated_motions = [] 42 | 43 | with torch.no_grad(): 44 | for i, data in tqdm(enumerate(dataloader)): 45 | 46 | # Get the data from the data loader 47 | if self.extended: 48 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = data 49 | else: 50 | name, text, motion1, motion2, motion_lens = data 51 | 52 | # If data into MM list, duplicate them mm_num_repeats times 53 | batch = {} 54 | if i in mm_idxs: 55 | batch["text"] = list(text) * mm_num_repeats 56 | if self.extended: 57 | batch["text_individual1"] = list(text_individual1) * mm_num_repeats 58 | batch["text_individual2"] = list(text_individual2) * mm_num_repeats 59 | else: 60 | batch["text"] = list(text) 61 | if self.extended: 62 | batch["text_individual1"] = list(text_individual1) 63 | batch["text_individual2"] = list(text_individual2) 64 | 65 | batch["motion_lens"] = motion_lens 66 | 67 | # Predict the motions with the model to be evaluated 68 | batch = self.model.forward_test(batch) 69 | motions_output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], 2, -1) 70 | motions_output = self.normalizer.backward(motions_output.cpu().detach().numpy()) 71 | 72 | # Padding the motions to the max_length 73 | B,T = motions_output.shape[0], motions_output.shape[1] 74 | if T < self.max_length: 75 | padding_len = self.max_length - T 76 | D = motions_output.shape[-1] 77 | padding_zeros = np.zeros((B, padding_len, 2, D)) 78 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1) 79 | assert motions_output.shape[1] == self.max_length 80 | 81 | # Save the generated motions 82 | if self.extended: 83 | sub_dict = {'motion1': motions_output[0, :,0], 84 | 'motion2': motions_output[0, :,1], 85 | 'motion_lens': motion_lens[0], 86 | 'text': text[0], 87 | 'text_individual1': text_individual1[0], 88 | 'text_individual2': text_individual2[0]} 89 | generated_motions.append(sub_dict) 90 | if i in mm_idxs: 91 | mm_sub_dict = {'mm_motions': motions_output, 92 | 'motion_lens': motion_lens[0], 93 | 'text': text[0], 94 | 'text_individual1': text_individual1[0], 95 | 'text_individual2': text_individual2[0]} 96 | mm_generated_motions.append(mm_sub_dict) 97 | else: 98 | sub_dict = {'motion1': motions_output[0, :,0], 99 | 'motion2': motions_output[0, :,1], 100 | 'motion_lens': motion_lens[0], 101 | 'text': text[0]} 102 | generated_motions.append(sub_dict) 103 | if i in mm_idxs: 104 | mm_sub_dict = {'mm_motions': motions_output, 105 | 'motion_lens': motion_lens[0], 106 | 'text': text[0]} 107 | mm_generated_motions.append(mm_sub_dict) 108 | 109 | 110 | self.generated_motions = generated_motions 111 | self.mm_generated_motions = mm_generated_motions 112 | 113 | def __len__(self): 114 | """ 115 | Get the length of the dataset. 116 | """ 117 | return len(self.generated_motions) 118 | 119 | def __getitem__(self, item): 120 | """ 121 | Get the item of the dataset. 122 | :param item: Index of the item. 123 | """ 124 | data = self.generated_motions[item] 125 | 126 | # Return the data, if dataset extended also return individual descriptions 127 | if self.extended: 128 | motion1, motion2, motion_lens, text, text_individual1, text_individual2 = data['motion1'], data['motion2'], data['motion_lens'], data['text'], data['text_individual1'], data['text_individual2'] 129 | return "generated", text, motion1, motion2, motion_lens, text_individual1, text_individual2 130 | else: 131 | motion1, motion2, motion_lens, text = data['motion1'], data['motion2'], data['motion_lens'], data['text'] 132 | return "generated", text, motion1, motion2, motion_lens 133 | 134 | 135 | class MMGeneratedDatasetInterHuman(Dataset): 136 | """ 137 | Dataset for the MultiModality metric. 138 | """ 139 | def __init__(self, motion_dataset): 140 | """ 141 | Initialization of the dataset. 142 | :param motion_dataset: EvaluationDataset of the generated motions. 143 | """ 144 | self.dataset = motion_dataset.mm_generated_motions 145 | self.extended = motion_dataset.extended 146 | 147 | def __len__(self): 148 | """ 149 | Get the length of the dataset. 150 | """ 151 | return len(self.dataset) 152 | 153 | def __getitem__(self, item): 154 | """ 155 | Get the item of the dataset. 156 | :param item: Index of the item. 157 | """ 158 | data = self.dataset[item] 159 | mm_motions = data['mm_motions'] 160 | motion_lens = data['motion_lens'] 161 | mm_motions1 = mm_motions[:,:,0] 162 | mm_motions2 = mm_motions[:,:,1] 163 | text = data['text'] 164 | motion_lens = np.array([motion_lens]*mm_motions1.shape[0]) 165 | 166 | # If dataset extended also return individual descriptions 167 | if self.extended: 168 | text_individual1 = data['text_individual1'] 169 | text_individual2 = data['text_individual2'] 170 | return "mm_generated", text, mm_motions1, mm_motions2, motion_lens, text_individual1, text_individual2 171 | else: 172 | return "mm_generated", text, mm_motions1, mm_motions2, motion_lens 173 | 174 | 175 | 176 | class EvaluationDatasetDualMDM(Dataset): 177 | """ 178 | Evaluation Dataset of DualMDM. 179 | Motions are generated by the trained model to later be compared with the ground truth. 180 | The dataset is generated by combining interaction descriptions from InterHuman with individual descriptions from HumanML3D 181 | """ 182 | def __init__(self, model, dataset, device, num_repeats): 183 | """ 184 | Initialization of the dataset and generation of the motions. 185 | :param model: Model to generate the motions. 186 | :param dataset: Ground truth dataset. 187 | :param device: Device to run the model. 188 | :param num_repeats: Number of repeats for each sample. 189 | """ 190 | self.normalizer = MotionNormalizer() 191 | self.model = model.to(device) 192 | self.model.eval() 193 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) 194 | self.max_length = dataset.max_length 195 | 196 | # Paths to the individual descriptions from HumanML3D 197 | self.individual_text_path = "data/HumanML3D/texts" 198 | self.individual_text_files = os.listdir(self.individual_text_path) 199 | 200 | # Data structures 201 | generated_motions = [] 202 | 203 | # Model parameters 204 | composition_weight_func = self.model.decoder.cfg_composition_weight_func 205 | composition_weight_value = self.model.decoder.cfg_composition_weight_value 206 | 207 | with torch.no_grad(): 208 | for i, data in tqdm(enumerate(dataloader)): 209 | 210 | # Get the data from the gt InterHuman dataset 211 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = data 212 | 213 | batch = {} 214 | batch["motion_lens"] = motion_lens.repeat(num_repeats * 2) 215 | batch["text"] = list(text) * (num_repeats * 2) 216 | batch["text_individual1"] = list(text_individual1) * num_repeats 217 | batch["text_individual2"] = list(text_individual2) * num_repeats 218 | 219 | # Modify individual textual description to be from HumanML3D 220 | for j in range(num_repeats): 221 | # Select 2 random files from HumanML3D for the individual descriptions 222 | files = random.sample(self.individual_text_files, 2) 223 | 224 | with open(pjoin(self.individual_text_path, files[0]), "r") as f: 225 | hml3d_texts_individual1 = f.readlines() 226 | hml3d_text_individual1 = random.choice(hml3d_texts_individual1) 227 | hml3d_text_individual1 = hml3d_text_individual1.strip().split("#")[0] 228 | batch["text_individual1"].append(hml3d_text_individual1) 229 | 230 | with open(pjoin(self.individual_text_path, files[1]), "r") as f: 231 | hml3d_texts_individual2 = f.readlines() 232 | hml3d_text_individual2 = random.choice(hml3d_texts_individual2) 233 | hml3d_text_individual2 = hml3d_text_individual2.strip().split("#")[0] 234 | batch["text_individual2"].append(hml3d_text_individual2) 235 | 236 | 237 | # Generate interactions using the base interaction model (in2IN) 238 | batch_interaction = copy.deepcopy(batch) 239 | batch_interaction["motion_lens"] = batch_interaction["motion_lens"][:num_repeats] 240 | batch_interaction["text"] = batch_interaction["text"][:num_repeats] 241 | batch_interaction["text_individual1"] = batch_interaction["text_individual1"][:num_repeats] 242 | batch_interaction["text_individual2"] = batch_interaction["text_individual2"][:num_repeats] 243 | self.model.decoder.cfg_composition_weight_func = "const" 244 | self.model.decoder.cfg_composition_weight_value = 0 245 | batch_interaction = self.model.forward_test(batch_interaction) 246 | motions_output_interaction = batch_interaction["output"].reshape(batch_interaction["output"].shape[0], batch_interaction["output"].shape[1], 2, -1) 247 | motions_output_interaction = self.normalizer.backward(motions_output_interaction.cpu().detach().numpy()) 248 | 249 | # Generate interactions using the combined interaction model with the interaction (DualMDM) 250 | batch_individual = copy.deepcopy(batch) 251 | batch_individual["motion_lens"] = batch_individual["motion_lens"][num_repeats:] 252 | batch_individual["text"] = batch_individual["text"][num_repeats:] 253 | batch_individual["text_individual1"] = batch_individual["text_individual1"][num_repeats:] 254 | batch_individual["text_individual2"] = batch_individual["text_individual2"][num_repeats:] 255 | self.model.decoder.cfg_composition_weight_func = composition_weight_func 256 | self.model.decoder.cfg_composition_weight_value = composition_weight_value 257 | batch_individual = self.model.forward_test(batch_individual) 258 | motions_output_individual = batch_individual["output"].reshape(batch_individual["output"].shape[0], batch_individual["output"].shape[1], 2, -1) 259 | motions_output_individual = self.normalizer.backward(motions_output_individual.cpu().detach().numpy()) 260 | 261 | motions_output = np.concatenate((motions_output_interaction, motions_output_individual), axis=0) 262 | B,T = motions_output.shape[0], motions_output.shape[1] 263 | 264 | # Padding all the generated motions to the biggest length of the dataset 265 | if T < self.max_length: 266 | padding_len = self.max_length - T 267 | D = motions_output.shape[-1] 268 | padding_zeros = np.zeros((B, padding_len, 2, D)) 269 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1) 270 | assert motions_output.shape[1] == self.max_length 271 | 272 | # Save the generated motions 273 | sub_dict = {'generated_motions': motions_output, 274 | 'motion1': motion1, 275 | 'motion2': motion2, 276 | 'motion_lens': batch["motion_lens"], 277 | 'text': batch["text"], 278 | 'text_individual1': batch["text_individual1"], 279 | 'text_individual2': batch["text_individual2"]} 280 | 281 | generated_motions.append(sub_dict) 282 | 283 | 284 | self.generated_motions = generated_motions 285 | 286 | def __len__(self): 287 | """ 288 | Get the length of the dataset. 289 | """ 290 | return len(self.generated_motions) 291 | 292 | def __getitem__(self, item): 293 | """ 294 | Get the item of the dataset. 295 | :param item: Index of the item. 296 | """ 297 | data = self.generated_motions[item] 298 | 299 | motion_lens = data['motion_lens'] 300 | text = data['text'] 301 | text_individual1 = data['text_individual1'] 302 | text_individual2 = data['text_individual2'] 303 | 304 | generated_motions = data['generated_motions'] 305 | generated_motions1 = generated_motions[:, :, 0, :] 306 | generated_motions2 = generated_motions[:, :, 1, :] 307 | motion1 = data['motion1'] 308 | motion2 = data['motion2'] 309 | 310 | return generated_motions1, generated_motions2, motion1, motion2, motion_lens, text, text_individual1, text_individual2 -------------------------------------------------------------------------------- /in2in/models/nets.py: -------------------------------------------------------------------------------- 1 | from in2in.models.utils.cfg_sampler import ClassifierFreeSampleDualMDM, ClassifierFreeSampleModel, ClassifierFreeSampleModelMultiple 2 | import torch 3 | import random 4 | import torch.nn as nn 5 | 6 | 7 | from in2in.models.utils.blocks import TransformerBlockDoubleCond 8 | from in2in.models.utils.gaussian_diffusion import LossType, ModelMeanType, ModelVarType, MotionDiffusion, create_named_schedule_sampler, get_named_beta_schedule, space_timesteps 9 | from in2in.models.utils.layers import FinalLayer 10 | from in2in.models.utils.utils import PositionalEncoding, TimestepEmbedder, zero_module 11 | 12 | class in2INDiffusion(nn.Module): 13 | # Mode can be individual interaction or dual 14 | def __init__(self, cfg, mode, sampling_strategy="ddim50"): 15 | super().__init__() 16 | self.cfg = cfg 17 | self.nfeats = cfg.INPUT_DIM 18 | self.latent_dim = cfg.LATENT_DIM 19 | self.ff_size = cfg.FF_SIZE 20 | self.num_layers = cfg.NUM_LAYERS 21 | self.num_heads = cfg.NUM_HEADS 22 | self.dropout = cfg.DROPOUT 23 | self.activation = cfg.ACTIVATION 24 | self.motion_rep = cfg.MOTION_REP 25 | self.diffusion_steps = cfg.DIFFUSION_STEPS 26 | self.beta_scheduler = cfg.BETA_SCHEDULER 27 | self.sampler = cfg.SAMPLER 28 | self.sampling_strategy = sampling_strategy 29 | self.mode = mode 30 | 31 | # Setting wieghts 32 | if self.mode == "dual": 33 | self.cfg_weight_individual = cfg.CFG_WEIGHT_INDIVIDUAL 34 | self.cfg_weight_interaction = cfg.CFG_WEIGHT_INTERACTION 35 | self.cfg_composition_weight_func = cfg.W_FUNC 36 | self.cfg_composition_weight_value = cfg.W_VALUE 37 | elif self.mode == "interaction": 38 | self.cfg_weight = cfg.CFG_WEIGHT 39 | self.cfg_weight_individual = cfg.CFG_WEIGHT_INDIVIDUAL 40 | self.cfg_weight_interaction = cfg.CFG_WEIGHT_INTERACTION 41 | elif self.mode == "individual": 42 | self.cfg_weight = cfg.CFG_WEIGHT 43 | 44 | # Creaning network 45 | if self.mode =="dual": 46 | self.net_interaction = in2INDenoiser(self.nfeats, 47 | latent_dim=self.latent_dim, 48 | ff_size=self.ff_size, 49 | num_layers=self.num_layers, 50 | num_heads=self.num_heads, 51 | dropout=self.dropout, 52 | activation=self.activation, 53 | mode="dual_interaction") 54 | 55 | self.net_individual = in2INDenoiser(self.nfeats, 56 | latent_dim=self.latent_dim, 57 | ff_size=self.ff_size, 58 | num_layers=self.num_layers, 59 | num_heads=self.num_heads, 60 | dropout=self.dropout, 61 | activation=self.activation, 62 | mode="dual_individual") 63 | elif self.mode == "interaction": 64 | self.net_interaction = in2INDenoiser(self.nfeats, 65 | latent_dim=self.latent_dim, 66 | ff_size=self.ff_size, 67 | num_layers=self.num_layers, 68 | num_heads=self.num_heads, 69 | dropout=self.dropout, 70 | activation=self.activation, 71 | mode="interaction") 72 | elif self.mode == "individual": 73 | self.net_individual = in2INDenoiser(self.nfeats, 74 | latent_dim=self.latent_dim, 75 | ff_size=self.ff_size, 76 | num_layers=self.num_layers, 77 | num_heads=self.num_heads, 78 | dropout=self.dropout, 79 | activation=self.activation, 80 | mode="individual") 81 | 82 | 83 | self.diffusion_steps = self.diffusion_steps 84 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps) 85 | timestep_respacing=[self.diffusion_steps] 86 | 87 | 88 | self.diffusion = MotionDiffusion( 89 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing), 90 | betas=self.betas, 91 | motion_rep=self.motion_rep, 92 | model_mean_type=ModelMeanType.START_X, 93 | model_var_type=ModelVarType.FIXED_SMALL, 94 | loss_type=LossType.MSE, 95 | rescale_timesteps = False, 96 | mode=self.mode 97 | ) 98 | 99 | self.sampler = create_named_schedule_sampler(self.sampler, self.diffusion) 100 | 101 | def mask_cond(self, cond, cond_mask_prob = 0.1, force_mask=False): 102 | bs = cond.shape[0] 103 | if force_mask: 104 | return torch.zeros_like(cond) 105 | elif cond_mask_prob > 0.: 106 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * cond_mask_prob).view([bs]+[1]*len(cond.shape[1:])) # 1-> use null_cond, 0-> use real cond 107 | return cond * (1. - mask), (1. - mask) 108 | else: 109 | return cond, None 110 | 111 | def generate_src_mask(self, T, length): 112 | B = length.shape[0] 113 | src_mask = torch.ones(B, T, 2) 114 | for p in range(2): 115 | for i in range(B): 116 | for j in range(length[i], T): 117 | src_mask[i, j, p] = 0 118 | return src_mask 119 | 120 | def compute_loss(self, batch): 121 | 122 | if self.mode == "interaction": 123 | cond_interaction = batch["cond_interaction"] 124 | cond_interaction_individual1 = batch["cond_interaction_individual1"] 125 | cond_interaction_individual2 = batch["cond_interaction_individual2"] 126 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2], dim=1) 127 | elif self.mode == "individual": 128 | cond_individual_individual1 = batch["cond_individual_individual1"] 129 | cond = torch.cat([cond_individual_individual1], dim=1) 130 | 131 | x_start = batch["motions"] 132 | B,T = batch["motions"].shape[:2] 133 | 134 | if cond is not None: 135 | cond, cond_mask = self.mask_cond(cond, 0.1) 136 | 137 | seq_mask = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(x_start.device) 138 | 139 | t, _ = self.sampler.sample(B, x_start.device) 140 | 141 | if self.mode == "interaction": 142 | model = self.net_interaction 143 | elif self.mode == "individual": 144 | model = self.net_individual 145 | 146 | output = self.diffusion.training_losses( 147 | model=model, 148 | x_start=x_start, 149 | t=t, 150 | mask=seq_mask, 151 | t_bar=self.cfg.T_BAR, 152 | cond_mask=cond_mask, 153 | model_kwargs={"mask":seq_mask, 154 | "cond":cond, 155 | }, 156 | ) 157 | return output 158 | 159 | def forward(self, batch): 160 | 161 | if self.mode == "dual": 162 | cond_interaction = batch["cond_interaction"] 163 | cond_interaction_individual1 = batch["cond_interaction_individual1"] 164 | cond_interaction_individual2 = batch["cond_interaction_individual2"] 165 | cond_individual_individual1 = batch["cond_individual_individual1"] 166 | cond_individual_individual2 = batch["cond_individual_individual2"] 167 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2, cond_individual_individual1, cond_individual_individual2], dim=1) 168 | elif self.mode == "interaction": 169 | cond_interaction = batch["cond_interaction"] 170 | cond_interaction_individual1 = batch["cond_interaction_individual1"] 171 | cond_interaction_individual2 = batch["cond_interaction_individual2"] 172 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2], dim=1) 173 | elif self.mode == "individual": 174 | cond_individual_individual1 = batch["cond_individual_individual1"] 175 | cond = torch.cat([cond_individual_individual1], dim=1) 176 | 177 | B = cond.shape[0] 178 | T = batch["motion_lens"][0] 179 | 180 | timestep_respacing= self.sampling_strategy 181 | self.diffusion_test = MotionDiffusion( 182 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing), 183 | betas=self.betas, 184 | motion_rep=self.motion_rep, 185 | model_mean_type=ModelMeanType.START_X, 186 | model_var_type=ModelVarType.FIXED_SMALL, 187 | loss_type=LossType.MSE, 188 | rescale_timesteps = False, 189 | mode = self.mode 190 | ) 191 | 192 | if self.mode == "dual": 193 | self.cfg_model = ClassifierFreeSampleDualMDM(self.net_individual, self.net_interaction, self.cfg_weight_individual, self.cfg_weight_interaction, self.cfg_composition_weight_func, self.cfg_composition_weight_value) 194 | output = self.diffusion_test.ddim_sample_loop( 195 | self.cfg_model, 196 | (B, T, self.nfeats*2), 197 | clip_denoised=False, 198 | progress=True, 199 | model_kwargs={ 200 | "mask":None, 201 | "cond":cond, 202 | }, 203 | x_start=None) 204 | elif self.mode == "interaction": 205 | self.cfg_model = ClassifierFreeSampleModelMultiple(self.net_interaction, self.cfg_weight, self.cfg_weight_interaction, self.cfg_weight_individual) 206 | output = self.diffusion_test.ddim_sample_loop( 207 | self.cfg_model, 208 | (B, T, self.nfeats*2), 209 | clip_denoised=False, 210 | progress=True, 211 | model_kwargs={ 212 | "mask":None, 213 | "cond":cond, 214 | }, 215 | x_start=None) 216 | elif self.mode == "individual": 217 | self.cfg_model = ClassifierFreeSampleModel(self.net_individual, self.cfg_weight) 218 | output = self.diffusion_test.ddim_sample_loop( 219 | self.cfg_model, 220 | (B, T, self.nfeats), 221 | clip_denoised=False, 222 | progress=True, 223 | model_kwargs={ 224 | "mask":None, 225 | "cond":cond, 226 | }, 227 | x_start=None) 228 | 229 | 230 | return {"output":output} 231 | 232 | class in2INDenoiser(nn.Module): 233 | def __init__(self, 234 | input_feats, 235 | mode, 236 | latent_dim=512, 237 | num_frames=240, 238 | ff_size=1024, 239 | num_layers=8, 240 | num_heads=8, 241 | dropout=0.1, 242 | activation="gelu", 243 | **kargs): 244 | super().__init__() 245 | 246 | self.num_frames = num_frames 247 | self.latent_dim = latent_dim 248 | self.ff_size = ff_size 249 | self.num_layers = num_layers 250 | self.num_heads = num_heads 251 | self.dropout = dropout 252 | self.activation = activation 253 | self.input_feats = input_feats 254 | self.time_embed_dim = latent_dim 255 | self.mode = mode 256 | self.text_emb_dim = 768 257 | 258 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=0) 259 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 260 | 261 | # Input Embedding 262 | self.motion_embed = nn.Linear(self.input_feats, self.latent_dim) 263 | self.text_embed = nn.Linear(self.text_emb_dim, self.latent_dim) 264 | 265 | self.blocks = nn.ModuleList() 266 | 267 | for i in range(num_layers): 268 | self.blocks.append(TransformerBlockDoubleCond(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size, mode=self.mode)) 269 | 270 | # Output Module 271 | self.out = zero_module(FinalLayer(self.latent_dim, self.input_feats)) 272 | 273 | 274 | def forward(self, x, timesteps, mask=None, cond=None): 275 | """ 276 | x: B, T, D 277 | """ 278 | B, T = x.shape[0], x.shape[1] 279 | x_a = x[...,:self.input_feats] 280 | 281 | if self.mode != "individual": 282 | x_b = x[...,self.input_feats:] 283 | 284 | if mask is not None: 285 | mask = mask[...,0] 286 | 287 | if self.mode == "dual_interaction" or self.mode == "interaction": 288 | emb = self.embed_timestep(timesteps) + self.text_embed(cond[:,:768]) 289 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768:768*2]) 290 | emb_individual2 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*2:768*3]) 291 | elif self.mode == "dual_individual": 292 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*3:768*4]) 293 | emb_individual2 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*4:]) 294 | elif self.mode == "individual": 295 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,:768]) 296 | else: 297 | raise ValueError("Mode not recognized") 298 | 299 | a_emb = self.motion_embed(x_a) 300 | h_a_prev = self.sequence_pos_encoder(a_emb) 301 | 302 | if self.mode != "individual": 303 | b_emb = self.motion_embed(x_b) 304 | h_b_prev = self.sequence_pos_encoder(b_emb) 305 | 306 | if mask is None: 307 | mask = torch.ones(B, T).to(x_a.device) 308 | key_padding_mask = ~(mask > 0.5) 309 | 310 | for i,block in enumerate(self.blocks): 311 | if self.mode == "interaction" or self.mode == "dual_interaction": 312 | h_a = block(h_a_prev, h_b_prev, emb_individual1, emb, key_padding_mask) 313 | h_b = block(h_b_prev, h_a_prev, emb_individual2, emb, key_padding_mask) 314 | elif self.mode == "dual_individual": 315 | h_a = block(h_a_prev, None, emb_individual1, None, key_padding_mask) 316 | h_b = block(h_b_prev, None, emb_individual2, None, key_padding_mask) 317 | elif self.mode == "individual": 318 | h_a = block(h_a_prev, None, emb_individual1, None, key_padding_mask) 319 | else: 320 | raise ValueError("Mode not recognized") 321 | 322 | h_a_prev = h_a 323 | 324 | if self.mode == "dual_interaction" or self.mode == "interaction": 325 | h_b_prev = h_b 326 | 327 | 328 | output_a = self.out(h_a) 329 | 330 | if self.mode == "individual": 331 | output = torch.cat([output_a], dim=-1) 332 | else: 333 | output_b = self.out(h_b) 334 | output = torch.cat([output_a, output_b], dim=-1) 335 | 336 | return output --------------------------------------------------------------------------------