├── in2in
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── hml3d_mean.npy
│ ├── hml3d_std.npy
│ ├── interhuman_mean.npy
│ ├── interhuman_std.npy
│ ├── preprocess.py
│ ├── configs.py
│ ├── plot.py
│ ├── paramUtil.py
│ ├── skeleton.py
│ ├── metrics.py
│ └── quaternion.py
├── evaluation
│ ├── __init__.py
│ ├── utils.py
│ ├── models.py
│ └── datasets.py
├── models
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ ├── utils.py
│ │ ├── layers.py
│ │ ├── cfg_sampler.py
│ │ └── losses.py
│ ├── dualmdm.py
│ ├── in2in.py
│ └── nets.py
├── datasets
│ ├── __init__.py
│ ├── dataloader.py
│ ├── humanml3d.py
│ └── interhuman.py
└── scripts
│ ├── eval
│ ├── DualMDM.py
│ └── interhuman.py
│ ├── infer.py
│ └── train.py
├── .gitignore
├── assets
└── cover.png
├── requirements.txt
├── configs
├── eval.yaml
├── train
│ ├── in2IN.yaml
│ └── individual.yaml
├── infer.yaml
├── models
│ ├── individual.yaml
│ ├── in2IN.yaml
│ └── DualMDM.yaml
└── datasets.yaml
├── setup.py
├── LICENSE
└── README.md
/in2in/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
--------------------------------------------------------------------------------
/in2in/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/in2in/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/in2in/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/in2in/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/assets/cover.png
--------------------------------------------------------------------------------
/in2in/utils/hml3d_mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/hml3d_mean.npy
--------------------------------------------------------------------------------
/in2in/utils/hml3d_std.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/hml3d_std.npy
--------------------------------------------------------------------------------
/in2in/utils/interhuman_mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/interhuman_mean.npy
--------------------------------------------------------------------------------
/in2in/utils/interhuman_std.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pabloruizponce/in2IN/HEAD/in2in/utils/interhuman_std.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 | lightning
4 | scipy
5 | matplotlib
6 | pillow
7 | yacs
8 | mmcv
9 | opencv-python
10 | tabulate
11 | termcolor
12 | smplx
13 | torch
14 | torchvision
15 | torchaudio
16 | pykeops
17 | git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterCLIP
2 | NUM_LAYERS: 8
3 | NUM_HEADS: 8
4 | DROPOUT: 0.1
5 | INPUT_DIM: 258
6 | LATENT_DIM: 1024
7 | FF_SIZE: 2048
8 | ACTIVATION: gelu
9 |
10 | MOTION_REP: global
11 | CHECKPOINT: checkpoints/evaluator/interclip.ckpt
12 | FINETUNE: False
13 | EXTENDED: True
14 |
--------------------------------------------------------------------------------
/configs/train/in2IN.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: in2IN
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 32
10 | EPOCH: 2000
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_EPOCH: 50
14 | RESUME:
15 | NUM_WORKERS: 4
16 | MODE: finetune
17 | LAST_EPOCH: 0
18 | LAST_ITER: 0
19 |
--------------------------------------------------------------------------------
/configs/train/individual.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: DualMDMIndividual
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 64
10 | EPOCH: 2000
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_EPOCH: 50
14 | RESUME:
15 | NUM_WORKERS: 8
16 | MODE: finetune
17 | LAST_EPOCH: 0
18 | LAST_ITER: 0
19 |
--------------------------------------------------------------------------------
/configs/infer.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: INFERENCE
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 32
10 | EPOCH: 2000
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_STEPS: 20000
14 | SAVE_EPOCH: 100
15 | RESUME:
16 | NUM_WORKERS: 2
17 | MODE: finetune
18 | LAST_EPOCH: 0
19 | LAST_ITER: 0
20 |
--------------------------------------------------------------------------------
/configs/models/individual.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterGen
2 | NUM_LAYERS: 8
3 | NUM_HEADS: 8
4 | DROPOUT: 0.1
5 | INPUT_DIM: 262
6 | LATENT_DIM: 1024
7 | FF_SIZE: 2048
8 | ACTIVATION: gelu
9 | CHECKPOINT: checkpoints/DualMDM/epoch=1999-step=728000.pt
10 |
11 | DIFFUSION_STEPS: 1000
12 | BETA_SCHEDULER: cosine
13 | SAMPLER: uniform
14 |
15 | MOTION_REP: global
16 | FINETUNE: False
17 |
18 | TEXT_ENCODER: clip
19 | T_BAR: 700
20 |
21 | CONTROL: text
22 | STRATEGY: ddim50
23 | CFG_WEIGHT: 3.5
24 |
25 |
--------------------------------------------------------------------------------
/configs/models/in2IN.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterGen
2 | NUM_LAYERS: 8
3 | NUM_HEADS: 8
4 | DROPOUT: 0.1
5 | INPUT_DIM: 262
6 | LATENT_DIM: 1024
7 | FF_SIZE: 2048
8 | ACTIVATION: gelu
9 | CHECKPOINT: checkpoints/in2IN/epoch=1999-step=352000.pt
10 |
11 | DIFFUSION_STEPS: 1000
12 | BETA_SCHEDULER: cosine
13 | SAMPLER: uniform
14 |
15 | MOTION_REP: global
16 | FINETUNE: False
17 |
18 | TEXT_ENCODER: clip
19 | T_BAR: 700
20 |
21 | CONTROL: text
22 | STRATEGY: ddim50
23 | CFG_WEIGHT: 3
24 | CFG_WEIGHT_INTERACTION: 3
25 | CFG_WEIGHT_INDIVIDUAL: 1
--------------------------------------------------------------------------------
/in2in/models/dualmdm.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from in2in.models.in2in import in2IN
4 | import torch
5 |
6 | def load_DualMDM_model(model_cfg):
7 | """
8 | Load the I2I model with the 2 checkpoints
9 | :param model_cfg: Model Configuration file
10 | :return: I2I model
11 | """
12 | model = in2IN(model_cfg, mode="dual")
13 | print("Model Created")
14 | ckpt = torch.load(model_cfg.CHECKPOINT_INTERACTION)
15 | ckpt_individual = torch.load(model_cfg.CHECKPOINT_INDIVIDUAL)
16 | ckpt.update(ckpt_individual)
17 | model.load_state_dict(ckpt, strict=True)
18 |
19 | return model
20 |
--------------------------------------------------------------------------------
/configs/models/DualMDM.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterGen
2 | NUM_LAYERS: 8
3 | NUM_HEADS: 8
4 | DROPOUT: 0.1
5 | INPUT_DIM: 262
6 | LATENT_DIM: 1024
7 | FF_SIZE: 2048
8 | ACTIVATION: gelu
9 | CHECKPOINT_INTERACTION: checkpoints/in2IN/epoch=1999-step=352000.pt
10 | CHECKPOINT_INDIVIDUAL: checkpoints/DualMDM/epoch=1999-step=728000.pt
11 |
12 | DIFFUSION_STEPS: 1000
13 | BETA_SCHEDULER: cosine
14 | SAMPLER: uniform
15 |
16 | MOTION_REP: global
17 | FINETUNE: False
18 |
19 | TEXT_ENCODER: clip
20 | T_BAR: 700
21 |
22 | CONTROL: text
23 | STRATEGY: ddim50
24 |
25 | CFG_WEIGHT_INDIVIDUAL: 2.5
26 | CFG_WEIGHT_INTERACTION: 3.5
27 | W_FUNC: exp
28 | W_VALUE: 0.0085
--------------------------------------------------------------------------------
/configs/datasets.yaml:
--------------------------------------------------------------------------------
1 | interhuman:
2 | NAME: interhuman
3 | DATA_ROOT: ./data/
4 | MOTION_REP: global
5 | MODE: train
6 | CACHE: True
7 | EXTENDED: True
8 |
9 | interhuman_val:
10 | NAME: interhuman
11 | DATA_ROOT: ./data/
12 | MOTION_REP: global
13 | MODE: val
14 | CACHE: True
15 | EXTENDED: True
16 |
17 | interhuman_test:
18 | NAME: interhuman
19 | DATA_ROOT: ./data/
20 | MOTION_REP: global
21 | MODE: test
22 | CACHE: True
23 | EXTENDED: True
24 |
25 | humanml3d:
26 | NAME: humanml3d
27 | DATA_ROOT: ./data/HumanML3D/
28 | MOTION_REP: global
29 | MODE: train
30 | CACHE: True
31 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup, find_packages
5 |
6 | setup(
7 | name="in2IN",
8 | version="1.0",
9 | description="",
10 | author="Pablo Ruiz Ponce",
11 | packages=find_packages(include=['in2in', 'in2in.*']),
12 | install_requires=[
13 | 'numpy',
14 | 'tqdm',
15 | 'lightning',
16 | 'scipy',
17 | 'matplotlib',
18 | 'pillow',
19 | 'yacs',
20 | 'mmcv',
21 | 'opencv-python',
22 | 'tabulate',
23 | 'termcolor',
24 | 'smplx',
25 | 'torch',
26 | 'torchvision',
27 | 'torchaudio',
28 | 'pykeops',
29 | ],
30 | dependency_links=[
31 | 'git+https://github.com/openai/CLIP.git',
32 | ],
33 | package_data={'': ['*.npy']},
34 | )
--------------------------------------------------------------------------------
/in2in/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .utils import *
3 |
4 | FPS = 30
5 |
6 | def load_motion(file_path, min_length, swap=False):
7 | """
8 | Load motions from the original dataset with all the information needed to the convert them to Interhuman format
9 | :param file_path: path to the motion file
10 | :param min_length: minimum length of the motion
11 | :param swap: swap the left and right side of the motion
12 | """
13 | try:
14 | motion = np.load(file_path).astype(np.float32)
15 | except:
16 | print("error: ", file_path)
17 | return None, None
18 |
19 | # Reshape motion
20 | motion1 = motion[:, :22 * 3]
21 | motion2 = motion[:, 62 * 3:62 * 3 + 21 * 6]
22 | motion = np.concatenate([motion1, motion2], axis=1)
23 |
24 | # If the motion is to short, return none.
25 | if motion.shape[0] < min_length:
26 | return None, None
27 |
28 | # Swap
29 | if swap:
30 | motion_swap = swap_left_right(motion, 22)
31 | else:
32 | motion_swap = None
33 |
34 | return motion, motion_swap
35 |
36 | def load_motion_hml3d(pos_file_path, rot_file_path , min_length):
37 | """
38 | Load motions from hml3d dataset with all the information needed to the convert them to Interhuman format
39 | :param pos_file_path: path to the position file
40 | :param rot_file_path: path to the rotation file
41 | """
42 |
43 | # Try to extract motions from the original dataset
44 | try:
45 | pos_motion = np.load(pos_file_path).astype(np.float32)
46 | rot_motion = np.load(rot_file_path).astype(np.float32)
47 | except:
48 | print("error: ", pos_motion)
49 | return None, None
50 |
51 | # Conver postition from (LENGHT, JOINTS, 3) to (LENGHT, JOINTS*3)
52 | pos_motion = pos_motion[:,:22]
53 | pos_motion = pos_motion.reshape(pos_motion.shape[0], -1)[:-1,:]
54 |
55 | # Extract relative rotations from HumanML3D representation
56 | rot_motion = rot_motion[:,4+(21*3)+(22*3):4+(21*3)+(22*3)+(21*6)].reshape(rot_motion.shape[0], -1)
57 |
58 | # Concatenate position and rotation
59 | motion = np.concatenate([pos_motion, rot_motion], axis=1)
60 |
61 | if motion.shape[0] < min_length:
62 | return None, None
63 |
64 | return motion, None
--------------------------------------------------------------------------------
/in2in/utils/configs.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict
3 | from yacs.config import CfgNode as CN
4 |
5 |
6 | def to_lower(x: Dict) -> Dict:
7 | """
8 | Convert all dictionary keys to lowercase
9 | Args:
10 | x (dict): Input dictionary
11 | Returns:
12 | dict: Output dictionary with all keys converted to lowercase
13 | """
14 | return {k.lower(): v for k, v in x.items()}
15 |
16 | _C = CN(new_allowed=True)
17 |
18 | def default_config() -> CN:
19 | """
20 | Get a yacs CfgNode object with the default config values.
21 | """
22 | # Return a clone so that the defaults will not be altered
23 | # This is for the "local variable" use pattern
24 | return _C.clone()
25 |
26 | def get_config(config_file: str, merge: bool = True) -> CN:
27 | """
28 | Read a config file and optionally merge it with the default config file.
29 | Args:
30 | config_file (str): Path to config file.
31 | merge (bool): Whether to merge with the default config or not.
32 | Returns:
33 | CfgNode: Config as a yacs CfgNode object.
34 | """
35 | if merge:
36 | cfg = default_config()
37 | else:
38 | cfg = CN(new_allowed=True)
39 | cfg.merge_from_file(config_file)
40 | cfg.freeze()
41 | return cfg
42 |
43 | def get_config_model(config_file: str, merge: bool = True, w_func:str = None, w_value:float = None) -> CN:
44 | """
45 | Read a config file and optionally merge it with the default config file.
46 | Args:
47 | config_file (str): Path to config file.
48 | merge (bool): Whether to merge with the default config or not.
49 | Returns:
50 | CfgNode: Config as a yacs CfgNode object.
51 | """
52 | if merge:
53 | cfg = default_config()
54 | else:
55 | cfg = CN(new_allowed=True)
56 | cfg.merge_from_file(config_file)
57 |
58 |
59 | cfg.W_FUNC = w_func
60 | cfg.W_VALUE = w_value
61 | cfg.freeze()
62 | return cfg
63 |
64 | def dataset_config() -> CN:
65 | """
66 | Get dataset config file
67 | Returns:
68 | CfgNode: Dataset config as a yacs CfgNode object.
69 | """
70 | cfg = CN(new_allowed=True)
71 | config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets.yaml')
72 | cfg.merge_from_file(config_file)
73 | cfg.freeze()
74 | return cfg
75 |
76 |
--------------------------------------------------------------------------------
/in2in/models/utils/blocks.py:
--------------------------------------------------------------------------------
1 | from .layers import *
2 |
3 | class TransformerBlock(nn.Module):
4 | def __init__(self,
5 | latent_dim=512,
6 | num_heads=8,
7 | ff_size=1024,
8 | dropout=0.,
9 | cond_abl=False,
10 | **kargs):
11 | super().__init__()
12 | self.latent_dim = latent_dim
13 | self.num_heads = num_heads
14 | self.dropout = dropout
15 | self.cond_abl = cond_abl
16 |
17 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout)
18 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
19 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim)
20 |
21 | def forward(self, x, y, emb=None, key_padding_mask=None):
22 | h1 = self.sa_block(x, emb, key_padding_mask)
23 | h1 = h1 + x
24 | h2 = self.ca_block(h1, y, emb, key_padding_mask)
25 | h2 = h2 + h1
26 | out = self.ffn(h2, emb)
27 | out = out + h2
28 | return out
29 |
30 | class TransformerBlockDoubleCond(nn.Module):
31 | def __init__(self,
32 | mode,
33 | latent_dim=512,
34 | num_heads=8,
35 | ff_size=1024,
36 | dropout=0.,
37 | cond_abl=False,
38 | **kargs):
39 | super().__init__()
40 | self.latent_dim = latent_dim
41 | self.num_heads = num_heads
42 | self.dropout = dropout
43 | self.cond_abl = cond_abl
44 | self.mode = mode
45 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout)
46 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
47 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim)
48 |
49 | def forward(self, x, y, emb=None, emb_interaction=None, key_padding_mask=None):
50 | h1 = self.sa_block(x, emb, key_padding_mask)
51 | h1 = h1 + x
52 | # If only individual, we ingore the cross attention layers
53 | if self.mode == "individual" or self.mode == "dual_individual":
54 | h2 = h1
55 | else:
56 | h2 = self.ca_block(h1, y, emb_interaction, key_padding_mask)
57 | h2 = h2 + h1
58 |
59 | # All modes have the FFN layer
60 | out = self.ffn(h2, emb)
61 | out = out + h2
62 | return out
63 |
64 |
--------------------------------------------------------------------------------
/in2in/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | import torch
3 | from .interhuman import InterHuman
4 | from .humanml3d import HumanML3D
5 |
6 | class DataModuleHML3D(pl.LightningDataModule):
7 | def __init__(self, cfg, batch_size, num_workers):
8 | """
9 | Initialize LightningDataModule for ProHMR training
10 | Args:
11 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
12 | dataset_cfg (CfgNode): Dataset configuration file
13 | """
14 | super().__init__()
15 | self.cfg = cfg
16 | self.batch_size = batch_size
17 | self.num_workers = num_workers
18 |
19 | def setup(self, stage = None):
20 | """
21 | Create train and validation datasets
22 | """
23 | if self.cfg.NAME == "humanml3d":
24 | self.train_dataset = HumanML3D(self.cfg)
25 | else:
26 | raise NotImplementedError
27 |
28 | def train_dataloader(self):
29 | """
30 | Return train dataloader
31 | """
32 | return torch.utils.data.DataLoader(
33 | self.train_dataset,
34 | batch_size=self.batch_size,
35 | num_workers=self.num_workers,
36 | pin_memory=False,
37 | shuffle=True,
38 | drop_last=True,
39 | )
40 |
41 | class DataModule(pl.LightningDataModule):
42 | def __init__(self, cfg, batch_size, num_workers):
43 | """
44 | Initialize LightningDataModule for ProHMR training
45 | Args:
46 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
47 | dataset_cfg (CfgNode): Dataset configuration file
48 | """
49 | super().__init__()
50 | self.cfg = cfg
51 | self.batch_size = batch_size
52 | self.num_workers = num_workers
53 |
54 | def setup(self, stage = None):
55 | """
56 | Create train and validation datasets
57 | """
58 | if self.cfg.NAME == "interhuman":
59 | self.train_dataset = InterHuman(self.cfg)
60 | else:
61 | raise NotImplementedError
62 |
63 | def train_dataloader(self):
64 | """
65 | Return train dataloader
66 | """
67 | return torch.utils.data.DataLoader(
68 | self.train_dataset,
69 | batch_size=self.batch_size,
70 | num_workers=self.num_workers,
71 | pin_memory=False,
72 | shuffle=True,
73 | drop_last=True,
74 | )
75 |
--------------------------------------------------------------------------------
/in2in/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | from functools import partial
5 | from typing import Optional, Union
6 | from mmcv.runner import get_dist_info
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data.dataset import Dataset
9 |
10 |
11 | def build_dataloader(dataset: Dataset,
12 | samples_per_gpu: int,
13 | workers_per_gpu: int,
14 | num_gpus: Optional[int] = 1,
15 | shuffle: Optional[bool] = True,
16 | round_up: Optional[bool] = True,
17 | seed: Optional[Union[int, None]] = None,
18 | persistent_workers: Optional[bool] = True,
19 | **kwargs):
20 | """Build PyTorch DataLoader.
21 |
22 | In distributed training, each GPU/process has a dataloader.
23 | In non-distributed training, there is only one dataloader for all GPUs.
24 |
25 | Args:
26 | dataset (:obj:`Dataset`): A PyTorch dataset.
27 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
28 | batch size of each GPU.
29 | workers_per_gpu (int): How many subprocesses to use for data loading
30 | for each GPU.
31 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed
32 | training.
33 | dist (bool, optional): Distributed training/test or not. Default: True.
34 | shuffle (bool, optional): Whether to shuffle the data at every epoch.
35 | Default: True.
36 | round_up (bool, optional): Whether to round up the length of dataset by
37 | adding extra samples to make it evenly divisible. Default: True.
38 | persistent_workers (bool): If True, the data loader will not shutdown
39 | the worker processes after a dataset has been consumed once.
40 | This allows to maintain the workers Dataset instances alive.
41 | The argument also has effect in PyTorch>=1.7.0.
42 | Default: True
43 | kwargs: any keyword argument to be used to initialize DataLoader
44 |
45 | Returns:
46 | DataLoader: A PyTorch dataloader.
47 | """
48 | rank, world_size = get_dist_info()
49 | sampler = None
50 | batch_size = num_gpus * samples_per_gpu
51 | num_workers = num_gpus * workers_per_gpu
52 |
53 | init_fn = partial(
54 | worker_init_fn, num_workers=num_workers, rank=rank,
55 | seed=seed) if seed is not None else None
56 |
57 | data_loader = DataLoader(
58 | dataset,
59 | batch_size=batch_size,
60 | sampler=sampler,
61 | num_workers=num_workers,
62 | pin_memory=False,
63 | shuffle=shuffle,
64 | worker_init_fn=init_fn,
65 | persistent_workers=persistent_workers,
66 | **kwargs)
67 |
68 |
69 |
70 | return data_loader
71 |
72 |
73 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
74 | """Init random seed for each worker."""
75 | # The seed of each worker equals to
76 | # num_worker * rank + worker_id + user_seed
77 | worker_seed = num_workers * rank + worker_id + seed
78 | np.random.seed(worker_seed)
79 | random.seed(worker_seed)
80 |
--------------------------------------------------------------------------------
/in2in/models/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 |
5 |
6 | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
7 | def __init__(self, optimizer, warmup, max_iters, verbose=False):
8 | self.warmup = warmup
9 | self.max_num_iters = max_iters
10 | super().__init__(optimizer, verbose=verbose)
11 |
12 | def get_lr(self):
13 | lr_factor = self.get_lr_factor(epoch=self.last_epoch)
14 | return [base_lr * lr_factor for base_lr in self.base_lrs]
15 |
16 | def get_lr_factor(self, epoch):
17 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
18 | if epoch <= self.warmup:
19 | lr_factor *= (epoch+1) * 1.0 / self.warmup
20 | return lr_factor
21 |
22 |
23 |
24 | class PositionalEncoding(nn.Module):
25 | def __init__(self, d_model, dropout=0.0, max_len=5000):
26 | super(PositionalEncoding, self).__init__()
27 | self.dropout = nn.Dropout(p=dropout)
28 |
29 | pe = torch.zeros(max_len, d_model)
30 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
31 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
32 | pe[:, 0::2] = torch.sin(position * div_term)
33 | pe[:, 1::2] = torch.cos(position * div_term)
34 |
35 | self.register_buffer('pe', pe)
36 |
37 | def forward(self, x):
38 | x = x + self.pe[:x.shape[1], :].unsqueeze(0)
39 | return self.dropout(x)
40 |
41 | class TimestepEmbedder(nn.Module):
42 | def __init__(self, latent_dim, sequence_pos_encoder):
43 | super().__init__()
44 | self.latent_dim = latent_dim
45 | self.sequence_pos_encoder = sequence_pos_encoder
46 |
47 | time_embed_dim = self.latent_dim
48 | self.time_embed = nn.Sequential(
49 | nn.Linear(self.latent_dim, time_embed_dim),
50 | nn.SiLU(),
51 | nn.Linear(time_embed_dim, time_embed_dim),
52 | )
53 |
54 | def forward(self, timesteps):
55 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps])
56 |
57 |
58 | class IdentityEmbedder(nn.Module):
59 | def __init__(self, latent_dim, sequence_pos_encoder):
60 | super().__init__()
61 | self.latent_dim = latent_dim
62 | self.sequence_pos_encoder = sequence_pos_encoder
63 |
64 | time_embed_dim = self.latent_dim
65 | self.time_embed = nn.Sequential(
66 | nn.Linear(self.latent_dim, time_embed_dim),
67 | nn.SiLU(),
68 | nn.Linear(time_embed_dim, time_embed_dim),
69 | )
70 |
71 | def forward(self, timesteps):
72 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).unsqueeze(1)
73 |
74 |
75 | def set_requires_grad(nets, requires_grad=False):
76 | """Set requies_grad for all the networks.
77 |
78 | Args:
79 | nets (nn.Module | list[nn.Module]): A list of networks or a single
80 | network.
81 | requires_grad (bool): Whether the networks require gradients or not
82 | """
83 | if not isinstance(nets, list):
84 | nets = [nets]
85 | for net in nets:
86 | if net is not None:
87 | for param in net.parameters():
88 | param.requires_grad = requires_grad
89 |
90 |
91 | def zero_module(module):
92 | """
93 | Zero out the parameters of a module and return it.
94 | """
95 | for p in module.parameters():
96 | p.detach().zero_()
97 | return module
98 |
--------------------------------------------------------------------------------
/in2in/models/utils/layers.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
3 | class AdaLN(nn.Module):
4 | def __init__(self, latent_dim, embed_dim=None):
5 | super().__init__()
6 | if embed_dim is None:
7 | embed_dim = latent_dim
8 | self.emb_layers = nn.Sequential(
9 | # nn.Linear(embed_dim, latent_dim, bias=True),
10 | nn.SiLU(),
11 | zero_module(nn.Linear(embed_dim, 2 * latent_dim, bias=True)),
12 | )
13 | self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False, eps=1e-6)
14 |
15 | def forward(self, h, emb):
16 | """
17 | h: B, T, D
18 | emb: B, D
19 | """
20 | # B, 1, 2D
21 | emb_out = self.emb_layers(emb)
22 | # scale: B, 1, D / shift: B, 1, D
23 | scale, shift = torch.chunk(emb_out, 2, dim=-1)
24 | h = self.norm(h) * (1 + scale[:, None]) + shift[:, None]
25 | return h
26 |
27 |
28 | class VanillaSelfAttention(nn.Module):
29 | def __init__(self, latent_dim, num_head, dropout, embed_dim=None):
30 | super().__init__()
31 | self.num_head = num_head
32 | self.norm = AdaLN(latent_dim, embed_dim)
33 | self.attention = nn.MultiheadAttention(latent_dim, num_head, dropout=dropout, batch_first=True,
34 | add_zero_attn=True)
35 |
36 | def forward(self, x, emb, key_padding_mask=None):
37 | """
38 | x: B, T, D
39 | """
40 | x_norm = self.norm(x, emb)
41 | y = self.attention(x_norm, x_norm, x_norm,
42 | attn_mask=None,
43 | key_padding_mask=key_padding_mask,
44 | need_weights=False)[0]
45 | return y
46 |
47 |
48 | class VanillaCrossAttention(nn.Module):
49 | def __init__(self, latent_dim, xf_latent_dim, num_head, dropout, embed_dim=None):
50 | super().__init__()
51 | self.num_head = num_head
52 | self.norm = AdaLN(latent_dim, embed_dim)
53 | self.xf_norm = AdaLN(xf_latent_dim, embed_dim)
54 | self.attention = nn.MultiheadAttention(latent_dim, num_head, kdim=xf_latent_dim, vdim=xf_latent_dim,
55 | dropout=dropout, batch_first=True, add_zero_attn=True)
56 |
57 | def forward(self, x, xf, emb, key_padding_mask=None):
58 | """
59 | x: B, T, D
60 | xf: B, N, L
61 | """
62 | x_norm = self.norm(x, emb)
63 | xf_norm = self.xf_norm(xf, emb)
64 | y = self.attention(x_norm, xf_norm, xf_norm,
65 | attn_mask=None,
66 | key_padding_mask=key_padding_mask,
67 | need_weights=False)[0]
68 | return y
69 |
70 |
71 | class FFN(nn.Module):
72 | def __init__(self, latent_dim, ffn_dim, dropout, embed_dim=None):
73 | super().__init__()
74 | self.norm = AdaLN(latent_dim, embed_dim)
75 | self.linear1 = nn.Linear(latent_dim, ffn_dim, bias=True)
76 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim, bias=True))
77 | self.activation = nn.GELU()
78 | self.dropout = nn.Dropout(dropout)
79 |
80 | def forward(self, x, emb=None):
81 | if emb is not None:
82 | x_norm = self.norm(x, emb)
83 | else:
84 | x_norm = x
85 | y = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
86 | return y
87 |
88 |
89 | class FinalLayer(nn.Module):
90 | def __init__(self, latent_dim, out_dim):
91 | super().__init__()
92 | self.linear = zero_module(nn.Linear(latent_dim, out_dim, bias=True))
93 |
94 | def forward(self, x):
95 | x = self.linear(x)
96 | return x
--------------------------------------------------------------------------------
/in2in/utils/plot.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import mpl_toolkits.mplot3d.axes3d as p3
6 |
7 | from mpl_toolkits.mplot3d import Axes3D
8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
9 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter
10 | from tqdm import tqdm
11 |
12 |
13 | def plot_3d_motion(save_path, kinematic_tree, mp_joints, title, figsize=(10, 10), fps=120, radius=6, mode='interaction'):
14 | """
15 | Function to plot an interaction between two agents in 3D in matplotlib
16 | :param save_path: path to save the animation
17 | :param kinematic_tree: kinematic tree of the motion
18 | :param mp_joints: list of motion data for each agent
19 | :param title: title of the plot
20 | :param figsize: size of the figure
21 | :param fps: frames per second of the animation
22 | :param radius: radius of the plot
23 | :param mode: mode of the plot
24 | """
25 | matplotlib.use('Agg')
26 |
27 | # Define initial limits of the plot
28 | def init():
29 | ax.set_xlim3d([-radius / 4, radius / 4])
30 | ax.set_ylim3d([0, radius / 2])
31 | ax.set_zlim3d([0, radius / 2])
32 | ax.grid(b=False)
33 |
34 | # Funtion to plot a floor in the animation
35 | def plot_xzPlane(minx, maxx, miny, minz, maxz):
36 | verts = [
37 | [minx, miny, minz],
38 | [minx, miny, maxz],
39 | [maxx, miny, maxz],
40 | [maxx, miny, minz]
41 | ]
42 | xz_plane = Poly3DCollection([verts])
43 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
44 | ax.add_collection3d(xz_plane)
45 |
46 |
47 | # Create the figure and axis
48 | fig = plt.figure(figsize=figsize)
49 | ax = fig.add_subplot(111, projection='3d')
50 | init()
51 |
52 | # Offsets and colors
53 | mp_offset = list(range(-len(mp_joints)//2, len(mp_joints)//2, 1))
54 | colors = ['red', 'blue', 'black', 'red', 'blue',
55 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
56 | 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
57 | mp_colors = [[colors[i]] * 15 for i in range(len(mp_offset))]
58 |
59 | # Store the data for each agent
60 | mp_data = []
61 | for i,joints in enumerate(mp_joints):
62 |
63 | data = joints.copy().reshape(len(joints), -1, 3)
64 |
65 | MINS = data.min(axis=0).min(axis=0)
66 | MAXS = data.max(axis=0).max(axis=0)
67 |
68 | height_offset = MINS[1]
69 | data[:, :, 1] -= height_offset
70 | trajec = data[:, 0, [0, 2]]
71 |
72 | mp_data.append({"joints":data,
73 | "MINS":MINS,
74 | "MAXS":MAXS,
75 | "trajec":trajec, })
76 |
77 | def update(index):
78 | """
79 | Update function for the matplotlib animation
80 | :param index: index of the frame
81 | """
82 | # Update the progress bar
83 | bar.update(1)
84 |
85 | # Clear the axis and setting initial parameters
86 | ax.clear()
87 | plt.axis('off')
88 | ax.view_init(elev=120, azim=-90)
89 | ax.dist = 7.5
90 | ax.set_xticklabels([])
91 | ax.set_yticklabels([])
92 | ax.set_zticklabels([])
93 |
94 | # Plot the floor
95 | plot_xzPlane(-3, 3, 0, -3, 3)
96 |
97 | # Plot each of the persons in the motion
98 | for pid,data in enumerate(mp_data):
99 | for i, (chain, color) in enumerate(zip(kinematic_tree, mp_colors[pid])):
100 | linewidth = 3.0
101 | ax.plot3D(data["joints"][index, chain, 0],
102 | data["joints"][index, chain, 1],
103 | data["joints"][index, chain, 2],
104 | linewidth=linewidth,
105 | color=color,
106 | alpha=1)
107 |
108 | # Generate animation
109 | frame_number = min([data.shape[0] for data in mp_joints])
110 | bar = tqdm(total=frame_number+1)
111 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
112 | ani.save(save_path, fps=fps)
113 | plt.close()
114 |
115 |
--------------------------------------------------------------------------------
/in2in/scripts/eval/DualMDM.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(sys.path[0]+r"/../../../")
3 |
4 | import torch
5 | import random
6 | import argparse
7 | import numpy as np
8 |
9 | from in2in.utils.configs import get_config
10 | from in2in.models.dualmdm import load_DualMDM_model
11 | from in2in.utils.metrics import calculate_wasserstein
12 | from in2in.evaluation.utils import EvaluatorModelWrapper, get_dataset_motion_loader, get_motion_loader_DualMDM
13 |
14 | # Randon Seed Configuration
15 | np.random.seed(0)
16 | random.seed(0)
17 |
18 | def calculate_individual_diversity(generated_motion_embeddings):
19 | s_int = generated_motion_embeddings[:num_times]
20 | s_ind = generated_motion_embeddings[num_times:]
21 |
22 | diff, corrs_1_to_2, corrs_2_to_1 = calculate_wasserstein(s_int, s_ind, max_iters=500, verbose=True)
23 | return diff
24 |
25 |
26 | def evaluation():
27 | metrics = {
28 | 'Individual Diversity': [],
29 | }
30 |
31 | for i in range(replication_times):
32 | print("Replication: ", i)
33 |
34 | individual_diversity = []
35 | eval_motion_loader = eval_motion_loader_getter()
36 |
37 | for bidx ,eval_data in enumerate(eval_motion_loader):
38 |
39 | generated_motions1, generated_motions2, motion1, motion2, motion_lens, text, text_individual1, text_individual2 = eval_data
40 |
41 | # This is needed in order to work the motion emeddings
42 | generated_motions1 = generated_motions1[0]
43 | generated_motions2 = generated_motions2[0]
44 | motion1 = motion1[0]
45 | motion2 = motion2[0]
46 | motion_lens = motion_lens[0]
47 |
48 | generated_motion_embeddings = eval_wrapper.get_motion_embeddings([
49 | "name",
50 | text,
51 | generated_motions1,
52 | generated_motions2,
53 | motion_lens,
54 | text_individual1,
55 | text_individual2
56 | ])
57 |
58 | motion_len = motion_lens[0].item()
59 | generated_motions1 = generated_motions1[:,:motion_len,:]
60 | generated_motions2 = generated_motions2[:,:motion_len,:]
61 | motion1 = motion1[:,:motion_len,:]
62 | motion2 = motion2[:,:motion_len,:]
63 |
64 | individual_diversity.append(calculate_individual_diversity(generated_motion_embeddings).cpu().numpy().tolist())
65 |
66 | metrics['Individual Diversity'].append(np.mean(individual_diversity))
67 | print("Individual Diversity: ", np.mean(individual_diversity))
68 |
69 | print("---- Final Metrics ----")
70 | print("Individual Diversity: ", np.mean(metrics['Individual Diversity']), ", std: ",np.std(metrics['Individual Diversity']))
71 |
72 | if __name__ == '__main__':
73 |
74 | # Configuration values
75 | num_samples = 100
76 | num_times = 32
77 | replication_times = 5
78 | batch_size = 1
79 |
80 | # Create the parser
81 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments")
82 |
83 | # Add optional arguments
84 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file')
85 | parser.add_argument('--evaluator', type=str, required=True, help='Evaluator Configuration file')
86 | parser.add_argument('--device', type=int, default=0, help='GPU device id')
87 |
88 | # Parse the arguments
89 | args = parser.parse_args()
90 |
91 | # Loading configuration files
92 | data_cfg = get_config("configs/datasets.yaml").interhuman_test
93 | model_cfg = get_config(args.model)
94 | evalmodel_cfg = get_config(args.evaluator)
95 |
96 | # Cuda configuration
97 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu')
98 | torch.cuda.set_device(args.device)
99 |
100 | # Build and Load Model
101 | model = load_DualMDM_model(model_cfg)
102 |
103 | # Get Datasets and DataLoaders
104 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size, num_samples=num_samples)
105 | eval_motion_loader_getter = lambda: get_motion_loader_DualMDM(batch_size, model, gt_dataset, device, num_times)
106 |
107 | # Evaluator Model
108 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device)
109 | evaluation()
110 |
--------------------------------------------------------------------------------
/in2in/datasets/humanml3d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 |
5 | from tqdm import tqdm
6 | from utils.utils import *
7 | from torch.utils import data
8 | from utils.preprocess import *
9 | from os.path import join as pjoin
10 |
11 | class HumanML3D(data.Dataset):
12 | """
13 | HumanML3D dataset
14 | """
15 | def __init__(self, opt):
16 |
17 | # Configuration variables
18 | self.opt = opt
19 | self.max_cond_length = 1
20 | self.min_cond_length = 1
21 | self.max_gt_length = 300
22 | self.min_gt_length = 15
23 | self.max_length = self.max_cond_length + self.max_gt_length -1
24 | self.min_length = self.min_cond_length + self.min_gt_length -1
25 | self.motion_rep = opt.MOTION_REP
26 | self.cache = opt.CACHE
27 |
28 | # Data structures
29 | self.motion_dict = {}
30 | self.data_list = []
31 | data_list = []
32 |
33 | # Load paths from the given split
34 | if self.opt.MODE == "train":
35 | try:
36 | data_list = open(os.path.join(opt.DATA_ROOT, "train.txt"), "r").readlines()
37 | except Exception as e:
38 | print(e)
39 | elif self.opt.MODE == "val":
40 | try:
41 | data_list = open(os.path.join(opt.DATA_ROOT, "val.txt"), "r").readlines()
42 | except Exception as e:
43 | print(e)
44 | elif self.opt.MODE == "test":
45 | try:
46 | data_list = open(os.path.join(opt.DATA_ROOT, "test.txt"), "r").readlines()
47 | except Exception as e:
48 | print(e)
49 |
50 | # Suffle paths
51 | random.shuffle(data_list)
52 |
53 | # Load data
54 | index = 0
55 | motion_path = pjoin(opt.DATA_ROOT, "interhuman/")
56 | for file in tqdm(os.listdir(motion_path)):
57 |
58 | # Comment if you want to use the whole dataset
59 | if file.split(".")[0]+"\n" not in data_list:
60 | continue
61 |
62 | motion_name = file.split(".")[0]
63 | motion_file_path = pjoin(motion_path, file)
64 | text_path = motion_file_path.replace("interhuman", "texts").replace("npy", "txt")
65 |
66 | # Load motion and text
67 | texts = [item.replace("\n", "") for item in open(text_path, "r").readlines()]
68 | motion1 = np.load(motion_file_path).astype(np.float32)
69 |
70 | # Check if the motion is too short
71 | if motion1.shape[0] < self.min_length:
72 | continue
73 |
74 | # Cache the motion if needed
75 | if self.cache:
76 | self.motion_dict[index] = motion1
77 | else:
78 | self.motion_dict[index] = motion_file_path
79 |
80 | self.data_list.append({
81 | "name": motion_name,
82 | "motion_id": index,
83 | "swap":False,
84 | "texts":texts
85 | })
86 |
87 | index += 1
88 |
89 | print("Total Dataset Size: ", len(self.data_list))
90 |
91 | def __len__(self):
92 | """
93 | Get the length of the dataset
94 | """
95 | return len(self.data_list)
96 |
97 | def __getitem__(self, item):
98 | """
99 | Get an item from the dataset
100 | param item: Index of the item to get
101 | """
102 |
103 | # Get the data from the dataset
104 | idx = item % self.__len__()
105 | data = self.data_list[idx]
106 | name = data["name"]
107 | motion_id = data["motion_id"]
108 |
109 | # Select a random text from the list
110 | text = random.choice(data["texts"]).strip().split('#')[0]
111 |
112 | # Load the motion
113 | if self.cache:
114 | full_motion1 = self.motion_dict[motion_id]
115 | else:
116 | file_path1 = self.motion_dict[motion_id]
117 | full_motion1 = np.load(file_path1).astype(np.float32)
118 |
119 | # Get motion lenght and select a random segment
120 | length = full_motion1.shape[0]
121 | if length > self.max_length:
122 | idx = random.choice(list(range(0, length - self.max_gt_length, 1)))
123 | gt_length = self.max_gt_length
124 | motion1 = full_motion1[idx:idx + gt_length]
125 | else:
126 | idx = 0
127 | gt_length = min(length, self.max_gt_length )
128 | motion1 = full_motion1[idx:idx + gt_length]
129 |
130 | # Check if the motion is too short and pad it
131 | gt_motion1 = motion1
132 | gt_length = len(gt_motion1)
133 | if gt_length < self.max_gt_length:
134 | padding_len = self.max_gt_length - gt_length
135 | D = gt_motion1.shape[1]
136 | padding_zeros = np.zeros((padding_len, D))
137 | gt_motion1 = np.concatenate((gt_motion1, padding_zeros), axis=0)
138 |
139 | # Return the data
140 | return name, text, gt_motion1, gt_length
141 |
142 |
--------------------------------------------------------------------------------
/in2in/models/utils/cfg_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | class ClassifierFreeSampleModel(nn.Module):
6 |
7 | def __init__(self, model, cfg_scale):
8 | super().__init__()
9 | self.model = model # model is the actual model to run
10 | self.s = cfg_scale
11 |
12 | def forward(self, x, timesteps, cond=None, mask=None):
13 | B, T, D = x.shape
14 |
15 | x_combined = torch.cat([x, x], dim=0)
16 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0)
17 | if cond is not None:
18 | cond = torch.cat([cond, torch.zeros_like(cond)], dim=0)
19 | if mask is not None:
20 | mask = torch.cat([mask, mask], dim=0)
21 |
22 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask)
23 |
24 | out_cond = out[:B]
25 | out_uncond = out[B:]
26 |
27 | cfg_out = self.s * out_cond + (1-self.s) * out_uncond
28 | return cfg_out
29 |
30 |
31 | class ClassifierFreeSampleModelMultiple(nn.Module):
32 | def __init__(self, model, cfg_scale, cfg_scale_interaction, cfg_scale_individuals):
33 | super().__init__()
34 | self.model = model
35 | self.s = cfg_scale
36 | self.s_interaction = cfg_scale_interaction
37 | self.s_individuals = cfg_scale_individuals
38 |
39 | def forward(self, x, timesteps, cond=None, mask=None):
40 | B, T, D = x.shape
41 |
42 | x_combined = torch.cat([x, x, x, x], dim=0)
43 | timesteps_combined = torch.cat([timesteps, timesteps, timesteps, timesteps], dim=0)
44 | if cond is not None:
45 |
46 | cond_full = cond
47 |
48 | cond_interaction = torch.zeros_like(cond)
49 | cond_interaction[:,:768] = cond[:,:768]
50 |
51 | cond_individuals = torch.zeros_like(cond)
52 | cond_individuals[:,768:] = cond[:,768:]
53 |
54 | cond = torch.cat([cond_full, cond_interaction, cond_individuals, torch.zeros_like(cond)], dim=0)
55 | if mask is not None:
56 | mask = torch.cat([mask, mask, mask, mask], dim=0)
57 |
58 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask)
59 |
60 | out_cond = out[:B]
61 | out_cond_interaction = out[B:B*2]
62 | out_cond_individuals = out[B*2:B*3]
63 | out_uncond = out[B*3:]
64 |
65 | cfg_out = (self.s * out_cond) + (self.s_interaction * out_cond_interaction) + (self.s_individuals * out_cond_individuals) + ((1-(self.s+self.s_interaction+self.s_individuals)) * out_uncond)
66 | return cfg_out
67 |
68 |
69 | class ClassifierFreeSampleDualMDM(nn.Module):
70 |
71 | def __init__(self, m_individual, m_interaction, s_individual, s_interaction, s_composition_func, s_composition_value):
72 | super().__init__()
73 | self.m_individual = m_individual
74 | self.m_interaction = m_interaction
75 | self.s_individual = s_individual
76 | self.s_interaction = s_interaction
77 | self.s_composition = self.weight(s_composition_func, s_composition_value)
78 |
79 |
80 | def weight(self, func, value):
81 | print(f"Diffusion Weight Scheduler func: {func}, value: {value}")
82 |
83 | if func == "exp":
84 | return lambda x: np.exp(-value * (1000 - x))[0]
85 | elif func == "exp-inv":
86 | return lambda x: 1 - np.exp(-value * (1000 - x))[0]
87 | elif func == "lin":
88 | return lambda x: 1 - ((1000 - x) / 1000)[0]
89 | elif func == "const":
90 | return lambda x: value
91 | else:
92 | raise ValueError("Unknown function")
93 |
94 | def forward(self, x, timesteps, cond=None, mask=None):
95 | B, T, D = x.shape
96 |
97 | x_combined = torch.cat([x, x], dim=0)
98 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0)
99 |
100 | if cond is not None:
101 | cond_combined = torch.cat([cond, torch.zeros_like(cond)], dim=0)
102 |
103 | if mask is not None:
104 | mask_combined = torch.cat([mask, mask], dim=0)
105 | else:
106 | mask_combined = None
107 |
108 | out_interaction = self.m_interaction(x_combined, timesteps_combined, cond=cond_combined, mask=mask_combined)
109 | out_individual = self.m_individual(x_combined, timesteps_combined, cond=cond_combined, mask=mask_combined)
110 |
111 | out_interaction_cond = out_interaction[:B]
112 | out_interaction_uncond = out_interaction[B:]
113 |
114 | out_individual_cond = out_individual[:B]
115 | out_individual_uncond = out_individual[B:]
116 |
117 | cfg_out_interaction = (out_interaction_uncond + self.s_interaction * (out_interaction_cond - out_interaction_uncond))
118 | cfg_out_individual = (out_individual_uncond + self.s_individual * (out_individual_cond - out_individual_uncond))
119 |
120 | w = self.s_composition(timesteps.cpu().numpy())
121 | cfg_out = cfg_out_interaction + w * (cfg_out_individual - cfg_out_interaction)
122 | return cfg_out
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/in2in/scripts/infer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(sys.path[0] + r"/../../")
3 |
4 | from collections import OrderedDict
5 | import copy
6 | import os.path
7 | import argparse
8 | import torch
9 | from scipy.ndimage import gaussian_filter1d
10 | import numpy as np
11 |
12 | from lightning import LightningModule
13 | from in2in.utils.utils import MotionNormalizer, MotionNormalizerHML3D
14 | from in2in.utils.plot import plot_3d_motion
15 | from in2in.utils.paramUtil import HML_KINEMATIC_CHAIN
16 | from in2in.utils.configs import get_config
17 | from in2in.models.dualmdm import load_DualMDM_model
18 | from in2in.models.in2in import in2IN
19 |
20 | class LitGenModel(LightningModule):
21 | def __init__(self, model, cfg, save_folder, mode):
22 | super().__init__()
23 | # cfg init
24 | self.cfg = cfg
25 |
26 | self.automatic_optimization = False
27 | self.save_folder = save_folder
28 |
29 | # train model init
30 | self.model = model
31 |
32 | self.save_folder = os.path.join("results",save_folder)
33 | if not os.path.exists(self.save_folder):
34 | os.makedirs(self.save_folder)
35 |
36 | self.mode = mode
37 |
38 | if self.mode == "individual":
39 | self.normalizer = MotionNormalizerHML3D()
40 | else:
41 | self.normalizer = MotionNormalizer()
42 |
43 | def plot_t2m(self, mp_data, result_path, caption):
44 | mp_joint = []
45 |
46 | if self.mode == "individual":
47 | joint = mp_data.reshape(-1,22,3)
48 | mp_joint.append(joint)
49 | else:
50 | for i, data in enumerate(mp_data):
51 | if i == 0:
52 | joint = data[:,:22*3].reshape(-1,22,3)
53 | else:
54 | joint = data[:,:22*3].reshape(-1,22,3)
55 |
56 | mp_joint.append(joint)
57 |
58 | plot_3d_motion(result_path, HML_KINEMATIC_CHAIN, mp_joint, title=caption, fps=30)
59 |
60 |
61 | def generate_one_sample(self, prompt_interaction, prompt_individual1, prompt_individual2, name):
62 | self.model.eval()
63 | batch = OrderedDict({})
64 |
65 | batch["motion_lens"] = torch.zeros(1,1).long().cuda()
66 | batch["prompt_interaction"] = prompt_interaction
67 |
68 | if self.mode != "individual":
69 | batch["prompt_individual1"] = prompt_individual1
70 | batch["prompt_individual2"] = prompt_individual2
71 |
72 | window_size = 210
73 | motion_output = self.generate_loop(batch, window_size)
74 | result_path = f"{self.save_folder}/{name}.mp4"
75 |
76 | if self.mode == "individual":
77 | self.plot_t2m(motion_output,
78 | result_path,
79 | batch["prompt_interaction"])
80 | else:
81 | self.plot_t2m([motion_output[0], motion_output[1]],
82 | result_path,
83 | batch["prompt_interaction"])
84 |
85 | def generate_loop(self, batch, window_size):
86 | prompt_interaction = batch["prompt_interaction"]
87 |
88 | if self.mode != "individual":
89 | prompt_individual1 = batch["prompt_individual1"]
90 | prompt_individual2 = batch["prompt_individual2"]
91 |
92 | batch = copy.deepcopy(batch)
93 | batch["motion_lens"][:] = window_size
94 |
95 | batch["text"] = [prompt_interaction]
96 | if self.mode != "individual":
97 | batch["text_individual1"] = [prompt_individual1]
98 | batch["text_individual2"] = [prompt_individual2]
99 |
100 | batch = self.model.forward_test(batch)
101 |
102 | if self.mode == "individual":
103 | motion_output = batch["output"][0].reshape(-1, 262)
104 | motion_output = self.normalizer.backward(motion_output.cpu().detach().numpy())
105 | joints3d = motion_output[:,:22*3].reshape(-1,22,3)
106 | joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest')
107 | return joints3d
108 |
109 | motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1)
110 | motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy())
111 |
112 | sequences = [[], []]
113 | for j in range(2):
114 | motion_output = motion_output_both[:,j]
115 | joints3d = motion_output[:,:22*3].reshape(-1,22,3)
116 | joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest')
117 | sequences[j].append(joints3d)
118 |
119 | sequences[0] = np.concatenate(sequences[0], axis=0)
120 | sequences[1] = np.concatenate(sequences[1], axis=0)
121 | return sequences
122 |
123 | if __name__ == '__main__':
124 |
125 | # Create the parser
126 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments")
127 |
128 | # Add optional arguments
129 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file')
130 | parser.add_argument('--infer', type=str, required=True, help='Infer Configuration file')
131 | parser.add_argument('--mode', type=str, required=True, help='Mode of the inference (individual, interaction, dual)')
132 | parser.add_argument('--out', type=str, required=True, help='Folder to save the results')
133 | parser.add_argument('--device', type=str, required=True, help='Device to run the model')
134 |
135 | parser.add_argument('--text_interaction', type=str, required=True, help='Interaction prompt')
136 | parser.add_argument('--text_individual1', type=str, required=False, help='Individual 1 prompt')
137 | parser.add_argument('--text_individual2', type=str, required=False, help='Individual 2 prompt')
138 | parser.add_argument('--name', type=str, required=True, help='Name of the output file')
139 |
140 | # Parse the arguments
141 | args = parser.parse_args()
142 |
143 | model_cfg = get_config(args.model)
144 | infer_cfg = get_config(args.infer)
145 |
146 | if args.mode == "dual":
147 | model = load_DualMDM_model(model_cfg)
148 | else:
149 | model = in2IN(model_cfg, args.mode)
150 | model.load_state_dict(torch.load(model_cfg.CHECKPOINT), strict=True)
151 |
152 | litmodel = LitGenModel(model, infer_cfg, args.out, mode=args.mode).to(torch.device("cuda:"+ args.device))
153 | litmodel.generate_one_sample(args.text_interaction, args.text_individual1, args.text_individual2, args.name)
154 |
155 |
--------------------------------------------------------------------------------
/in2in/models/in2in.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 |
4 | from torch import nn
5 | from in2in.models.nets import in2INDiffusion
6 | from in2in.models.utils.utils import set_requires_grad
7 |
8 | class in2IN(nn.Module):
9 | def __init__(self, cfg, mode):
10 | super().__init__()
11 | self.cfg = cfg
12 | self.latent_dim = cfg.LATENT_DIM
13 | # Mode can be individual interaction or dual
14 | self.mode = mode
15 |
16 | # DECODER (Denoiser)
17 | self.decoder = in2INDiffusion(cfg, mode, sampling_strategy=cfg.STRATEGY)
18 |
19 | # TEXT ENCODER (Trainable) - 1 FOR EACH MODEL #
20 | # INTERACTION
21 | if self.mode == "interaction" or self.mode == "dual":
22 | clipTransEncoderLayer_interaction = nn.TransformerEncoderLayer(
23 | d_model=768,
24 | nhead=8,
25 | dim_feedforward=2048,
26 | dropout=0.1,
27 | activation="gelu",
28 | batch_first=True)
29 |
30 | self.clipTransEncoder_interaction = nn.TransformerEncoder(
31 | clipTransEncoderLayer_interaction,
32 | num_layers=2)
33 |
34 | self.clip_ln_interaction = nn.LayerNorm(768)
35 |
36 | # INDIVIDUAL
37 | if self.mode == "individual" or self.mode == "dual":
38 | clipTransEncoderLayer_individual = nn.TransformerEncoderLayer(
39 | d_model=768,
40 | nhead=8,
41 | dim_feedforward=2048,
42 | dropout=0.1,
43 | activation="gelu",
44 | batch_first=True)
45 |
46 | self.clipTransEncoder_individual = nn.TransformerEncoder(
47 | clipTransEncoderLayer_individual,
48 | num_layers=2)
49 |
50 | self.clip_ln_individual = nn.LayerNorm(768)
51 |
52 | # CLIP MODEL (No trainable)
53 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False)
54 |
55 | self.token_embedding = clip_model.token_embedding
56 | self.clip_transformer = clip_model.transformer
57 | self.positional_embedding = clip_model.positional_embedding
58 | self.ln_final = clip_model.ln_final
59 | self.dtype = clip_model.dtype
60 |
61 | set_requires_grad(self.clip_transformer, False)
62 | set_requires_grad(self.token_embedding, False)
63 | set_requires_grad(self.ln_final, False)
64 |
65 | def compute_loss(self, batch):
66 | if self.mode == "dual":
67 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction")
68 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1")
69 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2")
70 | batch = self.text_process(batch, mode="individual",text_name="text_individual1", out_name="cond_individual_individual1")
71 | batch = self.text_process(batch, mode="individual",text_name="text_individual2", out_name="cond_individual_individual2")
72 | elif self.mode == "interaction" :
73 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction")
74 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1")
75 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2")
76 | elif self.mode == "individual":
77 | batch = self.text_process(batch, mode="individual", out_name="cond_individual_individual1")
78 |
79 | losses = self.decoder.compute_loss(batch)
80 | return losses["total"], losses
81 |
82 | def decode_motion(self, batch):
83 | batch.update(self.decoder(batch))
84 | return batch
85 |
86 | def forward(self, batch):
87 | return self.compute_loss(batch)
88 |
89 | def forward_test(self, batch):
90 | if self.mode == "dual":
91 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction")
92 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1")
93 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2")
94 | batch = self.text_process(batch, mode="individual",text_name="text_individual1", out_name="cond_individual_individual1")
95 | batch = self.text_process(batch, mode="individual",text_name="text_individual2", out_name="cond_individual_individual2")
96 | elif self.mode == "interaction" :
97 | batch = self.text_process(batch, mode="interaction", out_name="cond_interaction")
98 | batch = self.text_process(batch, mode="interaction",text_name="text_individual1", out_name="cond_interaction_individual1")
99 | batch = self.text_process(batch, mode="interaction",text_name="text_individual2", out_name="cond_interaction_individual2")
100 | elif self.mode == "individual":
101 | batch = self.text_process(batch, mode="individual", out_name="cond_individual_individual1")
102 |
103 | batch.update(self.decode_motion(batch))
104 | return batch
105 |
106 | def text_process(self, batch, mode, text_name="text", out_name="cond"):
107 | device = next(self.clip_transformer.parameters()).device
108 | raw_text = batch[text_name]
109 |
110 | with torch.no_grad():
111 |
112 | text = clip.tokenize(raw_text, truncate=True).to(device)
113 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
114 | pe_tokens = x + self.positional_embedding.type(self.dtype)
115 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND
116 | x = self.clip_transformer(x)
117 | x = x.permute(1, 0, 2)
118 | clip_out = self.ln_final(x).type(self.dtype)
119 |
120 | if mode == "individual":
121 | out = self.clipTransEncoder_individual(clip_out)
122 | out = self.clip_ln_individual(out)
123 | elif mode == "interaction":
124 | out = self.clipTransEncoder_interaction(clip_out)
125 | out = self.clip_ln_interaction(out)
126 | else:
127 | raise ValueError("Mode not recognized")
128 |
129 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
130 | batch[out_name] = cond
131 |
132 | return batch
133 |
--------------------------------------------------------------------------------
/in2in/utils/paramUtil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | ### CONSTANTS ###
4 |
5 | HML_RAW_OFFSETS = np.array([[0,0,0],
6 | [1,0,0],
7 | [-1,0,0],
8 | [0,1,0],
9 | [0,-1,0],
10 | [0,-1,0],
11 | [0,1,0],
12 | [0,-1,0],
13 | [0,-1,0],
14 | [0,1,0],
15 | [0,0,1],
16 | [0,0,1],
17 | [0,1,0],
18 | [1,0,0],
19 | [-1,0,0],
20 | [0,0,1],
21 | [0,-1,0],
22 | [0,-1,0],
23 | [0,-1,0],
24 | [0,-1,0],
25 | [0,-1,0],
26 | [0,-1,0]])
27 | HML_KINEMATIC_CHAIN = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
28 | HML_LEFT_HAND_CHAIN = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
29 | HML_RIGHT_HAND_CHAIN = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
30 | HML_TGT_SKEL_ID = '000021'
31 | HML_JOINT_NAMES = [
32 | 'pelvis',
33 | 'left_hip',
34 | 'right_hip',
35 | 'spine1',
36 | 'left_knee',
37 | 'right_knee',
38 | 'spine2',
39 | 'left_ankle',
40 | 'right_ankle',
41 | 'spine3',
42 | 'left_foot',
43 | 'right_foot',
44 | 'neck',
45 | 'left_collar',
46 | 'right_collar',
47 | 'head',
48 | 'left_shoulder',
49 | 'right_shoulder',
50 | 'left_elbow',
51 | 'right_elbow',
52 | 'left_wrist',
53 | 'right_wrist',
54 | ]
55 | NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints
56 | HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis',
57 | 'left_hip',
58 | 'right_hip',
59 | 'left_knee',
60 | 'right_knee',
61 | 'left_ankle',
62 | 'right_ankle',
63 | 'left_foot',
64 | 'right_foot',]]
65 | SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS]
66 | HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1))
67 | HML_ROOT_MASK = np.concatenate(([True]*(1+2+1),
68 | HML_ROOT_BINARY[1:].repeat(3),
69 | HML_ROOT_BINARY[1:].repeat(6),
70 | HML_ROOT_BINARY.repeat(3),
71 | [False] * 4))
72 | HML_ROOT_HORIZONTAL_MASK = np.concatenate(([True]*(1+2) + [False],
73 | np.zeros_like(HML_ROOT_BINARY[1:].repeat(3)),
74 | np.zeros_like(HML_ROOT_BINARY[1:].repeat(6)),
75 | np.zeros_like(HML_ROOT_BINARY.repeat(3)),
76 | [False] * 4))
77 | HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)])
78 | HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1),
79 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3),
80 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6),
81 | HML_LOWER_BODY_JOINTS_BINARY.repeat(3),
82 | [True]*4))
83 | HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK
84 | HML_TRAJ_MASK = np.zeros_like(HML_ROOT_MASK)
85 | HML_TRAJ_MASK[1:3] = True
86 | NUM_HML_FEATS = 263
87 | L_IDX1, L_IDX2 = 5, 8 # Lower legs
88 | FID_R, FID_L = [8, 11], [7, 10] # Right/Left foot
89 | FACE_JOINT_INDX = [2, 1, 17, 16] # Face direction, r_hip, l_hip, sdr_r, sdr_l
90 | R_HIP, L_HIP = 2, 1 # l_hip, r_hip
91 | JOINTS_NUM = 22
92 |
93 | ### FUNCTIONS ###
94 |
95 | def expand_mask(mask, shape):
96 | """
97 | expands a mask of shape (num_feat, seq_len) to the requested shape (usually, (batch_size, num_feat, 1, seq_len))
98 | """
99 | _, num_feat, _, _ = shape
100 | return np.ones(shape) * mask.reshape((1, num_feat, 1, -1))
101 |
102 | def get_joints_mask(join_names):
103 | joins_mask = np.array([joint_name in join_names for joint_name in HML_JOINT_NAMES])
104 | mask = np.concatenate(([False]*(1+2+1),
105 | joins_mask[1:].repeat(3),
106 | np.zeros_like(joins_mask[1:].repeat(6)),
107 | np.zeros_like(joins_mask.repeat(3)),
108 | [False] * 4))
109 | return mask
110 |
111 | def get_batch_joint_mask(shape, joint_names):
112 | return expand_mask(get_joints_mask(joint_names), shape)
113 |
114 | def get_in_between_mask(shape, lengths, prefix_end, suffix_end):
115 | mask = np.ones(shape) # True means use gt motion
116 | for i, length in enumerate(lengths):
117 | start_idx, end_idx = int(prefix_end * length), int(suffix_end * length)
118 | mask[i, :, :, start_idx: end_idx] = 0 # do inpainting in those frames
119 | return mask
120 |
121 | def get_prefix_mask(shape, prefix_length=20):
122 | _, num_feat, _, seq_len = shape
123 | prefix_mask = np.concatenate((np.ones((num_feat, prefix_length)), np.zeros((num_feat, seq_len - prefix_length))), axis=-1)
124 | return expand_mask(prefix_mask, shape)
125 |
126 | def get_inpainting_mask(mask_name, shape, **kwargs):
127 | mask_names = mask_name.split(',')
128 |
129 | mask = np.zeros(shape)
130 | if 'in_between' in mask_names:
131 | mask = np.maximum(mask, get_in_between_mask(shape, **kwargs))
132 |
133 | if 'root' in mask_names:
134 | mask = np.maximum(mask, expand_mask(HML_ROOT_MASK, shape))
135 |
136 | if 'root_horizontal' in mask_names:
137 | mask = np.maximum(mask, expand_mask(HML_ROOT_HORIZONTAL_MASK, shape))
138 |
139 | if 'prefix' in mask_names:
140 | mask = np.maximum(mask, get_prefix_mask(shape, **kwargs))
141 |
142 | if 'upper_body' in mask_names:
143 | mask = np.maximum(mask, expand_mask(HML_UPPER_BODY_MASK, shape))
144 |
145 | if 'lower_body' in mask_names:
146 | mask = np.maximum(mask, expand_mask(HML_LOWER_BODY_MASK, shape))
147 |
148 | return np.maximum(mask, get_batch_joint_mask(shape, mask_names))
149 |
150 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | SOFTWARE LICENSE AGREEMENT
2 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
3 |
4 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
5 |
6 | This is a license agreement ("Agreement") between your academic institution or non profit organization or self (called "Licensee" or "You" in this Agreement) and Universidad de Alicante (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
7 |
8 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
9 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
10 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
11 |
12 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
13 |
14 | COPYRIGHT: The Software is owned by Licensor and is protected by Spain copyright laws and applicable international treaties and/or conventions.
15 |
16 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
17 |
18 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
19 |
20 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
21 |
22 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark "in2IN", "DualMDM", "Universidad de Alicante", or any renditions thereof without the prior written permission of Licensor.
23 |
24 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
25 |
26 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
27 |
28 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by clicking "I Agree" below or by using the Software until terminated as provided below.
29 |
30 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
31 |
32 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
33 |
34 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
35 |
36 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
37 |
38 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
39 |
40 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable
41 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
42 |
43 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
44 |
45 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
46 |
47 | GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Spain without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Alicante, Spain.
48 |
49 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
in2IN:Leveraging individual Information to Generate Human INteractions
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## 🔎 About
13 |
14 |

15 |
16 |
17 | Generating human-human motion interactions conditioned on textual descriptions is a very useful application in many areas such as robotics, gaming, animation, and the metaverse. Alongside this utility also comes a great difficulty in modeling the highly dimensional inter-personal dynamics. In addition, properly capturing the intra-personal diversity of interactions has a lot of challenges. Current methods generate interactions with limited diversity of intra-person dynamics due to the limitations of the available datasets and conditioning strategies. For this, we introduce in2IN, a novel diffusion model for human-human motion generation which is conditioned not only on the textual description of the overall interaction but also on the individual descriptions of the actions performed by each person involved in the interaction. To train this model, we use a large language model to extend the InterHuman dataset with individual descriptions. As a result, in2IN achieves state-of-the-art performance in the InterHuman dataset. Furthermore, in order to increase the intra-personal diversity on the existing interaction datasets, we propose DualMDM, a model composition technique that combines the motions generated with in2IN and the motions generated by a single-person motion prior pre-trained on HumanML3D. As a result, DualMDM generates motions with higher individual diversity and improves control over the intra-person dynamics while maintaining inter-personal coherence.
18 |
19 |
20 |
21 | ## 📌 News
22 | - [2024-06-04] Code, model weights, and additional training data are now available!
23 | - [2024-04-16] Our paper is available on [arXiv](https://arxiv.org/abs/2404.09988)
24 | - [2024-04-06] in2IN is now accepted at CVPR 2024 Workshop [HuMoGen](https://humogen.github.io)!
25 |
26 | ## 📝 TODO List
27 | - [x] Release code
28 | - [x] Release model weights
29 | - [x] Release individual descriptions from InterHuman dataset.
30 | - [ ] Release visualization code.
31 |
32 |
33 | ## 💻 Usage
34 | ### 🛠️ Installation
35 | 1. Clone the repo
36 | ```sh
37 | git clone https://github.com/pabloruizponce/in2IN.git
38 | ```
39 | 2. Install the requirements
40 | 1. Download the required libraries
41 | ```sh
42 | pip install -r requirements.txt
43 | ```
44 | 2. Install ffmpeg
45 | ```sh
46 | sudo apt update
47 | sudo apt install ffmpeg
48 | ```
49 |
50 | > [!WARNING]
51 | > All the code has been tested with Ubuntu 22.04.3 LTS x86_64 using Python 3.12.2 and CUDA 12.3.1. If you have any issues, please open and issue.
52 | 3. Download the individual descriptions from the InterHuman dataset from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `data` folder.
53 | > [!IMPORTANT]
54 | > The original InterHuman dataset is needed to run the code. You can download it from [here](https://github.com/tr3e/InterGen). If you use the dataset, please cite us and the original paper.
55 |
56 | ### 🕹️ Inference
57 | Download the model weights from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `checkpoints` folder.
58 |
59 | ```sh
60 | python in2in/scripts/infer.py \
61 | --model configs/models/in2IN.yaml \
62 | --infer configs/infer.yaml \
63 | --mode interaction \
64 | --out results \
65 | --device 0 \
66 | --text_interaction "Interaction textual description" \
67 | --text_individual1 "Individual textual description" \
68 | --text_individual2 "Individual textual description" \
69 | --name "output_name" \
70 | ```
71 |
72 | > [!NOTE]
73 | > More information about the parameters can be found using the `--help` flag.
74 |
75 |
76 | ### 🏃🏻♂️ Training
77 |
78 | ```sh
79 | python in2in/scripts/train.py \
80 | --train configs/train/in2IN.yaml \
81 | --model configs/models/in2IN.yaml \
82 | --data configs/datasets.yaml \
83 | --mode interaction \
84 | --device 0 \
85 | ```
86 |
87 | ### 🎖️ Evaluation
88 | Download the evaluator model weights from [here](https://drive.google.com/drive/folders/14I3_BLu7ItWPNBWN8rMChOZiIkEhXrxH?usp=share_link) and place them in the `checkpoints` folder.
89 |
90 | #### Interaction Quality
91 | ```sh
92 | python in2in/scripts/eval/interhuman.py \
93 | --model configs/models/in2IN.yaml \
94 | --evaluator configs/eval.yaml \
95 | --mode [interaction, dual] \
96 | --out results \
97 | --device 0 \
98 | ```
99 |
100 | #### Individual Diversity
101 | ```sh
102 | python in2in/scripts/eval/DualMDM.py \
103 | --model configs/models/DualMDM.yaml \
104 | --evaluator configs/eval.yaml \
105 | --device 0 \
106 | ```
107 |
108 | ## 📚 Citation
109 |
110 | If you find our work helpful, please cite:
111 |
112 | ```bibtex
113 | @InProceedings{Ruiz-Ponce_2024_CVPR,
114 | author = {Ruiz-Ponce, Pablo and Barquero, German and Palmero, Cristina and Escalera, Sergio and Garc{\'\i}a-Rodr{\'\i}guez, Jos\'e},
115 | title = {in2IN: Leveraging Individual Information to Generate Human INteractions},
116 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
117 | month = {June},
118 | year = {2024},
119 | pages = {1941-1951}
120 | }
121 | ```
122 |
123 | ## 🫶🏼 Acknowledgments
124 | - [InterGen](https://github.com/tr3e/InterGen) as we inherit a lot of code from them.
125 | - [MDM](https://github.com/GuyTevet/motion-diffusion-model) as we used their evaluation code for text-motion models.
126 | - [Diffusion Models Beat GANS on Image Synthesis](https://github.com/openai/guided-diffusion) as we used their gaussian diffusion code as a base for our implementation.
127 |
--------------------------------------------------------------------------------
/in2in/scripts/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(sys.path[0] + r"/../../")
3 |
4 | import os
5 | import time
6 | import wandb
7 | import torch
8 | import lightning.pytorch as pl
9 | from pytorch_lightning.loggers import WandbLogger
10 |
11 | import torch.optim as optim
12 | from collections import OrderedDict
13 | from datasets import DataModule, DataModuleHML3D
14 | from in2in.utils.configs import get_config
15 | from os.path import join as pjoin
16 | import argparse
17 | from in2in.models.utils.utils import CosineWarmupScheduler
18 | from in2in.utils.utils import print_current_loss
19 | from in2in.models.in2in import in2IN
20 |
21 |
22 | os.environ['PL_TORCH_DISTRIBUTED_BACKEND'] = 'nccl'
23 | from lightning.pytorch.strategies import DDPStrategy
24 | torch.set_float32_matmul_precision('medium')
25 |
26 | def list_of_ints(arg):
27 | return list(map(int, arg.split(',')))
28 |
29 | class LitTrainModel(pl.LightningModule):
30 | def __init__(self, model, cfg, mode):
31 | super().__init__()
32 |
33 | self.cfg = cfg
34 | self.automatic_optimization = False
35 | self.save_root = pjoin(self.cfg.GENERAL.CHECKPOINT, self.cfg.GENERAL.EXP_NAME)
36 | self.model_dir = pjoin(self.save_root, 'model')
37 | self.meta_dir = pjoin(self.save_root, 'meta')
38 | self.log_dir = pjoin(self.save_root, 'log')
39 |
40 | os.makedirs(self.model_dir, exist_ok=True)
41 | os.makedirs(self.meta_dir, exist_ok=True)
42 | os.makedirs(self.log_dir, exist_ok=True)
43 |
44 | self.model = model
45 |
46 | # Can be ["individual", "iteraction", "dual"]
47 | self.mode = mode
48 |
49 | # save hyper-parameters to self.hparamsm auto-logged by wandb
50 | self.save_hyperparameters(ignore=['model'])
51 |
52 | def _configure_optim(self):
53 | optimizer = optim.AdamW(self.model.parameters(), lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
54 | if self.mode == "individual":
55 | optimizer = optim.AdamW(self.model.parameters(), lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
56 | return [optimizer]
57 | else:
58 | scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=10, max_iters=self.cfg.TRAIN.EPOCH, verbose=True)
59 | return [optimizer], [scheduler]
60 |
61 | def configure_optimizers(self):
62 | return self._configure_optim()
63 |
64 | def forward(self, batch_data):
65 | if self.mode == "individual":
66 | name, text, motion1, motion_lens = batch_data
67 | motions = motion1.detach().float()
68 | elif self.mode == "interaction":
69 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data
70 | motion1 = motion1.detach().float()
71 | motion2 = motion2.detach().float()
72 | motions = torch.cat([motion1, motion2], dim=-1)
73 |
74 | B, T = motion1.shape[:2]
75 |
76 | batch = OrderedDict({})
77 | batch["text"] = text
78 |
79 | if self.mode == "interaction":
80 | batch["text_individual1"] = text_individual1
81 | batch["text_individual2"] = text_individual2
82 |
83 | batch["motions"] = motions.reshape(B, T, -1).type(torch.float32)
84 | batch["motion_lens"] = motion_lens.long()
85 |
86 | loss, loss_logs = self.model(batch)
87 | return loss, loss_logs
88 |
89 | def on_train_start(self):
90 | self.rank = 0
91 | self.world_size = 1
92 | self.start_time = time.time()
93 | self.it = self.cfg.TRAIN.LAST_ITER if self.cfg.TRAIN.LAST_ITER else 0
94 | self.epoch = self.cfg.TRAIN.LAST_EPOCH if self.cfg.TRAIN.LAST_EPOCH else 0
95 | self.logs = OrderedDict()
96 |
97 | print("Model Iterations", self.it)
98 | print("Model Epochs", self.epoch)
99 |
100 |
101 | def training_step(self, batch, batch_idx):
102 | loss, loss_logs = self.forward(batch)
103 | opt = self.optimizers()
104 | opt.zero_grad()
105 | self.manual_backward(loss)
106 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
107 | opt.step()
108 |
109 | return {"loss": loss,
110 | "loss_logs": loss_logs}
111 |
112 |
113 | def on_train_batch_end(self, outputs, batch, batch_idx):
114 | if outputs.get('skip_batch') or not outputs.get('loss_logs'):
115 | return
116 | for k, v in outputs['loss_logs'].items():
117 | if k not in self.logs:
118 | self.logs[k] = v.item()
119 | else:
120 | self.logs[k] += v.item()
121 |
122 | self.it += 1
123 | if self.it % self.cfg.TRAIN.LOG_STEPS == 0:
124 | mean_loss = OrderedDict({})
125 | for tag, value in self.logs.items():
126 | mean_loss[tag] = value / self.cfg.TRAIN.LOG_STEPS
127 | # log metrics to wandb
128 | self.log(tag, mean_loss[tag], on_step=True, on_epoch=False, prog_bar=True)
129 | self.logs = OrderedDict()
130 | print_current_loss(self.start_time, self.it, mean_loss,
131 | self.trainer.current_epoch,
132 | inner_iter=batch_idx,
133 | lr=self.trainer.optimizers[0].param_groups[0]['lr'])
134 |
135 |
136 | def on_train_epoch_end(self):
137 | if self.mode == "interaction":
138 | sch = self.lr_schedulers()
139 | if sch is not None:
140 | sch.step()
141 |
142 | def save(self, file_name):
143 | state = {}
144 | try:
145 | state['model'] = self.model.module.state_dict()
146 | except:
147 | state['model'] = self.model.state_dict()
148 | torch.save(state, file_name, _use_new_zipfile_serialization=False)
149 | return
150 |
151 |
152 | if __name__ == '__main__':
153 |
154 | # Create the parser
155 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments")
156 |
157 | # Add arguments
158 | parser.add_argument('--train', type=str, required=True, help='Training Configuration file')
159 | parser.add_argument('--model' , type=str, required=True, help='Model Configuration file')
160 | parser.add_argument('--data', type=str, required=True, help='Data Configuration file')
161 | parser.add_argument('--mode', type=str, required=True, help='Model mode (individual or interaction)')
162 | parser.add_argument('--resume', type=str, required=False, help='Resume training from checkpoint')
163 | parser.add_argument('--device', type=list_of_ints, required=True, help='Device to run the training')
164 |
165 | # Parse the arguments
166 | args = parser.parse_args()
167 |
168 | model_cfg = get_config(args.model)
169 | train_cfg = get_config(args.train)
170 |
171 | # initialise the wandb logger and name your wandb project
172 | wandb_logger = WandbLogger(project='in2IN', name=train_cfg.GENERAL.EXP_NAME)
173 |
174 | if args.mode == "individual":
175 | data_cfg = get_config(args.data).humanml3d
176 | datamodule = DataModuleHML3D(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS)
177 | elif args.mode == "interaction":
178 | data_cfg = get_config(args.data).interhuman
179 | datamodule = DataModule(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS)
180 |
181 | model = in2IN(model_cfg, args.mode)
182 | litmodel = LitTrainModel(model, train_cfg, args.mode)
183 |
184 | # Checkpoint callback
185 | checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=litmodel.model_dir,
186 | every_n_epochs=train_cfg.TRAIN.SAVE_EPOCH,
187 | save_top_k=-1)
188 |
189 | # Trainer
190 | trainer = pl.Trainer(
191 | default_root_dir=litmodel.model_dir,
192 | devices=args.device, accelerator='gpu',
193 | max_epochs=train_cfg.TRAIN.EPOCH,
194 | strategy=DDPStrategy(find_unused_parameters=True),
195 | precision='16-mixed',
196 | logger=wandb_logger,
197 | callbacks=[checkpoint_callback]
198 | )
199 |
200 | if args.resume:
201 | trainer.fit(model=litmodel, datamodule=datamodule, ckpt_path=args.resume)
202 | else:
203 | trainer.fit(model=litmodel, datamodule=datamodule)
--------------------------------------------------------------------------------
/in2in/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 | import torch
4 |
5 | from in2in.evaluation.datasets import EvaluationDatasetDualMDM, EvaluationDatasetInterHuman, MMGeneratedDatasetInterHuman
6 | from in2in.models import *
7 | from in2in.datasets import InterHuman
8 | from in2in.evaluation.models import InterCLIP
9 | from torch.utils.data import DataLoader
10 |
11 | def get_dataset_motion_loader(opt, batch_size, num_samples=-1):
12 | """
13 | Get the ground truth dataset of motions with his given dataloader.
14 | :param opt: Configuration of the dataset.
15 | :param batch_size: Batch size of the dataloader.
16 | :return: Dataloader of the motion datase.
17 | """
18 | opt = copy.deepcopy(opt)
19 | if opt.NAME == 'interhuman':
20 | print('Loading dataset %s ...' % opt.NAME)
21 |
22 | dataset = InterHuman(opt, num_samples=num_samples)
23 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=True, shuffle=True)
24 | else:
25 | raise KeyError('Dataset not Recognized !!')
26 |
27 | print('Ground Truth Dataset Loading Completed!!!')
28 | return dataloader, dataset
29 |
30 |
31 |
32 |
33 | def get_motion_loader_in2IN(batch_size, model, ground_truth_dataset, device, mm_num_samples, mm_num_repeats):
34 | """
35 | Get the generated dataset of motions with his given dataloader and the MultiModality one.
36 | :param batch_size: Batch size of the dataloader.
37 | :param model: Model to generate the motions.
38 | :param ground_truth_dataset: Ground truth dataset.
39 | :param device: Device to run the model.
40 | :param mm_num_samples: Number of samples to generate for the MultiModality metric.
41 | :param mm_num_repeats: Number of repeats for each sample in the MultiModality metric.
42 | :return: Dataloader of the generated motion dataset and the MultiModality one.
43 | """
44 |
45 | dataset = EvaluationDatasetInterHuman(model, ground_truth_dataset, device, mm_num_samples=mm_num_samples, mm_num_repeats=mm_num_repeats)
46 | mm_dataset = MMGeneratedDatasetInterHuman(dataset)
47 |
48 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True)
49 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0)
50 |
51 | print('Generated Dataset Loading Completed!!!')
52 |
53 | return motion_loader, mm_motion_loader
54 |
55 |
56 |
57 | def get_motion_loader_DualMDM(batch_size, model, ground_truth_dataset, device, num_repeats):
58 | """
59 | Get the generated dataset of motions with his given dataloader
60 | :param batch_size: Batch size of the dataloader.
61 | :param model: Model to generate the motions.
62 | :param ground_truth_dataset: Ground truth dataset.
63 | :param device: Device to run the model.
64 | :param num_repeats: Number of repeats for each sample.
65 | :return: Dataloader of the generated motion dataset.
66 | """
67 | dataset = EvaluationDatasetDualMDM(model, ground_truth_dataset, device, num_repeats=num_repeats)
68 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True)
69 | return motion_loader
70 |
71 |
72 | def build_models(cfg):
73 | """
74 | Create and load the feature extractor model for the evaluation.
75 | :param cfg: Configuration of the model.
76 | :return: Feature extractor model for the evaluation.
77 | """
78 |
79 | model = InterCLIP(cfg)
80 |
81 | # Load the model from the checkpoint
82 | checkpoint = torch.load(cfg.CHECKPOINT, map_location="cpu")
83 | for k in list(checkpoint["state_dict"].keys()):
84 | if "model" in k:
85 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k)
86 | model.load_state_dict(checkpoint["state_dict"], strict=True)
87 |
88 | return model
89 |
90 |
91 | class EvaluatorModelWrapper(object):
92 | """
93 | Wrapper of the model for the evaluation.
94 | The model will be used to extract features from the generated motions and the gt motions.
95 | """
96 | def __init__(self, cfg, device):
97 | """
98 | Initialization of the model.
99 | :param cfg: Configuration of the model.
100 | :param device: Device to run the model.
101 | """
102 | self.model = build_models(cfg)
103 | self.cfg = cfg
104 | self.device = device
105 | self.model = self.model.to(device)
106 | self.model.eval()
107 | self.extended = cfg.EXTENDED
108 |
109 |
110 | def get_co_embeddings(self, batch_data):
111 | """
112 | Get the embeddings of the text and the motions of a given batch of data.
113 | :param batch_data: Batch of data to extract the embeddings.
114 | :return: Embeddings of the text and the motions.
115 | Please note that the results does not following the order of inputs
116 | """
117 | with torch.no_grad():
118 | # Extract data from the batch provided by the evaluation datasets
119 | if self.extended:
120 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data
121 | else:
122 | name, text, motion1, motion2, motion_lens = batch_data
123 |
124 | motion1 = motion1.detach().float() # .to(self.device)
125 | motion2 = motion2.detach().float() # .to(self.device)
126 | motions = torch.cat([motion1, motion2], dim=-1)
127 | motions = motions.detach().to(self.device).float()
128 |
129 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
130 | motions = motions[align_idx]
131 | motion_lens = motion_lens[align_idx]
132 | text = list(text)
133 | if self.extended:
134 | text_individual1 = list(text_individual1)
135 | text_individual2 = list(text_individual2)
136 |
137 | # Create padding for the motions
138 | B, T = motions.shape[:2]
139 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
140 | padded_len = cur_len.max()
141 |
142 | # Create batch for feature prediction
143 | batch = {}
144 | batch["text"] = text
145 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
146 | batch["motion_lens"] = motion_lens
147 | if self.extended:
148 | batch["text_individual1"] = text_individual1
149 | batch["text_individual2"] = text_individual2
150 |
151 | # Motion Encoding
152 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
153 |
154 | # Text Encoding
155 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx]
156 |
157 | return text_embedding, motion_embedding
158 |
159 | def get_motion_embeddings(self, batch_data):
160 | """
161 | Get the embeddings of the motions of a given batch of data.
162 | :param batch_data: Batch of data to extract the embeddings.
163 | :return: Embeddings of the motions.
164 | Please note that the results does not following the order of inputs
165 | """
166 | with torch.no_grad():
167 | # Extract data from the batch provided by the evaluation datasets
168 | if self.extended:
169 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = batch_data
170 | else:
171 | name, text, motion1, motion2, motion_lens = batch_data
172 |
173 | motion1 = motion1.detach().float() # .to(self.device)
174 | motion2 = motion2.detach().float() # .to(self.device)
175 | motions = torch.cat([motion1, motion2], dim=-1)
176 | motions = motions.detach().to(self.device).float()
177 |
178 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
179 | motions = motions[align_idx]
180 | motion_lens = motion_lens[align_idx]
181 | text = list(text)
182 |
183 | # Create padding for the motions
184 | B, T = motions.shape[:2]
185 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
186 | padded_len = cur_len.max()
187 |
188 | # Create batch for feature prediction
189 | batch = {}
190 | batch["text"] = text
191 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
192 | batch["motion_lens"] = motion_lens
193 | if self.extended:
194 | batch["text_individual1"] = text_individual1
195 | batch["text_individual2"] = text_individual2
196 |
197 | # Motion Encoding
198 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
199 |
200 | return motion_embedding
201 |
--------------------------------------------------------------------------------
/in2in/evaluation/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import clip
4 | from in2in.models.utils.utils import PositionalEncoding, set_requires_grad
5 |
6 |
7 | from models import *
8 |
9 | class MotionEncoder(nn.Module):
10 | """
11 | Motion encoder module for feature extractor evaluation model
12 | """
13 | def __init__(self, cfg):
14 | """
15 | Initialize the motion encoder module
16 | :param cfg: model configuration file
17 | """
18 | super().__init__()
19 |
20 | # Model parameters
21 | self.cfg = cfg
22 | self.input_feats = cfg.INPUT_DIM
23 | self.latent_dim = cfg.LATENT_DIM
24 | self.ff_size = cfg.FF_SIZE
25 | self.num_layers = cfg.NUM_LAYERS
26 | self.num_heads = cfg.NUM_HEADS
27 | self.dropout = cfg.DROPOUT
28 | self.activation = cfg.ACTIVATION
29 |
30 | # Model architecture
31 | self.query_token = nn.Parameter(torch.randn(1, self.latent_dim))
32 | self.embed_motion = nn.Linear(self.input_feats*2, self.latent_dim)
33 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=2000)
34 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
35 | nhead=self.num_heads,
36 | dim_feedforward=self.ff_size,
37 | dropout=self.dropout,
38 | activation=self.activation,
39 | batch_first=True)
40 | self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers)
41 | self.out_ln = nn.LayerNorm(self.latent_dim)
42 | self.out = nn.Linear(self.latent_dim, 512)
43 |
44 |
45 | def forward(self, batch):
46 | """
47 | Forward pass of the motion encoder module
48 | :param batch: input batch
49 | :return batch: updated batch
50 | """
51 |
52 | x, mask = batch["motions"], batch["mask"]
53 | B, T, D = x.shape
54 | x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1)
55 |
56 | # Embedding
57 | x_emb = self.embed_motion(x)
58 | emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=x.device)][:,None], x_emb], dim=1)
59 |
60 | # Masking
61 | seq_mask = (mask>0.5)
62 | token_mask = torch.ones((B, 1), dtype=bool, device=x.device)
63 | valid_mask = torch.cat([token_mask, seq_mask], dim=1)
64 |
65 | # Positional encoder and transformer
66 | h = self.sequence_pos_encoder(emb)
67 | h = self.transformer(h, src_key_padding_mask=~valid_mask)
68 | h = self.out_ln(h)
69 | motion_emb = self.out(h[:,0])
70 | batch["motion_emb"] = motion_emb
71 |
72 | return batch
73 |
74 | class InterCLIP(nn.Module):
75 | """
76 | InterCLIP model for feature extractor evaluation
77 | It is based in clip model and MotionEncoder
78 | """
79 | def __init__(self, cfg):
80 | """
81 | Initialize the InterCLIP model
82 | :param cfg: model configuration file
83 | """
84 | super().__init__()
85 |
86 | # Model parameters
87 | self.cfg = cfg
88 | self.latent_dim = cfg.LATENT_DIM
89 | self.motion_encoder = MotionEncoder(cfg)
90 | self.latent_dim = self.latent_dim
91 |
92 | # CLIP model
93 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False)
94 | self.token_embedding = clip_model.token_embedding
95 | self.positional_embedding = clip_model.positional_embedding
96 | self.dtype = clip_model.dtype
97 | self.latent_scale = nn.Parameter(torch.Tensor([1]))
98 | set_requires_grad(self.token_embedding, False)
99 |
100 | # Additional text encoding layers
101 | textTransEncoderLayer = nn.TransformerEncoderLayer(
102 | d_model=768,
103 | nhead=8,
104 | dim_feedforward=cfg.FF_SIZE,
105 | dropout=0.1,
106 | activation="gelu",
107 | batch_first=True)
108 | self.textTransEncoder = nn.TransformerEncoder(
109 | textTransEncoderLayer,
110 | num_layers=8)
111 | self.text_ln = nn.LayerNorm(768)
112 | self.out = nn.Linear(768, 512)
113 |
114 | # Losses
115 | self.clip_training = "text_"
116 | self.l1_criterion = torch.nn.L1Loss(reduction='mean')
117 | self.loss_ce = nn.CrossEntropyLoss()
118 |
119 |
120 | def generate_src_mask(self, T, length):
121 | """
122 | Generate source mask for transformer
123 | :param T: sequence length
124 | :param length: sequence length
125 | :return src_mask: source mask
126 | """
127 | B = length.shape[0]
128 | src_mask = torch.ones(B, T)
129 | for i in range(B):
130 | for j in range(length[i], T):
131 | src_mask[i, j] = 0
132 | return src_mask
133 |
134 | def encode_motion(self, batch):
135 | """
136 | Encode motion features
137 | :param batch: input batch
138 | :return batch: updated batch
139 | """
140 | batch["mask"] = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(batch["motions"].device)
141 | batch.update(self.motion_encoder(batch))
142 | batch["motion_emb"] = batch["motion_emb"] / batch["motion_emb"].norm(dim=-1, keepdim=True) * self.latent_scale
143 |
144 | return batch
145 |
146 | def encode_text(self, batch):
147 | """
148 | Encode text features
149 | :param batch: input batch
150 | :return batch: updated batch
151 | """
152 | device = next(self.parameters()).device
153 | raw_text = batch["text"]
154 |
155 | with torch.no_grad():
156 | text = clip.tokenize(raw_text, truncate=True).to(device)
157 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
158 | pe_tokens = x + self.positional_embedding.type(self.dtype)
159 |
160 | out = self.textTransEncoder(pe_tokens)
161 | out = self.text_ln(out)
162 |
163 | out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
164 | out = self.out(out)
165 |
166 | # Normalize
167 | batch['text_emb'] = out
168 | batch["text_emb"] = batch["text_emb"] / batch["text_emb"].norm(dim=-1, keepdim=True) * self.latent_scale
169 |
170 | return batch
171 |
172 |
173 | def compute_loss(self, batch):
174 | """
175 | Wrapper for calculating the loss of the model
176 | :param batch: input batch
177 | """
178 |
179 | losses = {}
180 | losses["total"] = 0
181 |
182 | # Encode text and motion
183 | batch = self.encode_text(batch)
184 | batch = self.encode_motion(batch)
185 |
186 | # Compute clip losses
187 | mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)
188 | losses.update(clip_losses)
189 | losses["total"] += mixed_clip_loss
190 |
191 | return losses["total"], losses
192 |
193 | def compute_clip_losses(self, batch):
194 | """
195 | Computing losses from the motion encoder and the text encoder
196 | :param batch: input batch
197 | :return mixed_clip_loss: mixed clip loss
198 | :return clip_losses: clip losses
199 | """
200 | mixed_clip_loss = 0.
201 | clip_losses = {}
202 |
203 | for d in self.clip_training.split('_')[:1]:
204 | if d == 'image':
205 | features = self.clip_model.encode_image(batch['images']).float() # preprocess is done in dataloader
206 | elif d == 'text':
207 | features = batch['text_emb']
208 | motion_features = batch['motion_emb']
209 |
210 | # Normalized features
211 | features_norm = features / features.norm(dim=-1, keepdim=True)
212 | motion_features_norm = motion_features / motion_features.norm(dim=-1, keepdim=True)
213 |
214 | # Compute logits
215 | logit_scale = self.latent_scale ** 2
216 | logits_per_motion = logit_scale * motion_features_norm @ features_norm.t()
217 | logits_per_d = logits_per_motion.t()
218 |
219 | batch_size = motion_features.shape[0]
220 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=motion_features.device)
221 |
222 | # Compute losses
223 | ce_from_motion_loss = self.loss_ce(logits_per_motion, ground_truth)
224 | ce_from_d_loss = self.loss_ce(logits_per_d, ground_truth)
225 | clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
226 |
227 | clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
228 | clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
229 | clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
230 | mixed_clip_loss += clip_mixed_loss
231 |
232 | return mixed_clip_loss, clip_losses
233 |
234 | def forward(self, batch):
235 | """
236 | Forward pass of the InterCLIP model
237 | :param batch: input batch
238 | :return batch: updated batch
239 | """
240 | return self.compute_loss(batch)
241 |
242 |
243 |
--------------------------------------------------------------------------------
/in2in/datasets/interhuman.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import os
4 |
5 | from tqdm import tqdm
6 | from torch.utils import data
7 | from utils.preprocess import load_motion
8 | from os.path import join as pjoin
9 | from utils.quaternion import *
10 | from utils.utils import rigid_transform, process_motion_interhuman
11 |
12 | class InterHuman(data.Dataset):
13 | """
14 | InterHuman dataset
15 | """
16 |
17 | def __init__(self, opt, num_samples=-1):
18 |
19 | # Configuration variables
20 | self.opt = opt
21 | self.max_cond_length = 1
22 | self.min_cond_length = 1
23 | self.max_gt_length = 300
24 | self.min_gt_length = 15
25 | self.max_length = self.max_cond_length + self.max_gt_length -1
26 | self.min_length = self.min_cond_length + self.min_gt_length -1
27 | self.motion_rep = opt.MOTION_REP
28 | self.cache = opt.CACHE
29 | self.extended = opt.EXTENDED
30 |
31 | # Data structures
32 | self.motion_dict = {}
33 | self.data_list = []
34 | data_list = []
35 |
36 | # Load paths from the given split
37 | if self.opt.MODE == "train":
38 | try:
39 | data_list = open(os.path.join(opt.DATA_ROOT, "split/train.txt"), "r").readlines()
40 | except Exception as e:
41 | print(e)
42 | elif self.opt.MODE == "val":
43 | try:
44 | data_list = open(os.path.join(opt.DATA_ROOT, "split/val.txt"), "r").readlines()
45 | except Exception as e:
46 | print(e)
47 | elif self.opt.MODE == "test":
48 | try:
49 | data_list = open(os.path.join(opt.DATA_ROOT, "split/test.txt"), "r").readlines()
50 | except Exception as e:
51 | print(e)
52 |
53 | # Suffle paths
54 | random.shuffle(data_list)
55 |
56 | if num_samples > 0:
57 | data_list = data_list[:num_samples]
58 | print("Using only {} samples".format(num_samples))
59 |
60 | # Load data
61 | index = 0
62 | root = pjoin(opt.DATA_ROOT, "motions_processed/person1")
63 | for file in tqdm(os.listdir(root)):
64 |
65 | # Comment if you want to use the whole dataset
66 | if file.split(".")[0]+"\n" not in data_list:
67 | continue
68 |
69 | motion_name = file.split(".")[0]
70 | file_path_person1 = pjoin(root, file)
71 | file_path_person2 = pjoin(root.replace("person1", "person2"), file)
72 | text_path = file_path_person1.replace("motions_processed", "annots").replace("person1", "").replace("npy", "txt")
73 |
74 | # Load interaction texts and make the swaps
75 | texts = [item.replace("\n", "") for item in open(text_path, "r").readlines()]
76 | texts_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right")
77 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts]
78 |
79 | # If using extended version, load individual desciptions of the motions
80 | if self.extended:
81 | text_path_individual1 = file_path_person1.replace("motions_processed", "annots_individual").replace("npy", "txt")
82 | text_path_individual2 = file_path_person2.replace("motions_processed", "annots_individual").replace("npy", "txt")
83 |
84 | if not os.path.exists(text_path_individual1):
85 | continue
86 | else:
87 | texts_individual1 = [item.replace("\n", "") for item in open(text_path_individual1, "r").readlines()]
88 | texts_individual2 = [item.replace("\n", "") for item in open(text_path_individual2, "r").readlines()]
89 |
90 | # Make the swaps of the individual descriptions
91 | texts_individual1_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right")
92 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_individual2]
93 | texts_individual2_swap = [item.replace("\n", "").replace("left", "tmp").replace("right", "left").replace("tmp", "right")
94 | .replace("clockwise", "tmp").replace("counterclockwise","clockwise").replace("tmp","counterclockwise") for item in texts_individual1]
95 |
96 | # Load motion and check if it is too short and cache it if needed
97 | if self.cache:
98 | motion1, motion1_swap = load_motion(file_path_person1, self.min_length, swap=True)
99 | motion2, motion2_swap = load_motion(file_path_person2, self.min_length, swap=True)
100 | if motion1 is None:
101 | continue
102 |
103 | if self.cache:
104 | self.motion_dict[index] = [motion1, motion2]
105 | self.motion_dict[index+1] = [motion1_swap, motion2_swap]
106 | else:
107 | self.motion_dict[index] = [file_path_person1, file_path_person2]
108 | self.motion_dict[index + 1] = [file_path_person1, file_path_person2]
109 |
110 | # Fill data structures depending on the variable of the dataset used
111 | if self.extended:
112 | self.data_list.append({
113 | "name": motion_name,
114 | "motion_id": index,
115 | "swap":False,
116 | "texts":texts,
117 | "texts_individual1":texts_individual1_swap,
118 | "texts_individual2":texts_individual2_swap,
119 | })
120 |
121 | if opt.MODE == "train":
122 | self.data_list.append({
123 | "name": motion_name+"_swap",
124 | "motion_id": index+1,
125 | "swap": True,
126 | "texts": texts_swap,
127 | "texts_individual1":texts_individual1,
128 | "texts_individual2":texts_individual2,
129 | })
130 | else:
131 | self.data_list.append({
132 | "name": motion_name,
133 | "motion_id": index,
134 | "swap":False,
135 | "texts":texts
136 | })
137 |
138 | if opt.MODE == "train":
139 | self.data_list.append({
140 | "name": motion_name+"_swap",
141 | "motion_id": index+1,
142 | "swap": True,
143 | "texts": texts_swap,
144 | })
145 |
146 | index += 2
147 |
148 | print("Total Dataset Size: ", len(self.data_list))
149 |
150 | def __len__(self):
151 | """
152 | Get the length of the dataset
153 | """
154 | return len(self.data_list)
155 |
156 | def __getitem__(self, item):
157 | """
158 | Get an item from the dataset
159 | param item: Index of the item to get
160 | """
161 |
162 | # Get the data from the dataset
163 | idx = item % self.__len__()
164 | data = self.data_list[idx]
165 | name = data["name"]
166 | motion_id = data["motion_id"]
167 | swap = data["swap"]
168 |
169 | # Select a random text from the list and if extended also select the individual descriptions
170 | text = random.choice(data["texts"]).strip()
171 | if self.extended:
172 | text_individual1 = random.choice(data["texts_individual1"]).strip()
173 | text_individual2 = random.choice(data["texts_individual2"]).strip()
174 |
175 | # Load the motion
176 | if self.cache:
177 | full_motion1, full_motion2 = self.motion_dict[motion_id]
178 | else:
179 | file_path1, file_path2 = self.motion_dict[motion_id]
180 | motion1, motion1_swap = load_motion(file_path1, self.min_length, swap=swap)
181 | motion2, motion2_swap = load_motion(file_path2, self.min_length, swap=swap)
182 | if swap:
183 | full_motion1 = motion1_swap
184 | full_motion2 = motion2_swap
185 | else:
186 | full_motion1 = motion1
187 | full_motion2 = motion2
188 |
189 | # Get motion lenght and select a random segment
190 | length = full_motion1.shape[0]
191 | if length > self.max_length:
192 | idx = random.choice(list(range(0, length - self.max_gt_length, 1)))
193 | gt_length = self.max_gt_length
194 | motion1 = full_motion1[idx:idx + gt_length]
195 | motion2 = full_motion2[idx:idx + gt_length]
196 | else:
197 | idx = 0
198 | gt_length = min(length - idx, self.max_gt_length )
199 | motion1 = full_motion1[idx:idx + gt_length]
200 | motion2 = full_motion2[idx:idx + gt_length]
201 |
202 | # Swap the motions randomly
203 | if np.random.rand() > 0.5:
204 | motion1, motion2 = motion2, motion1
205 |
206 | # Process the motion
207 | motion1, root_quat_init1, root_pos_init1 = process_motion_interhuman(motion1, 0.001, 0, n_joints=22)
208 | motion2, root_quat_init2, root_pos_init2 = process_motion_interhuman(motion2, 0.001, 0, n_joints=22)
209 |
210 | # Rotate motion 2
211 | r_relative = qmul_np(root_quat_init2, qinv_np(root_quat_init1))
212 | angle = np.arctan2(r_relative[:, 2:3], r_relative[:, 0:1])
213 | xz = qrot_np(root_quat_init1, root_pos_init2 - root_pos_init1)[:, [0, 2]]
214 | relative = np.concatenate([angle, xz], axis=-1)[0]
215 | motion2 = rigid_transform(relative, motion2)
216 |
217 | gt_motion1 = motion1
218 | gt_motion2 = motion2
219 |
220 | # Check if the motion is too short and pad it
221 | gt_length = len(gt_motion1)
222 | if gt_length < self.max_gt_length:
223 | padding_len = self.max_gt_length - gt_length
224 | D = gt_motion1.shape[1]
225 | padding_zeros = np.zeros((padding_len, D))
226 | gt_motion1 = np.concatenate((gt_motion1, padding_zeros), axis=0)
227 | gt_motion2 = np.concatenate((gt_motion2, padding_zeros), axis=0)
228 |
229 | # Return the data
230 | if self.extended:
231 | return name, text, gt_motion1, gt_motion2, gt_length, text_individual1, text_individual2
232 | else:
233 | return name, text, gt_motion1, gt_motion2, gt_length
234 |
--------------------------------------------------------------------------------
/in2in/utils/skeleton.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .paramUtil import FACE_JOINT_INDX, HML_KINEMATIC_CHAIN, HML_RAW_OFFSETS, L_IDX1, L_IDX2
4 | from .quaternion import *
5 | from scipy.ndimage import filters
6 | class Skeleton(object):
7 | """
8 | Skeleton class to handle the skeleton data such as
9 | offsets, kinematic tree, forward and inverse kinematics
10 | """
11 | def __init__(self, offset, kinematic_tree, device):
12 | """
13 | Initialize the Skeleton class
14 | :param offset: Offset of the skeleton that we are using
15 | :param kinematic_tree: Kinematic tree of the skeleton
16 | :param device: Device to run the skeleton
17 | """
18 | self.device = device
19 | self._raw_offset_np = offset.numpy()
20 | self._raw_offset = offset.clone().detach().to(device).float()
21 | self._kinematic_tree = kinematic_tree
22 | self._offset = None
23 | self._parents = [0] * len(self._raw_offset)
24 | self._parents[0] = -1
25 | for chain in self._kinematic_tree:
26 | for j in range(1, len(chain)):
27 | self._parents[chain[j]] = chain[j-1]
28 |
29 | def njoints(self):
30 | return len(self._raw_offset)
31 |
32 | def offset(self):
33 | return self._offset
34 |
35 | def set_offset(self, offsets):
36 | self._offset = offsets.clone().detach().to(self.device).float()
37 |
38 | def kinematic_tree(self):
39 | return self._kinematic_tree
40 |
41 | def parents(self):
42 | return self._parents
43 |
44 | # joints (batch_size, joints_num, 3)
45 | def get_offsets_joints_batch(self, joints):
46 | assert len(joints.shape) == 3
47 | _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
48 | for i in range(1, self._raw_offset.shape[0]):
49 | _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
50 |
51 | self._offset = _offsets.detach()
52 | return _offsets
53 |
54 | # joints (joints_num, 3)
55 | def get_offsets_joints(self, joints):
56 | assert len(joints.shape) == 2
57 | _offsets = self._raw_offset.clone()
58 | for i in range(1, self._raw_offset.shape[0]):
59 | # print(joints.shape)
60 | _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
61 |
62 | self._offset = _offsets.detach()
63 | return _offsets
64 |
65 | # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
66 | # joints (batch_size, joints_num, 3)
67 | def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
68 | assert len(face_joint_idx) == 4
69 | # Get Forward Direction
70 | l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
71 | across1 = joints[:, r_hip] - joints[:, l_hip]
72 | across2 = joints[:, sdr_r] - joints[:, sdr_l]
73 | across = across1 + across2
74 | across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
75 |
76 | # forward (batch_size, 3)
77 | forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
78 | if smooth_forward:
79 | forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
80 | # forward (batch_size, 3)
81 | forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
82 |
83 | # Get Root Rotation
84 | target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
85 | root_quat = qbetween_np(forward, target)
86 |
87 | # Inverse Kinematics
88 | # quat_params (batch_size, joints_num, 4)
89 | quat_params = np.zeros(joints.shape[:-1] + (4,))
90 | root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
91 | quat_params[:, 0] = root_quat
92 | for chain in self._kinematic_tree:
93 | R = root_quat
94 | for j in range(len(chain) - 1):
95 | # (batch, 3)
96 | u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
97 | # (batch, 3)
98 | v = joints[:, chain[j+1]] - joints[:, chain[j]]
99 | v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
100 | rot_u_v = qbetween_np(u, v)
101 | R_loc = qmul_np(qinv_np(R), rot_u_v)
102 | quat_params[:,chain[j + 1], :] = R_loc
103 | R = qmul_np(R, R_loc)
104 |
105 | return quat_params
106 |
107 | # Be sure root joint is at the beginning of kinematic chains
108 | def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
109 | # quat_params (batch_size, joints_num, 4)
110 | # joints (batch_size, joints_num, 3)
111 | # root_pos (batch_size, 3)
112 | if skel_joints is not None:
113 | offsets = self.get_offsets_joints_batch(skel_joints)
114 | if len(self._offset.shape) == 2:
115 | offsets = self._offset.expand(quat_params.shape[0], -1, -1)
116 | joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
117 | joints[:, 0] = root_pos
118 | for chain in self._kinematic_tree:
119 | if do_root_R:
120 | R = quat_params[:, 0]
121 | else:
122 | R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
123 | for i in range(1, len(chain)):
124 | R = qmul(R, quat_params[:, chain[i]])
125 | offset_vec = offsets[:, chain[i]]
126 | joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
127 | return joints
128 |
129 | # Be sure root joint is at the beginning of kinematic chains
130 | def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
131 | # quat_params (batch_size, joints_num, 4)
132 | # joints (batch_size, joints_num, 3)
133 | # root_pos (batch_size, 3)
134 | if skel_joints is not None:
135 | skel_joints = torch.from_numpy(skel_joints)
136 | offsets = self.get_offsets_joints_batch(skel_joints)
137 | if len(self._offset.shape) == 2:
138 | offsets = self._offset.expand(quat_params.shape[0], -1, -1)
139 | offsets = offsets.numpy()
140 | joints = np.zeros(quat_params.shape[:-1] + (3,))
141 | joints[:, 0] = root_pos
142 | for chain in self._kinematic_tree:
143 | if do_root_R:
144 | R = quat_params[:, 0]
145 | else:
146 | R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
147 | for i in range(1, len(chain)):
148 | R = qmul_np(R, quat_params[:, chain[i]])
149 | offset_vec = offsets[:, chain[i]]
150 | joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
151 | return joints
152 |
153 | def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
154 | # cont6d_params (batch_size, joints_num, 6)
155 | # joints (batch_size, joints_num, 3)
156 | # root_pos (batch_size, 3)
157 | if skel_joints is not None:
158 | skel_joints = torch.from_numpy(skel_joints)
159 | offsets = self.get_offsets_joints_batch(skel_joints)
160 | if len(self._offset.shape) == 2:
161 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
162 | offsets = offsets.numpy()
163 | joints = np.zeros(cont6d_params.shape[:-1] + (3,))
164 | joints[:, 0] = root_pos
165 | for chain in self._kinematic_tree:
166 | if do_root_R:
167 | matR = cont6d_to_matrix_np(cont6d_params[:, 0])
168 | else:
169 | matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
170 | for i in range(1, len(chain)):
171 | matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
172 | offset_vec = offsets[:, chain[i]][..., np.newaxis]
173 | joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
174 | return joints
175 |
176 | def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
177 | # cont6d_params (batch_size, joints_num, 6)
178 | # joints (batch_size, joints_num, 3)
179 | # root_pos (batch_size, 3)
180 | if skel_joints is not None:
181 | # skel_joints = torch.from_numpy(skel_joints)
182 | offsets = self.get_offsets_joints_batch(skel_joints)
183 | if len(self._offset.shape) == 2:
184 | offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
185 | joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
186 | joints[..., 0, :] = root_pos
187 | for chain in self._kinematic_tree:
188 | if do_root_R:
189 | matR = cont6d_to_matrix(cont6d_params[:, 0])
190 | else:
191 | matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
192 | for i in range(1, len(chain)):
193 | matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
194 | offset_vec = offsets[:, chain[i]].unsqueeze(-1)
195 | joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
196 | return joints
197 |
198 |
199 | def uniform_skeleton(positions, target_skeleton_path="data/motions_processed/person1/1.npy"):
200 | """
201 | Create a uniform skeleton from the source skeleton to the target skeleton
202 | :param positions: Source Skeleton motion
203 | :param target_skeleton_path: Target Skeleton path of the motion
204 | :return: New joints of the source skeleton in the target skeleton format
205 | """
206 |
207 | # Target Skeleton
208 | n_raw_offsets = torch.from_numpy(HML_RAW_OFFSETS)
209 | kinematic_chain = HML_KINEMATIC_CHAIN
210 | example_data = np.load(target_skeleton_path)
211 | example_data = example_data.reshape(len(example_data), -1, 3)
212 | example_data = torch.from_numpy(example_data)
213 | target_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
214 | target_offset = target_skel.get_offsets_joints(example_data[0])
215 |
216 | # Source Skeleton
217 | src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
218 | src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
219 | src_offset = src_offset.numpy()
220 | tgt_offset = target_offset.numpy()
221 |
222 | # Calculate Scale Ratio as the ratio of legs
223 | src_leg_len = np.abs(src_offset[L_IDX1]).max() + np.abs(src_offset[L_IDX2]).max()
224 | tgt_leg_len = np.abs(tgt_offset[L_IDX1]).max() + np.abs(tgt_offset[L_IDX2]).max()
225 |
226 | scale_rt = tgt_leg_len / src_leg_len
227 | src_root_pos = positions[:, 0]
228 | tgt_root_pos = src_root_pos * scale_rt
229 |
230 | # Inverse Kinematics
231 | quat_params = src_skel.inverse_kinematics_np(positions, FACE_JOINT_INDX)
232 |
233 | # Forward Kinematics
234 | src_skel.set_offset(target_offset)
235 | new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
236 | return new_joints
--------------------------------------------------------------------------------
/in2in/scripts/eval/interhuman.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(sys.path[0]+r"/../../../")
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from datetime import datetime
8 | from in2in.evaluation.utils import get_dataset_motion_loader, get_motion_loader_in2IN, EvaluatorModelWrapper
9 | from in2in.utils.metrics import *
10 | from collections import OrderedDict
11 | from in2in.utils.plot import *
12 | from in2in.utils.utils import *
13 | from in2in.utils.configs import get_config
14 | from tqdm import tqdm
15 | from in2in.models.dualmdm import load_DualMDM_model
16 | from in2in.models.in2in import in2IN
17 |
18 | import argparse
19 |
20 | def evaluate_matching_score(motion_loaders, file):
21 | match_score_dict = OrderedDict({})
22 | R_precision_dict = OrderedDict({})
23 | activation_dict = OrderedDict({})
24 | print('========== Evaluating MM Distance ==========')
25 | for motion_loader_name, motion_loader in motion_loaders.items():
26 | all_motion_embeddings = []
27 | score_list = []
28 | all_size = 0
29 | mm_dist_sum = 0
30 | top_k_count = 0
31 | with torch.no_grad():
32 | for idx, batch in tqdm(enumerate(motion_loader)):
33 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch)
34 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
35 | motion_embeddings.cpu().numpy())
36 | mm_dist_sum += dist_mat.trace()
37 |
38 | argsmax = np.argsort(dist_mat, axis=1)
39 | top_k_mat = calculate_top_k(argsmax, top_k=3)
40 | top_k_count += top_k_mat.sum(axis=0)
41 |
42 | all_size += text_embeddings.shape[0]
43 |
44 | all_motion_embeddings.append(motion_embeddings.cpu().numpy())
45 |
46 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
47 | mm_dist = mm_dist_sum / all_size
48 | R_precision = top_k_count / all_size
49 | match_score_dict[motion_loader_name] = mm_dist
50 | R_precision_dict[motion_loader_name] = R_precision
51 | activation_dict[motion_loader_name] = all_motion_embeddings
52 |
53 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}')
54 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}', file=file, flush=True)
55 |
56 | line = f'---> [{motion_loader_name}] R_precision: '
57 | for i in range(len(R_precision)):
58 | line += '(top %d): %.4f ' % (i+1, R_precision[i])
59 | print(line)
60 | print(line, file=file, flush=True)
61 |
62 | return match_score_dict, R_precision_dict, activation_dict
63 |
64 |
65 | def evaluate_fid(groundtruth_loader, activation_dict, file):
66 | eval_dict = OrderedDict({})
67 | gt_motion_embeddings = []
68 | print('========== Evaluating FID ==========')
69 | with torch.no_grad():
70 | for idx, batch in tqdm(enumerate(groundtruth_loader)):
71 | motion_embeddings = eval_wrapper.get_motion_embeddings(batch)
72 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
73 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
74 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
75 |
76 | for model_name, motion_embeddings in activation_dict.items():
77 | mu, cov = calculate_activation_statistics(motion_embeddings)
78 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
79 | print(f'---> [{model_name}] FID: {fid:.4f}')
80 | print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
81 | eval_dict[model_name] = fid
82 | return eval_dict
83 |
84 |
85 | def evaluate_diversity(activation_dict, file):
86 | eval_dict = OrderedDict({})
87 | print('========== Evaluating Diversity ==========')
88 | for model_name, motion_embeddings in activation_dict.items():
89 | diversity = calculate_diversity(motion_embeddings, diversity_times)
90 | eval_dict[model_name] = diversity
91 | print(f'---> [{model_name}] Diversity: {diversity:.4f}')
92 | print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
93 | return eval_dict
94 |
95 |
96 | def evaluate_multimodality(mm_motion_loaders, file):
97 | eval_dict = OrderedDict({})
98 | print('========== Evaluating MultiModality ==========')
99 | for model_name, mm_motion_loader in mm_motion_loaders.items():
100 | mm_motion_embeddings = []
101 | with torch.no_grad():
102 | for idx, batch in enumerate(mm_motion_loader):
103 | # (1, mm_replications, dim_pos)
104 | batch[2] = batch[2][0]
105 | batch[3] = batch[3][0]
106 | batch[4] = batch[4][0]
107 | motion_embedings = eval_wrapper.get_motion_embeddings(batch)
108 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
109 | if len(mm_motion_embeddings) == 0:
110 | multimodality = 0
111 | else:
112 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
113 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
114 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
115 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
116 | eval_dict[model_name] = multimodality
117 | return eval_dict
118 |
119 |
120 | def get_metric_statistics(values):
121 | mean = np.mean(values, axis=0)
122 | std = np.std(values, axis=0)
123 | conf_interval = 1.96 * std / np.sqrt(replication_times)
124 | return mean, conf_interval
125 |
126 |
127 | def evaluation(log_file):
128 | with open(log_file, 'w') as f:
129 | all_metrics = OrderedDict({'MM Distance': OrderedDict({}),
130 | 'R_precision': OrderedDict({}),
131 | 'FID': OrderedDict({}),
132 | 'Diversity': OrderedDict({}),
133 | 'MultiModality': OrderedDict({})})
134 | for replication in range(replication_times):
135 | motion_loaders = {}
136 | mm_motion_loaders = {}
137 | motion_loaders['ground truth'] = gt_loader
138 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
139 | motion_loader, mm_motion_loader = motion_loader_getter()
140 | motion_loaders[motion_loader_name] = motion_loader
141 | mm_motion_loaders[motion_loader_name] = mm_motion_loader
142 |
143 | print(f'==================== Replication {replication} ====================')
144 | print(f'==================== Replication {replication} ====================', file=f, flush=True)
145 | print(f'Time: {datetime.now()}')
146 | print(f'Time: {datetime.now()}', file=f, flush=True)
147 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f)
148 |
149 | print(f'Time: {datetime.now()}')
150 | print(f'Time: {datetime.now()}', file=f, flush=True)
151 | fid_score_dict = evaluate_fid(gt_loader, acti_dict, f)
152 |
153 | print(f'Time: {datetime.now()}')
154 | print(f'Time: {datetime.now()}', file=f, flush=True)
155 | div_score_dict = evaluate_diversity(acti_dict, f)
156 |
157 | print(f'Time: {datetime.now()}')
158 | print(f'Time: {datetime.now()}', file=f, flush=True)
159 | mm_score_dict = evaluate_multimodality(mm_motion_loaders, f)
160 |
161 | print(f'!!! DONE !!!')
162 | print(f'!!! DONE !!!', file=f, flush=True)
163 |
164 | for key, item in mat_score_dict.items():
165 | if key not in all_metrics['MM Distance']:
166 | all_metrics['MM Distance'][key] = [item]
167 | else:
168 | all_metrics['MM Distance'][key] += [item]
169 |
170 | for key, item in R_precision_dict.items():
171 | if key not in all_metrics['R_precision']:
172 | all_metrics['R_precision'][key] = [item]
173 | else:
174 | all_metrics['R_precision'][key] += [item]
175 |
176 | for key, item in fid_score_dict.items():
177 | if key not in all_metrics['FID']:
178 | all_metrics['FID'][key] = [item]
179 | else:
180 | all_metrics['FID'][key] += [item]
181 |
182 | for key, item in div_score_dict.items():
183 | if key not in all_metrics['Diversity']:
184 | all_metrics['Diversity'][key] = [item]
185 | else:
186 | all_metrics['Diversity'][key] += [item]
187 |
188 | for key, item in mm_score_dict.items():
189 | if key not in all_metrics['MultiModality']:
190 | all_metrics['MultiModality'][key] = [item]
191 | else:
192 | all_metrics['MultiModality'][key] += [item]
193 |
194 | for metric_name, metric_dict in all_metrics.items():
195 | print('========== %s Summary ==========' % metric_name)
196 | print('========== %s Summary ==========' % metric_name, file=f, flush=True)
197 |
198 | for model_name, values in metric_dict.items():
199 | mean, conf_interval = get_metric_statistics(np.array(values))
200 | if isinstance(mean, np.float64) or isinstance(mean, np.float32):
201 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
202 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
203 | elif isinstance(mean, np.ndarray):
204 | line = f'---> [{model_name}]'
205 | for i in range(len(mean)):
206 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
207 | print(line)
208 | print(line, file=f, flush=True)
209 |
210 |
211 | if __name__ == '__main__':
212 |
213 | # Create the parser
214 | parser = argparse.ArgumentParser(description="Argparse example with optional arguments")
215 |
216 | # Add optional arguments
217 | parser.add_argument('--model', type=str, required=True, help='Model Configuration file')
218 | parser.add_argument('--evaluator', type=str, required=True, help='Evaluator Configuration file')
219 | parser.add_argument('--out', type=str, required=True, help='Out file')
220 | parser.add_argument('--device', type=int, default=0, help='GPU device id')
221 | parser.add_argument('--mode', type=str, required=True, help='Mode of the inference (interaction, dual)')
222 |
223 |
224 | # Parse the arguments
225 | args = parser.parse_args()
226 |
227 | mm_num_samples = 100
228 | mm_num_repeats = 30
229 | mm_num_times = 10
230 |
231 | diversity_times = 300
232 | replication_times = 10
233 |
234 | # batch_size is fixed to 96!!
235 | batch_size = 96
236 |
237 | data_cfg = get_config("configs/datasets.yaml").interhuman_test
238 |
239 | eval_motion_loaders = {}
240 | model_cfg = get_config( args.model)
241 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu')
242 | torch.cuda.set_device(args.device)
243 |
244 | if args.mode == "dual":
245 | model = load_DualMDM_model(model_cfg)
246 | elif args.mode == "interaction":
247 | model = in2IN(model_cfg, args.mode)
248 | model.load_state_dict(torch.load(model_cfg.CHECKPOINT))
249 |
250 | eval_motion_loaders[model_cfg.NAME] = lambda: get_motion_loader_in2IN(
251 | batch_size,
252 | model,
253 | gt_dataset,
254 | device,
255 | mm_num_samples,
256 | mm_num_repeats
257 | )
258 |
259 | device = torch.device('cuda:%d' % args.device if torch.cuda.is_available() else 'cpu')
260 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size)
261 | evalmodel_cfg = get_config(args.evaluator)
262 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device)
263 | log_file = args.out
264 | evaluation(log_file)
265 |
--------------------------------------------------------------------------------
/in2in/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 | from typing import Union
4 | import pykeops.torch as keops
5 | import torch
6 | import tqdm
7 |
8 | emb_scale = 6
9 |
10 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
11 | def euclidean_distance_matrix(matrix1, matrix2):
12 | """
13 | Params:
14 | -- matrix1: N1 x D
15 | -- matrix2: N2 x D
16 | Returns:
17 | -- dist: N1 x N2
18 | dist[i, j] == distance(matrix1[i], matrix2[j])
19 | """
20 | assert matrix1.shape[1] == matrix2.shape[1]
21 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
22 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
23 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
24 | dists = np.sqrt(d1 + d2 + d3) # broadcasting
25 | return dists
26 |
27 | def calculate_top_k(mat, top_k):
28 | size = mat.shape[0]
29 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
30 | bool_mat = (mat == gt_mat)
31 | correct_vec = False
32 | top_k_list = []
33 | for i in range(top_k):
34 | # print(correct_vec, bool_mat[:, i])
35 | correct_vec = (correct_vec | bool_mat[:, i])
36 | # print(correct_vec)
37 | top_k_list.append(correct_vec[:, None])
38 | top_k_mat = np.concatenate(top_k_list, axis=1)
39 | return top_k_mat
40 |
41 |
42 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
43 | dist_mat = euclidean_distance_matrix(embedding1, embedding2)
44 | argmax = np.argsort(dist_mat, axis=1)
45 | top_k_mat = calculate_top_k(argmax, top_k)
46 | if sum_all:
47 | return top_k_mat.sum(axis=0)
48 | else:
49 | return top_k_mat
50 |
51 |
52 | def calculate_matching_score(embedding1, embedding2, sum_all=False):
53 | assert len(embedding1.shape) == 2
54 | assert embedding1.shape[0] == embedding2.shape[0]
55 | assert embedding1.shape[1] == embedding2.shape[1]
56 |
57 | dist = linalg.norm(embedding1 - embedding2, axis=1)
58 | if sum_all:
59 | return dist.sum(axis=0)
60 | else:
61 | return dist
62 |
63 | def calculate_activation_statistics(activations):
64 | """
65 | Params:
66 | -- activation: num_samples x dim_feat
67 | Returns:
68 | -- mu: dim_feat
69 | -- sigma: dim_feat x dim_feat
70 | """
71 | activations = activations * emb_scale
72 | mu = np.mean(activations, axis=0)
73 | cov = np.cov(activations, rowvar=False)
74 | return mu, cov
75 |
76 |
77 | def calculate_diversity(activation, diversity_times):
78 | assert len(activation.shape) == 2
79 | assert activation.shape[0] > diversity_times
80 | num_samples = activation.shape[0]
81 |
82 | activation = activation * emb_scale
83 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
84 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
85 | dist = linalg.norm((activation[first_indices] - activation[second_indices])/2, axis=1)
86 | return dist.mean()
87 |
88 |
89 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
90 | """Numpy implementation of the Frechet Distance.
91 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
92 | and X_2 ~ N(mu_2, C_2) is
93 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
94 | Stable version by Dougal J. Sutherland.
95 | Params:
96 | -- mu1 : Numpy array containing the activations of a layer of the
97 | inception net (like returned by the function 'get_predictions')
98 | for generated samples.
99 | -- mu2 : The sample mean over activations, precalculated on an
100 | representative data set.
101 | -- sigma1: The covariance matrix over activations for generated samples.
102 | -- sigma2: The covariance matrix over activations, precalculated on an
103 | representative data set.
104 | Returns:
105 | -- : The Frechet Distance.
106 | """
107 |
108 | mu1 = np.atleast_1d(mu1)
109 | mu2 = np.atleast_1d(mu2)
110 |
111 | sigma1 = np.atleast_2d(sigma1)
112 | sigma2 = np.atleast_2d(sigma2)
113 |
114 | assert mu1.shape == mu2.shape, \
115 | 'Training and test mean vectors have different lengths'
116 | assert sigma1.shape == sigma2.shape, \
117 | 'Training and test covariances have different dimensions'
118 |
119 | diff = mu1 - mu2
120 |
121 | # Product might be almost singular
122 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
123 | if not np.isfinite(covmean).all():
124 | msg = ('fid calculation produces singular product; '
125 | 'adding %s to diagonal of cov estimates') % eps
126 | print(msg)
127 | offset = np.eye(sigma1.shape[0]) * eps
128 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
129 |
130 | # Numerical error might give slight imaginary component
131 | if np.iscomplexobj(covmean):
132 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
133 | m = np.max(np.abs(covmean.imag))
134 | raise ValueError('Imaginary component {}'.format(m))
135 | covmean = covmean.real
136 |
137 | tr_covmean = np.trace(covmean)
138 |
139 | return (diff.dot(diff) + np.trace(sigma1) +
140 | np.trace(sigma2) - 2 * tr_covmean)
141 |
142 |
143 | def calculate_multimodality(activation, multimodality_times):
144 | assert len(activation.shape) == 3
145 | assert activation.shape[1] > multimodality_times
146 | num_per_sent = activation.shape[1]
147 |
148 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
149 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
150 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
151 | return dist.mean()
152 |
153 | def calculate_wasserstein(x: torch.Tensor, y: torch.Tensor, p: float = 2,
154 | w_x: Union[torch.Tensor, None] = None,
155 | w_y: Union[torch.Tensor, None] = None,
156 | eps: float = 1e-3,
157 | max_iters: int = 100, stop_thresh: float = 1e-5,
158 | verbose=False):
159 | """
160 | Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds
161 | using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors.
162 | Note that this algorithm can be backpropped through
163 | (though this may be slow if using many iterations).
164 |
165 | :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row)
166 | :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row)
167 | :param p: Which norm to use. Must be an integer greater than 0.
168 | :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
169 | :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
170 | :param eps: The reciprocal of the sinkhorn entropy regularization parameter.
171 | :param max_iters: The maximum number of Sinkhorn iterations to perform.
172 | :param stop_thresh: Stop if the maximum change in the parameters is below this amount
173 | :param verbose: Print iterations
174 | :return: a triple (d, corrs_x_to_y, corr_y_to_x) where:
175 | * d is the approximate p-wasserstein distance between point clouds x and y
176 | * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair)
177 | * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair)
178 | """
179 |
180 | if not isinstance(p, int):
181 | raise TypeError(f"p must be an integer greater than 0, got {p}")
182 | if p <= 0:
183 | raise ValueError(f"p must be an integer greater than 0, got {p}")
184 |
185 | if eps <= 0:
186 | raise ValueError("Entropy regularization term eps must be > 0")
187 |
188 | if not isinstance(p, int):
189 | raise TypeError(f"max_iters must be an integer > 0, got {max_iters}")
190 | if max_iters <= 0:
191 | raise ValueError(f"max_iters must be an integer > 0, got {max_iters}")
192 |
193 | if not isinstance(stop_thresh, float):
194 | raise TypeError(f"stop_thresh must be a float, got {stop_thresh}")
195 |
196 | if len(x.shape) != 2:
197 | raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}")
198 | if len(y.shape) != 2:
199 | raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}")
200 | if x.shape[1] != y.shape[1]:
201 | raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], "
202 | f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}")
203 |
204 | if w_x is not None:
205 | if w_y is None:
206 | raise ValueError("If w_x is not None, w_y must also be not None")
207 | if len(w_x.shape) > 1:
208 | w_x = w_x.squeeze()
209 | if len(w_x.shape) != 1:
210 | raise ValueError(f"w_x must have shape [n,] or [n, 1] "
211 | f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}")
212 | if w_x.shape[0] != x.shape[0]:
213 | raise ValueError(f"w_x must match the shape of x in dimension 0 but got "
214 | f"x.shape = {x.shape} and w_x.shape = {w_x.shape}")
215 | if w_y is not None:
216 | if w_x is None:
217 | raise ValueError("If w_y is not None, w_x must also be not None")
218 | if len(w_y.shape) > 1:
219 | w_y = w_y.squeeze()
220 | if len(w_y.shape) != 1:
221 | raise ValueError(f"w_y must have shape [n,] or [n, 1] "
222 | f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}")
223 | if w_x.shape[0] != x.shape[0]:
224 | raise ValueError(f"w_y must match the shape of y in dimension 0 but got "
225 | f"y.shape = {y.shape} and w_y.shape = {w_y.shape}")
226 |
227 |
228 | # Distance matrix [n, m]
229 | x_i = keops.Vi(x) # [n, 1, d]
230 | y_j = keops.Vj(y) # [i, m, d]
231 | if p == 1:
232 | M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m]
233 | else:
234 | M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p) # [n, m]
235 |
236 | # Weights [n,] and [m,]
237 | if w_x is None and w_y is None:
238 | w_x = torch.ones(x.shape[0]).to(x) / x.shape[0]
239 | w_y = torch.ones(y.shape[0]).to(x) / y.shape[0]
240 | w_y *= (w_x.shape[0] / w_y.shape[0])
241 |
242 | sum_w_x = w_x.sum().item()
243 | sum_w_y = w_y.sum().item()
244 | if abs(sum_w_x - sum_w_y) > 1e-5:
245 | raise ValueError(f"Weights w_x and w_y do not sum to the same value, "
246 | f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} "
247 | f"(absolute difference = {abs(sum_w_x - sum_w_y)}")
248 |
249 | log_a = torch.log(w_x) # [n]
250 | log_b = torch.log(w_y) # [m]
251 |
252 | # Initialize the iteration with the change of variable
253 | u = torch.zeros_like(w_x)
254 | v = eps * torch.log(w_y)
255 |
256 | u_i = keops.Vi(u.unsqueeze(-1))
257 | v_j = keops.Vj(v.unsqueeze(-1))
258 |
259 | if verbose:
260 | pbar = tqdm.trange(max_iters)
261 | else:
262 | pbar = range(max_iters)
263 |
264 | for _ in pbar:
265 | u_prev = u
266 | v_prev = v
267 |
268 | summand_u = (-M_ij + v_j) / eps
269 | u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze())
270 | u_i = keops.Vi(u.unsqueeze(-1))
271 |
272 | summand_v = (-M_ij + u_i) / eps
273 | v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze())
274 | v_j = keops.Vj(v.unsqueeze(-1))
275 |
276 | max_err_u = torch.max(torch.abs(u_prev-u))
277 | max_err_v = torch.max(torch.abs(v_prev-v))
278 | if verbose:
279 | pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()})
280 | if max_err_u < stop_thresh and max_err_v < stop_thresh:
281 | break
282 |
283 | P_ij = ((-M_ij + u_i + v_j) / eps).exp()
284 |
285 | approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1)
286 | approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1)
287 |
288 | if u.shape[0] > v.shape[0]:
289 | distance = (P_ij * M_ij).sum(dim=1).sum()
290 | else:
291 | distance = (P_ij * M_ij).sum(dim=0).sum()
292 | return distance, approx_corr_1, approx_corr_2
293 |
--------------------------------------------------------------------------------
/in2in/utils/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | _EPS4 = np.finfo(np.float32).eps * 4.0
12 |
13 | _FLOAT_EPS = np.finfo(np.float32).eps
14 |
15 | # PyTorch-backed implementations
16 | def qinv(q):
17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18 | mask = torch.ones_like(q)
19 | mask[..., 1:] = -mask[..., 1:]
20 | return q * mask
21 |
22 |
23 | def qinv_np(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | return qinv(torch.from_numpy(q).float()).numpy()
26 |
27 |
28 | def qnormalize(q):
29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30 | return q / torch.norm(q, dim=-1, keepdim=True)
31 |
32 |
33 | def qmul(q, r):
34 | """
35 | Multiply quaternion(s) q with quaternion(s) r.
36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37 | Returns q*r as a tensor of shape (*, 4).
38 | """
39 | assert q.shape[-1] == 4
40 | assert r.shape[-1] == 4
41 |
42 | original_shape = q.shape
43 |
44 | # Compute outer product
45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46 |
47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
52 |
53 |
54 | def qrot(q, v):
55 | """
56 | Rotate vector(s) v about the rotation described by quaternion(s) q.
57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58 | where * denotes any number of dimensions.
59 | Returns a tensor of shape (*, 3).
60 | """
61 | assert q.shape[-1] == 4
62 | assert v.shape[-1] == 3
63 | assert q.shape[:-1] == v.shape[:-1]
64 |
65 | original_shape = list(v.shape)
66 | # print(q.shape)
67 | q = q.contiguous().view(-1, 4)
68 | v = v.contiguous().view(-1, 3)
69 |
70 | qvec = q[:, 1:]
71 | uv = torch.cross(qvec, v, dim=1)
72 | uuv = torch.cross(qvec, uv, dim=1)
73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74 |
75 |
76 | def qeuler(q, order, epsilon=0, deg=True):
77 | """
78 | Convert quaternion(s) q to Euler angles.
79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80 | Returns a tensor of shape (*, 3).
81 | """
82 | assert q.shape[-1] == 4
83 |
84 | original_shape = list(q.shape)
85 | original_shape[-1] = 3
86 | q = q.view(-1, 4)
87 |
88 | q0 = q[:, 0]
89 | q1 = q[:, 1]
90 | q2 = q[:, 2]
91 | q3 = q[:, 3]
92 |
93 | if order == 'xyz':
94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | elif order == 'yzx':
98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101 | elif order == 'zxy':
102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105 | elif order == 'xzy':
106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109 | elif order == 'yxz':
110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113 | elif order == 'zyx':
114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117 | else:
118 | raise
119 |
120 | if deg:
121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122 | else:
123 | return torch.stack((x, y, z), dim=1).view(original_shape)
124 |
125 |
126 | # Numpy-backed implementations
127 | def qmul_np(q, r):
128 | q = torch.from_numpy(q).contiguous().float()
129 | r = torch.from_numpy(r).contiguous().float()
130 | return qmul(q, r).numpy()
131 |
132 |
133 | def qrot_np(q, v):
134 | q = torch.from_numpy(q).contiguous().float()
135 | v = torch.from_numpy(v).contiguous().float()
136 | return qrot(q, v).numpy()
137 |
138 |
139 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
140 | if use_gpu:
141 | q = torch.from_numpy(q).cuda().float()
142 | return qeuler(q, order, epsilon).cpu().numpy()
143 | else:
144 | q = torch.from_numpy(q).contiguous().float()
145 | return qeuler(q, order, epsilon).numpy()
146 |
147 |
148 | def qfix(q):
149 | """
150 | Enforce quaternion continuity across the time dimension by selecting
151 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
152 | between two consecutive frames.
153 |
154 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
155 | Returns a tensor of the same shape.
156 | """
157 | assert len(q.shape) == 3
158 | assert q.shape[-1] == 4
159 |
160 | result = q.copy()
161 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
162 | mask = dot_products < 0
163 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
164 | result[1:][mask] *= -1
165 | return result
166 |
167 |
168 | def euler2quat(e, order, deg=True):
169 | """
170 | Convert Euler angles to quaternions.
171 | """
172 | assert e.shape[-1] == 3
173 |
174 | original_shape = list(e.shape)
175 | original_shape[-1] = 4
176 |
177 | e = e.view(-1, 3)
178 |
179 | ## if euler angles in degrees
180 | if deg:
181 | e = e * np.pi / 180.
182 |
183 | x = e[:, 0]
184 | y = e[:, 1]
185 | z = e[:, 2]
186 |
187 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
188 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
189 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
190 |
191 | result = None
192 | for coord in order:
193 | if coord == 'x':
194 | r = rx
195 | elif coord == 'y':
196 | r = ry
197 | elif coord == 'z':
198 | r = rz
199 | else:
200 | raise
201 | if result is None:
202 | result = r
203 | else:
204 | result = qmul(result, r)
205 |
206 | # Reverse antipodal representation to have a non-negative "w"
207 | if order in ['xyz', 'yzx', 'zxy']:
208 | result *= -1
209 |
210 | return result.view(original_shape)
211 |
212 |
213 | def expmap_to_quaternion(e):
214 | """
215 | Convert axis-angle rotations (aka exponential maps) to quaternions.
216 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
217 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
218 | Returns a tensor of shape (*, 4).
219 | """
220 | assert e.shape[-1] == 3
221 |
222 | original_shape = list(e.shape)
223 | original_shape[-1] = 4
224 | e = e.reshape(-1, 3)
225 |
226 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
227 | w = np.cos(0.5 * theta).reshape(-1, 1)
228 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
229 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
230 |
231 |
232 | def euler_to_quaternion(e, order):
233 | """
234 | Convert Euler angles to quaternions.
235 | """
236 | assert e.shape[-1] == 3
237 |
238 | original_shape = list(e.shape)
239 | original_shape[-1] = 4
240 |
241 | e = e.reshape(-1, 3)
242 |
243 | x = e[:, 0]
244 | y = e[:, 1]
245 | z = e[:, 2]
246 |
247 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
248 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
249 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
250 |
251 | result = None
252 | for coord in order:
253 | if coord == 'x':
254 | r = rx
255 | elif coord == 'y':
256 | r = ry
257 | elif coord == 'z':
258 | r = rz
259 | else:
260 | raise
261 | if result is None:
262 | result = r
263 | else:
264 | result = qmul_np(result, r)
265 |
266 | # Reverse antipodal representation to have a non-negative "w"
267 | if order in ['xyz', 'yzx', 'zxy']:
268 | result *= -1
269 |
270 | return result.reshape(original_shape)
271 |
272 |
273 | def quaternion_to_matrix(quaternions):
274 | """
275 | Convert rotations given as quaternions to rotation matrices.
276 | Args:
277 | quaternions: quaternions with real part first,
278 | as tensor of shape (..., 4).
279 | Returns:
280 | Rotation matrices as tensor of shape (..., 3, 3).
281 | """
282 | r, i, j, k = torch.unbind(quaternions, -1)
283 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
284 |
285 | o = torch.stack(
286 | (
287 | 1 - two_s * (j * j + k * k),
288 | two_s * (i * j - k * r),
289 | two_s * (i * k + j * r),
290 | two_s * (i * j + k * r),
291 | 1 - two_s * (i * i + k * k),
292 | two_s * (j * k - i * r),
293 | two_s * (i * k - j * r),
294 | two_s * (j * k + i * r),
295 | 1 - two_s * (i * i + j * j),
296 | ),
297 | -1,
298 | )
299 | return o.reshape(quaternions.shape[:-1] + (3, 3))
300 |
301 |
302 | def quaternion_to_matrix_np(quaternions):
303 | q = torch.from_numpy(quaternions).contiguous().float()
304 | return quaternion_to_matrix(q).numpy()
305 |
306 |
307 | def quaternion_to_cont6d_np(quaternions):
308 | rotation_mat = quaternion_to_matrix_np(quaternions)
309 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
310 | return cont_6d
311 |
312 |
313 | def quaternion_to_cont6d(quaternions):
314 | rotation_mat = quaternion_to_matrix(quaternions)
315 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
316 | return cont_6d
317 |
318 |
319 | def cont6d_to_matrix(cont6d):
320 | assert cont6d.shape[-1] == 6, "The last dimension must be 6"
321 | x_raw = cont6d[..., 0:3]
322 | y_raw = cont6d[..., 3:6]
323 |
324 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
325 | z = torch.cross(x, y_raw, dim=-1)
326 | z = z / torch.norm(z, dim=-1, keepdim=True)
327 |
328 | y = torch.cross(z, x, dim=-1)
329 |
330 | x = x[..., None]
331 | y = y[..., None]
332 | z = z[..., None]
333 |
334 | mat = torch.cat([x, y, z], dim=-1)
335 | return mat
336 |
337 |
338 | def cont6d_to_matrix_np(cont6d):
339 | q = torch.from_numpy(cont6d).contiguous().float()
340 | return cont6d_to_matrix(q).numpy()
341 |
342 |
343 | def qpow(q0, t, dtype=torch.float):
344 | ''' q0 : tensor of quaternions
345 | t: tensor of powers
346 | '''
347 | q0 = qnormalize(q0)
348 | theta0 = torch.acos(q0[..., 0])
349 |
350 | ## if theta0 is close to zero, add epsilon to avoid NaNs
351 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
352 | theta0 = (1 - mask) * theta0 + mask * 10e-10
353 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
354 |
355 | if isinstance(t, torch.Tensor):
356 | q = torch.zeros(t.shape + q0.shape)
357 | theta = t.view(-1, 1) * theta0.view(1, -1)
358 | else: ## if t is a number
359 | q = torch.zeros(q0.shape)
360 | theta = t * theta0
361 |
362 | q[..., 0] = torch.cos(theta)
363 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
364 |
365 | return q.to(dtype)
366 |
367 |
368 | def qslerp(q0, q1, t):
369 | '''
370 | q0: starting quaternion
371 | q1: ending quaternion
372 | t: array of points along the way
373 |
374 | Returns:
375 | Tensor of Slerps: t.shape + q0.shape
376 | '''
377 |
378 | q0 = qnormalize(q0)
379 | q1 = qnormalize(q1)
380 | q_ = qpow(qmul(q1, qinv(q0)), t)
381 |
382 | return qmul(q_,
383 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
384 |
385 |
386 | def qbetween(v0, v1):
387 | '''
388 | find the quaternion used to rotate v0 to v1
389 | '''
390 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
391 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
392 |
393 | v = torch.cross(v0, v1)
394 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
395 | keepdim=True)
396 | return qnormalize(torch.cat([w, v], dim=-1))
397 |
398 |
399 | def qbetween_np(v0, v1):
400 | '''
401 | find the quaternion used to rotate v0 to v1
402 | '''
403 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
404 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
405 |
406 | v0 = torch.from_numpy(v0).float()
407 | v1 = torch.from_numpy(v1).float()
408 | return qbetween(v0, v1).numpy()
409 |
410 |
411 | def lerp(p0, p1, t):
412 | if not isinstance(t, torch.Tensor):
413 | t = torch.Tensor([t])
414 |
415 | new_shape = t.shape + p0.shape
416 | new_view_t = t.shape + torch.Size([1] * len(p0.shape))
417 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
418 | p0 = p0.view(new_view_p).expand(new_shape)
419 | p1 = p1.view(new_view_p).expand(new_shape)
420 | t = t.view(new_view_t).expand(new_shape)
421 |
422 | return p0 + t * (p1 - p0)
423 |
--------------------------------------------------------------------------------
/in2in/models/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from in2in.utils.utils import *
4 |
5 |
6 | class InterLoss(nn.Module):
7 | def __init__(self, recons_loss, nb_joints):
8 | super(InterLoss, self).__init__()
9 | self.nb_joints = nb_joints
10 | if recons_loss == 'l1':
11 | self.Loss = torch.nn.L1Loss(reduction='none')
12 | elif recons_loss == 'l2':
13 | self.Loss = torch.nn.MSELoss(reduction='none')
14 | elif recons_loss == 'l1_smooth':
15 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
16 |
17 | self.normalizer = MotionNormalizerTorch()
18 |
19 | self.weights = {}
20 | self.weights["RO"] = 0.01
21 | self.weights["JA"] = 3
22 | self.weights["DM"] = 3
23 |
24 | self.losses = {}
25 |
26 | def seq_masked_mse(self, prediction, target, mask):
27 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
28 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7)
29 | return loss
30 |
31 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None):
32 | if dm_mask is not None:
33 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7)
34 | else:
35 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1]
36 | if contact_mask is not None:
37 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7)
38 | loss = (loss * mask).sum(dim=(-1, -2, -3)) / (mask.sum(dim=(-1, -2, -3)) + 1.e-7) # [b]
39 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7)
40 |
41 | return loss
42 |
43 | def forward(self, motion_pred, motion_gt, mask, timestep_mask):
44 | B, T = motion_pred.shape[:2]
45 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask)
46 | target = self.normalizer.backward(motion_gt, global_rt=True)
47 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
48 |
49 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
50 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
51 |
52 | self.mask = mask
53 | self.timestep_mask = timestep_mask
54 |
55 | self.forward_distance_map(thresh=1)
56 | self.forward_joint_affinity(thresh=0.1)
57 | self.forward_relatvie_rot()
58 | self.accum_loss()
59 |
60 |
61 | def forward_relatvie_rot(self):
62 | r_hip, l_hip, sdr_r, sdr_l = FACE_JOINT_INDX
63 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :]
64 | across = across / across.norm(dim=-1, keepdim=True)
65 | across_gt = self.tgt_g_joints[..., r_hip, :] - self.tgt_g_joints[..., l_hip, :]
66 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)
67 |
68 | y_axis = torch.zeros_like(across)
69 | y_axis[..., 1] = 1
70 |
71 | forward = torch.cross(y_axis, across, axis=-1)
72 | forward = forward / forward.norm(dim=-1, keepdim=True)
73 | forward_gt = torch.cross(y_axis, across_gt, axis=-1)
74 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)
75 |
76 | pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
77 | tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])
78 |
79 | self.losses["RO"] = self.mix_masked_mse(pred_relative_rot[..., [0, 2]],
80 | tgt_relative_rot[..., [0, 2]],
81 | self.mask[..., 0, :], self.timestep_mask) * self.weights["RO"]
82 |
83 |
84 | def forward_distance_map(self, thresh):
85 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,))
86 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,))
87 |
88 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
89 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
90 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
91 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
92 |
93 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape(
94 | self.mask.shape[:-2] + (1, -1,))
95 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape(
96 | self.mask.shape[:-2] + (1, -1,))
97 |
98 | distance_matrix_mask = (pred_distance_matrix < thresh).float()
99 |
100 | self.losses["DM"] = self.mix_masked_mse(pred_distance_matrix, tgt_distance_matrix,
101 | self.mask[..., 0:1, :],
102 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["DM"]
103 |
104 | def forward_joint_affinity(self, thresh):
105 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,))
106 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,))
107 |
108 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
109 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
110 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
111 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
112 |
113 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape(
114 | self.mask.shape[:-2] + (1, -1,))
115 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape(
116 | self.mask.shape[:-2] + (1, -1,))
117 |
118 | distance_matrix_mask = (tgt_distance_matrix < thresh).float()
119 |
120 | self.losses["JA"] = self.mix_masked_mse(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix),
121 | self.mask[..., 0:1, :],
122 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["JA"]
123 |
124 | def accum_loss(self):
125 | loss = 0
126 | for term in self.losses.keys():
127 | loss += self.losses[term]
128 | self.losses["total"] = loss
129 | return self.losses
130 |
131 |
132 |
133 | class GeometricLoss(nn.Module):
134 | def __init__(self, recons_loss, nb_joints, name, mode="interaction"):
135 | super(GeometricLoss, self).__init__()
136 | self.mode = mode
137 | self.name = name
138 | self.nb_joints = nb_joints
139 | if recons_loss == 'l1':
140 | self.Loss = torch.nn.L1Loss(reduction='none')
141 | elif recons_loss == 'l2':
142 | self.Loss = torch.nn.MSELoss(reduction='none')
143 | elif recons_loss == 'l1_smooth':
144 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
145 |
146 | if mode == "individual":
147 | self.normalizer = MotionNormalizerTorchHML3D()
148 | else:
149 | self.normalizer = MotionNormalizerTorch()
150 |
151 | self.fids = [7, 10, 8, 11]
152 |
153 | self.weights = {}
154 | self.weights["VEL"] = 30
155 | self.weights["BL"] = 10
156 | self.weights["FC"] = 30
157 | self.weights["POSE"] = 1
158 | self.weights["TR"] = 100
159 |
160 | self.losses = {}
161 |
162 | def seq_masked_mse(self, prediction, target, mask):
163 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
164 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7)
165 | return loss
166 |
167 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None):
168 | if dm_mask is not None:
169 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) # [b,t,p,4,1]
170 | else:
171 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1]
172 | if contact_mask is not None:
173 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7)
174 | loss = (loss * mask).sum(dim=(-1, -2)) / (mask.sum(dim=(-1, -2)) + 1.e-7) # [b]
175 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7)
176 |
177 | return loss
178 |
179 | def forward(self, motion_pred, motion_gt, mask, timestep_mask):
180 | B, T = motion_pred.shape[:2]
181 |
182 | if self.mode == "individual":
183 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask)
184 |
185 | target = self.normalizer.backward(motion_gt, global_rt=True)
186 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
187 |
188 | self.first_motion_pred =motion_pred[:,0:1]
189 | self.first_motion_gt =motion_gt[:,0:1]
190 |
191 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3)
192 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3)
193 | self.mask = mask
194 | self.timestep_mask = timestep_mask
195 |
196 | if self.mode != "individual":
197 | self.forward_vel()
198 | self.forward_bone_length()
199 | self.forward_contact()
200 |
201 | self.accum_loss()
202 |
203 | def get_local_positions(self, positions, r_rot):
204 | '''Local pose'''
205 | positions[..., 0] -= positions[..., 0:1, 0]
206 | positions[..., 2] -= positions[..., 0:1, 2]
207 | '''All pose face Z+'''
208 | positions = qrot(r_rot[..., None, :].repeat(1, 1, positions.shape[-2], 1), positions)
209 | return positions
210 |
211 | def forward_local_pose(self):
212 | r_hip, l_hip, sdr_r, sdr_l = FACE_JOINT_INDX
213 |
214 | pred_g_joints = self.pred_g_joints.clone()
215 | tgt_g_joints = self.tgt_g_joints.clone()
216 |
217 | across = pred_g_joints[..., r_hip, :] - pred_g_joints[..., l_hip, :]
218 | across = across / across.norm(dim=-1, keepdim=True)
219 | across_gt = tgt_g_joints[..., r_hip, :] - tgt_g_joints[..., l_hip, :]
220 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)
221 |
222 | y_axis = torch.zeros_like(across)
223 | y_axis[..., 1] = 1
224 |
225 | forward = torch.cross(y_axis, across, axis=-1)
226 | forward = forward / forward.norm(dim=-1, keepdim=True)
227 | forward_gt = torch.cross(y_axis, across_gt, axis=-1)
228 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)
229 |
230 | z_axis = torch.zeros_like(forward)
231 | z_axis[..., 2] = 1
232 | noise = torch.randn_like(z_axis) *0.0001
233 | z_axis = z_axis+noise
234 | z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
235 |
236 |
237 | pred_rot = qbetween(forward, z_axis)
238 | tgt_rot = qbetween(forward_gt, z_axis)
239 |
240 | B, T, J, D = self.pred_g_joints.shape
241 | pred_joints = self.get_local_positions(pred_g_joints, pred_rot).reshape(B, T, -1)
242 | tgt_joints = self.get_local_positions(tgt_g_joints, tgt_rot).reshape(B, T, -1)
243 |
244 | self.losses["POSE_"+self.name] = self.mix_masked_mse(pred_joints, tgt_joints, self.mask, self.timestep_mask) * self.weights["POSE"]
245 |
246 | def forward_vel(self):
247 | B, T = self.pred_g_joints.shape[:2]
248 |
249 | pred_vel = self.pred_g_joints[:, 1:] - self.pred_g_joints[:, :-1]
250 | tgt_vel = self.tgt_g_joints[:, 1:] - self.tgt_g_joints[:, :-1]
251 |
252 | pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,))
253 | tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,))
254 |
255 | self.losses["VEL_"+self.name] = self.mix_masked_mse(pred_vel, tgt_vel, self.mask[:, :-1], self.timestep_mask) * self.weights["VEL"]
256 |
257 |
258 | def forward_contact(self):
259 |
260 | feet_vel = self.pred_g_joints[:, 1:, self.fids, :] - self.pred_g_joints[:, :-1, self.fids,:]
261 | feet_h = self.pred_g_joints[:, :-1, self.fids, 1]
262 |
263 | contact = self.foot_detect(feet_vel, feet_h, 0.001)
264 |
265 | self.losses["FC_"+self.name] = self.mix_masked_mse(feet_vel, torch.zeros_like(feet_vel), self.mask[:, :-1],
266 | self.timestep_mask,
267 | contact) * self.weights["FC"]
268 |
269 | def forward_bone_length(self):
270 | pred_g_joints = self.pred_g_joints
271 | tgt_g_joints = self.tgt_g_joints
272 | pred_bones = []
273 | tgt_bones = []
274 | for chain in HML_KINEMATIC_CHAIN:
275 | for i, joint in enumerate(chain[:-1]):
276 | pred_bone = (pred_g_joints[..., chain[i], :] - pred_g_joints[..., chain[i + 1], :]).norm(dim=-1,
277 | keepdim=True) # [B,T,P,1]
278 | tgt_bone = (tgt_g_joints[..., chain[i], :] - tgt_g_joints[..., chain[i + 1], :]).norm(dim=-1,
279 | keepdim=True)
280 | pred_bones.append(pred_bone)
281 | tgt_bones.append(tgt_bone)
282 |
283 | pred_bones = torch.cat(pred_bones, dim=-1)
284 | tgt_bones = torch.cat(tgt_bones, dim=-1)
285 |
286 | self.losses["BL_"+self.name] = self.mix_masked_mse(pred_bones, tgt_bones, self.mask, self.timestep_mask) * self.weights[
287 | "BL"]
288 |
289 |
290 | def forward_traj(self):
291 | B, T = self.pred_g_joints.shape[:2]
292 |
293 | pred_traj = self.pred_g_joints[..., 0, [0, 2]]
294 | tgt_g_traj = self.tgt_g_joints[..., 0, [0, 2]]
295 |
296 | self.losses["TR_"+self.name] = self.mix_masked_mse(pred_traj, tgt_g_traj, self.mask, self.timestep_mask) * self.weights["TR"]
297 |
298 |
299 | def accum_loss(self):
300 | loss = 0
301 | for term in self.losses.keys():
302 | loss += self.losses[term]
303 | self.losses[self.name] = loss
304 |
305 | def foot_detect(self, feet_vel, feet_h, thres):
306 | velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor(
307 | [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device)
308 |
309 | feet_x = (feet_vel[..., 0]) ** 2
310 | feet_y = (feet_vel[..., 1]) ** 2
311 | feet_z = (feet_vel[..., 2]) ** 2
312 |
313 | contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float()
314 | return contact
--------------------------------------------------------------------------------
/in2in/evaluation/datasets.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | import torch
4 | import os
5 | import numpy as np
6 | from tqdm import tqdm
7 | from os.path import join as pjoin
8 | from torch.utils.data import Dataset, DataLoader
9 | from in2in.utils.utils import MotionNormalizer
10 |
11 | class EvaluationDatasetInterHuman(Dataset):
12 | """
13 | Evaluation Dataset of InterHuman.
14 | Motions are generated by the trained model to later be compared with the ground truth.
15 | """
16 | def __init__(self, model, dataset, device, mm_num_samples, mm_num_repeats):
17 | """
18 | Initialization of the dataset and generation of the motions.
19 | :param model: Model to generate the motions.
20 | :param dataset: Ground truth dataset.
21 | :param device: Device to run the model.
22 | :param mm_num_samples: Number of samples to generate for the MultiModality metric.
23 | :param mm_num_repeats: Number of repeats for each sample in the MultiModality metric.
24 | """
25 |
26 | # Configuration variables
27 | self.normalizer = MotionNormalizer()
28 | self.model = model.to(device)
29 | self.model.eval()
30 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)
31 | self.max_length = dataset.max_length
32 | self.extended = dataset.extended
33 |
34 | # Indexes of the motions to generate
35 | idxs = list(range(len(dataset)))
36 | random.shuffle(idxs)
37 | mm_idxs = idxs[:mm_num_samples]
38 |
39 | # Data structures
40 | generated_motions = []
41 | mm_generated_motions = []
42 |
43 | with torch.no_grad():
44 | for i, data in tqdm(enumerate(dataloader)):
45 |
46 | # Get the data from the data loader
47 | if self.extended:
48 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = data
49 | else:
50 | name, text, motion1, motion2, motion_lens = data
51 |
52 | # If data into MM list, duplicate them mm_num_repeats times
53 | batch = {}
54 | if i in mm_idxs:
55 | batch["text"] = list(text) * mm_num_repeats
56 | if self.extended:
57 | batch["text_individual1"] = list(text_individual1) * mm_num_repeats
58 | batch["text_individual2"] = list(text_individual2) * mm_num_repeats
59 | else:
60 | batch["text"] = list(text)
61 | if self.extended:
62 | batch["text_individual1"] = list(text_individual1)
63 | batch["text_individual2"] = list(text_individual2)
64 |
65 | batch["motion_lens"] = motion_lens
66 |
67 | # Predict the motions with the model to be evaluated
68 | batch = self.model.forward_test(batch)
69 | motions_output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], 2, -1)
70 | motions_output = self.normalizer.backward(motions_output.cpu().detach().numpy())
71 |
72 | # Padding the motions to the max_length
73 | B,T = motions_output.shape[0], motions_output.shape[1]
74 | if T < self.max_length:
75 | padding_len = self.max_length - T
76 | D = motions_output.shape[-1]
77 | padding_zeros = np.zeros((B, padding_len, 2, D))
78 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1)
79 | assert motions_output.shape[1] == self.max_length
80 |
81 | # Save the generated motions
82 | if self.extended:
83 | sub_dict = {'motion1': motions_output[0, :,0],
84 | 'motion2': motions_output[0, :,1],
85 | 'motion_lens': motion_lens[0],
86 | 'text': text[0],
87 | 'text_individual1': text_individual1[0],
88 | 'text_individual2': text_individual2[0]}
89 | generated_motions.append(sub_dict)
90 | if i in mm_idxs:
91 | mm_sub_dict = {'mm_motions': motions_output,
92 | 'motion_lens': motion_lens[0],
93 | 'text': text[0],
94 | 'text_individual1': text_individual1[0],
95 | 'text_individual2': text_individual2[0]}
96 | mm_generated_motions.append(mm_sub_dict)
97 | else:
98 | sub_dict = {'motion1': motions_output[0, :,0],
99 | 'motion2': motions_output[0, :,1],
100 | 'motion_lens': motion_lens[0],
101 | 'text': text[0]}
102 | generated_motions.append(sub_dict)
103 | if i in mm_idxs:
104 | mm_sub_dict = {'mm_motions': motions_output,
105 | 'motion_lens': motion_lens[0],
106 | 'text': text[0]}
107 | mm_generated_motions.append(mm_sub_dict)
108 |
109 |
110 | self.generated_motions = generated_motions
111 | self.mm_generated_motions = mm_generated_motions
112 |
113 | def __len__(self):
114 | """
115 | Get the length of the dataset.
116 | """
117 | return len(self.generated_motions)
118 |
119 | def __getitem__(self, item):
120 | """
121 | Get the item of the dataset.
122 | :param item: Index of the item.
123 | """
124 | data = self.generated_motions[item]
125 |
126 | # Return the data, if dataset extended also return individual descriptions
127 | if self.extended:
128 | motion1, motion2, motion_lens, text, text_individual1, text_individual2 = data['motion1'], data['motion2'], data['motion_lens'], data['text'], data['text_individual1'], data['text_individual2']
129 | return "generated", text, motion1, motion2, motion_lens, text_individual1, text_individual2
130 | else:
131 | motion1, motion2, motion_lens, text = data['motion1'], data['motion2'], data['motion_lens'], data['text']
132 | return "generated", text, motion1, motion2, motion_lens
133 |
134 |
135 | class MMGeneratedDatasetInterHuman(Dataset):
136 | """
137 | Dataset for the MultiModality metric.
138 | """
139 | def __init__(self, motion_dataset):
140 | """
141 | Initialization of the dataset.
142 | :param motion_dataset: EvaluationDataset of the generated motions.
143 | """
144 | self.dataset = motion_dataset.mm_generated_motions
145 | self.extended = motion_dataset.extended
146 |
147 | def __len__(self):
148 | """
149 | Get the length of the dataset.
150 | """
151 | return len(self.dataset)
152 |
153 | def __getitem__(self, item):
154 | """
155 | Get the item of the dataset.
156 | :param item: Index of the item.
157 | """
158 | data = self.dataset[item]
159 | mm_motions = data['mm_motions']
160 | motion_lens = data['motion_lens']
161 | mm_motions1 = mm_motions[:,:,0]
162 | mm_motions2 = mm_motions[:,:,1]
163 | text = data['text']
164 | motion_lens = np.array([motion_lens]*mm_motions1.shape[0])
165 |
166 | # If dataset extended also return individual descriptions
167 | if self.extended:
168 | text_individual1 = data['text_individual1']
169 | text_individual2 = data['text_individual2']
170 | return "mm_generated", text, mm_motions1, mm_motions2, motion_lens, text_individual1, text_individual2
171 | else:
172 | return "mm_generated", text, mm_motions1, mm_motions2, motion_lens
173 |
174 |
175 |
176 | class EvaluationDatasetDualMDM(Dataset):
177 | """
178 | Evaluation Dataset of DualMDM.
179 | Motions are generated by the trained model to later be compared with the ground truth.
180 | The dataset is generated by combining interaction descriptions from InterHuman with individual descriptions from HumanML3D
181 | """
182 | def __init__(self, model, dataset, device, num_repeats):
183 | """
184 | Initialization of the dataset and generation of the motions.
185 | :param model: Model to generate the motions.
186 | :param dataset: Ground truth dataset.
187 | :param device: Device to run the model.
188 | :param num_repeats: Number of repeats for each sample.
189 | """
190 | self.normalizer = MotionNormalizer()
191 | self.model = model.to(device)
192 | self.model.eval()
193 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)
194 | self.max_length = dataset.max_length
195 |
196 | # Paths to the individual descriptions from HumanML3D
197 | self.individual_text_path = "data/HumanML3D/texts"
198 | self.individual_text_files = os.listdir(self.individual_text_path)
199 |
200 | # Data structures
201 | generated_motions = []
202 |
203 | # Model parameters
204 | composition_weight_func = self.model.decoder.cfg_composition_weight_func
205 | composition_weight_value = self.model.decoder.cfg_composition_weight_value
206 |
207 | with torch.no_grad():
208 | for i, data in tqdm(enumerate(dataloader)):
209 |
210 | # Get the data from the gt InterHuman dataset
211 | name, text, motion1, motion2, motion_lens, text_individual1, text_individual2 = data
212 |
213 | batch = {}
214 | batch["motion_lens"] = motion_lens.repeat(num_repeats * 2)
215 | batch["text"] = list(text) * (num_repeats * 2)
216 | batch["text_individual1"] = list(text_individual1) * num_repeats
217 | batch["text_individual2"] = list(text_individual2) * num_repeats
218 |
219 | # Modify individual textual description to be from HumanML3D
220 | for j in range(num_repeats):
221 | # Select 2 random files from HumanML3D for the individual descriptions
222 | files = random.sample(self.individual_text_files, 2)
223 |
224 | with open(pjoin(self.individual_text_path, files[0]), "r") as f:
225 | hml3d_texts_individual1 = f.readlines()
226 | hml3d_text_individual1 = random.choice(hml3d_texts_individual1)
227 | hml3d_text_individual1 = hml3d_text_individual1.strip().split("#")[0]
228 | batch["text_individual1"].append(hml3d_text_individual1)
229 |
230 | with open(pjoin(self.individual_text_path, files[1]), "r") as f:
231 | hml3d_texts_individual2 = f.readlines()
232 | hml3d_text_individual2 = random.choice(hml3d_texts_individual2)
233 | hml3d_text_individual2 = hml3d_text_individual2.strip().split("#")[0]
234 | batch["text_individual2"].append(hml3d_text_individual2)
235 |
236 |
237 | # Generate interactions using the base interaction model (in2IN)
238 | batch_interaction = copy.deepcopy(batch)
239 | batch_interaction["motion_lens"] = batch_interaction["motion_lens"][:num_repeats]
240 | batch_interaction["text"] = batch_interaction["text"][:num_repeats]
241 | batch_interaction["text_individual1"] = batch_interaction["text_individual1"][:num_repeats]
242 | batch_interaction["text_individual2"] = batch_interaction["text_individual2"][:num_repeats]
243 | self.model.decoder.cfg_composition_weight_func = "const"
244 | self.model.decoder.cfg_composition_weight_value = 0
245 | batch_interaction = self.model.forward_test(batch_interaction)
246 | motions_output_interaction = batch_interaction["output"].reshape(batch_interaction["output"].shape[0], batch_interaction["output"].shape[1], 2, -1)
247 | motions_output_interaction = self.normalizer.backward(motions_output_interaction.cpu().detach().numpy())
248 |
249 | # Generate interactions using the combined interaction model with the interaction (DualMDM)
250 | batch_individual = copy.deepcopy(batch)
251 | batch_individual["motion_lens"] = batch_individual["motion_lens"][num_repeats:]
252 | batch_individual["text"] = batch_individual["text"][num_repeats:]
253 | batch_individual["text_individual1"] = batch_individual["text_individual1"][num_repeats:]
254 | batch_individual["text_individual2"] = batch_individual["text_individual2"][num_repeats:]
255 | self.model.decoder.cfg_composition_weight_func = composition_weight_func
256 | self.model.decoder.cfg_composition_weight_value = composition_weight_value
257 | batch_individual = self.model.forward_test(batch_individual)
258 | motions_output_individual = batch_individual["output"].reshape(batch_individual["output"].shape[0], batch_individual["output"].shape[1], 2, -1)
259 | motions_output_individual = self.normalizer.backward(motions_output_individual.cpu().detach().numpy())
260 |
261 | motions_output = np.concatenate((motions_output_interaction, motions_output_individual), axis=0)
262 | B,T = motions_output.shape[0], motions_output.shape[1]
263 |
264 | # Padding all the generated motions to the biggest length of the dataset
265 | if T < self.max_length:
266 | padding_len = self.max_length - T
267 | D = motions_output.shape[-1]
268 | padding_zeros = np.zeros((B, padding_len, 2, D))
269 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1)
270 | assert motions_output.shape[1] == self.max_length
271 |
272 | # Save the generated motions
273 | sub_dict = {'generated_motions': motions_output,
274 | 'motion1': motion1,
275 | 'motion2': motion2,
276 | 'motion_lens': batch["motion_lens"],
277 | 'text': batch["text"],
278 | 'text_individual1': batch["text_individual1"],
279 | 'text_individual2': batch["text_individual2"]}
280 |
281 | generated_motions.append(sub_dict)
282 |
283 |
284 | self.generated_motions = generated_motions
285 |
286 | def __len__(self):
287 | """
288 | Get the length of the dataset.
289 | """
290 | return len(self.generated_motions)
291 |
292 | def __getitem__(self, item):
293 | """
294 | Get the item of the dataset.
295 | :param item: Index of the item.
296 | """
297 | data = self.generated_motions[item]
298 |
299 | motion_lens = data['motion_lens']
300 | text = data['text']
301 | text_individual1 = data['text_individual1']
302 | text_individual2 = data['text_individual2']
303 |
304 | generated_motions = data['generated_motions']
305 | generated_motions1 = generated_motions[:, :, 0, :]
306 | generated_motions2 = generated_motions[:, :, 1, :]
307 | motion1 = data['motion1']
308 | motion2 = data['motion2']
309 |
310 | return generated_motions1, generated_motions2, motion1, motion2, motion_lens, text, text_individual1, text_individual2
--------------------------------------------------------------------------------
/in2in/models/nets.py:
--------------------------------------------------------------------------------
1 | from in2in.models.utils.cfg_sampler import ClassifierFreeSampleDualMDM, ClassifierFreeSampleModel, ClassifierFreeSampleModelMultiple
2 | import torch
3 | import random
4 | import torch.nn as nn
5 |
6 |
7 | from in2in.models.utils.blocks import TransformerBlockDoubleCond
8 | from in2in.models.utils.gaussian_diffusion import LossType, ModelMeanType, ModelVarType, MotionDiffusion, create_named_schedule_sampler, get_named_beta_schedule, space_timesteps
9 | from in2in.models.utils.layers import FinalLayer
10 | from in2in.models.utils.utils import PositionalEncoding, TimestepEmbedder, zero_module
11 |
12 | class in2INDiffusion(nn.Module):
13 | # Mode can be individual interaction or dual
14 | def __init__(self, cfg, mode, sampling_strategy="ddim50"):
15 | super().__init__()
16 | self.cfg = cfg
17 | self.nfeats = cfg.INPUT_DIM
18 | self.latent_dim = cfg.LATENT_DIM
19 | self.ff_size = cfg.FF_SIZE
20 | self.num_layers = cfg.NUM_LAYERS
21 | self.num_heads = cfg.NUM_HEADS
22 | self.dropout = cfg.DROPOUT
23 | self.activation = cfg.ACTIVATION
24 | self.motion_rep = cfg.MOTION_REP
25 | self.diffusion_steps = cfg.DIFFUSION_STEPS
26 | self.beta_scheduler = cfg.BETA_SCHEDULER
27 | self.sampler = cfg.SAMPLER
28 | self.sampling_strategy = sampling_strategy
29 | self.mode = mode
30 |
31 | # Setting wieghts
32 | if self.mode == "dual":
33 | self.cfg_weight_individual = cfg.CFG_WEIGHT_INDIVIDUAL
34 | self.cfg_weight_interaction = cfg.CFG_WEIGHT_INTERACTION
35 | self.cfg_composition_weight_func = cfg.W_FUNC
36 | self.cfg_composition_weight_value = cfg.W_VALUE
37 | elif self.mode == "interaction":
38 | self.cfg_weight = cfg.CFG_WEIGHT
39 | self.cfg_weight_individual = cfg.CFG_WEIGHT_INDIVIDUAL
40 | self.cfg_weight_interaction = cfg.CFG_WEIGHT_INTERACTION
41 | elif self.mode == "individual":
42 | self.cfg_weight = cfg.CFG_WEIGHT
43 |
44 | # Creaning network
45 | if self.mode =="dual":
46 | self.net_interaction = in2INDenoiser(self.nfeats,
47 | latent_dim=self.latent_dim,
48 | ff_size=self.ff_size,
49 | num_layers=self.num_layers,
50 | num_heads=self.num_heads,
51 | dropout=self.dropout,
52 | activation=self.activation,
53 | mode="dual_interaction")
54 |
55 | self.net_individual = in2INDenoiser(self.nfeats,
56 | latent_dim=self.latent_dim,
57 | ff_size=self.ff_size,
58 | num_layers=self.num_layers,
59 | num_heads=self.num_heads,
60 | dropout=self.dropout,
61 | activation=self.activation,
62 | mode="dual_individual")
63 | elif self.mode == "interaction":
64 | self.net_interaction = in2INDenoiser(self.nfeats,
65 | latent_dim=self.latent_dim,
66 | ff_size=self.ff_size,
67 | num_layers=self.num_layers,
68 | num_heads=self.num_heads,
69 | dropout=self.dropout,
70 | activation=self.activation,
71 | mode="interaction")
72 | elif self.mode == "individual":
73 | self.net_individual = in2INDenoiser(self.nfeats,
74 | latent_dim=self.latent_dim,
75 | ff_size=self.ff_size,
76 | num_layers=self.num_layers,
77 | num_heads=self.num_heads,
78 | dropout=self.dropout,
79 | activation=self.activation,
80 | mode="individual")
81 |
82 |
83 | self.diffusion_steps = self.diffusion_steps
84 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps)
85 | timestep_respacing=[self.diffusion_steps]
86 |
87 |
88 | self.diffusion = MotionDiffusion(
89 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing),
90 | betas=self.betas,
91 | motion_rep=self.motion_rep,
92 | model_mean_type=ModelMeanType.START_X,
93 | model_var_type=ModelVarType.FIXED_SMALL,
94 | loss_type=LossType.MSE,
95 | rescale_timesteps = False,
96 | mode=self.mode
97 | )
98 |
99 | self.sampler = create_named_schedule_sampler(self.sampler, self.diffusion)
100 |
101 | def mask_cond(self, cond, cond_mask_prob = 0.1, force_mask=False):
102 | bs = cond.shape[0]
103 | if force_mask:
104 | return torch.zeros_like(cond)
105 | elif cond_mask_prob > 0.:
106 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * cond_mask_prob).view([bs]+[1]*len(cond.shape[1:])) # 1-> use null_cond, 0-> use real cond
107 | return cond * (1. - mask), (1. - mask)
108 | else:
109 | return cond, None
110 |
111 | def generate_src_mask(self, T, length):
112 | B = length.shape[0]
113 | src_mask = torch.ones(B, T, 2)
114 | for p in range(2):
115 | for i in range(B):
116 | for j in range(length[i], T):
117 | src_mask[i, j, p] = 0
118 | return src_mask
119 |
120 | def compute_loss(self, batch):
121 |
122 | if self.mode == "interaction":
123 | cond_interaction = batch["cond_interaction"]
124 | cond_interaction_individual1 = batch["cond_interaction_individual1"]
125 | cond_interaction_individual2 = batch["cond_interaction_individual2"]
126 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2], dim=1)
127 | elif self.mode == "individual":
128 | cond_individual_individual1 = batch["cond_individual_individual1"]
129 | cond = torch.cat([cond_individual_individual1], dim=1)
130 |
131 | x_start = batch["motions"]
132 | B,T = batch["motions"].shape[:2]
133 |
134 | if cond is not None:
135 | cond, cond_mask = self.mask_cond(cond, 0.1)
136 |
137 | seq_mask = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(x_start.device)
138 |
139 | t, _ = self.sampler.sample(B, x_start.device)
140 |
141 | if self.mode == "interaction":
142 | model = self.net_interaction
143 | elif self.mode == "individual":
144 | model = self.net_individual
145 |
146 | output = self.diffusion.training_losses(
147 | model=model,
148 | x_start=x_start,
149 | t=t,
150 | mask=seq_mask,
151 | t_bar=self.cfg.T_BAR,
152 | cond_mask=cond_mask,
153 | model_kwargs={"mask":seq_mask,
154 | "cond":cond,
155 | },
156 | )
157 | return output
158 |
159 | def forward(self, batch):
160 |
161 | if self.mode == "dual":
162 | cond_interaction = batch["cond_interaction"]
163 | cond_interaction_individual1 = batch["cond_interaction_individual1"]
164 | cond_interaction_individual2 = batch["cond_interaction_individual2"]
165 | cond_individual_individual1 = batch["cond_individual_individual1"]
166 | cond_individual_individual2 = batch["cond_individual_individual2"]
167 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2, cond_individual_individual1, cond_individual_individual2], dim=1)
168 | elif self.mode == "interaction":
169 | cond_interaction = batch["cond_interaction"]
170 | cond_interaction_individual1 = batch["cond_interaction_individual1"]
171 | cond_interaction_individual2 = batch["cond_interaction_individual2"]
172 | cond = torch.cat([cond_interaction, cond_interaction_individual1, cond_interaction_individual2], dim=1)
173 | elif self.mode == "individual":
174 | cond_individual_individual1 = batch["cond_individual_individual1"]
175 | cond = torch.cat([cond_individual_individual1], dim=1)
176 |
177 | B = cond.shape[0]
178 | T = batch["motion_lens"][0]
179 |
180 | timestep_respacing= self.sampling_strategy
181 | self.diffusion_test = MotionDiffusion(
182 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing),
183 | betas=self.betas,
184 | motion_rep=self.motion_rep,
185 | model_mean_type=ModelMeanType.START_X,
186 | model_var_type=ModelVarType.FIXED_SMALL,
187 | loss_type=LossType.MSE,
188 | rescale_timesteps = False,
189 | mode = self.mode
190 | )
191 |
192 | if self.mode == "dual":
193 | self.cfg_model = ClassifierFreeSampleDualMDM(self.net_individual, self.net_interaction, self.cfg_weight_individual, self.cfg_weight_interaction, self.cfg_composition_weight_func, self.cfg_composition_weight_value)
194 | output = self.diffusion_test.ddim_sample_loop(
195 | self.cfg_model,
196 | (B, T, self.nfeats*2),
197 | clip_denoised=False,
198 | progress=True,
199 | model_kwargs={
200 | "mask":None,
201 | "cond":cond,
202 | },
203 | x_start=None)
204 | elif self.mode == "interaction":
205 | self.cfg_model = ClassifierFreeSampleModelMultiple(self.net_interaction, self.cfg_weight, self.cfg_weight_interaction, self.cfg_weight_individual)
206 | output = self.diffusion_test.ddim_sample_loop(
207 | self.cfg_model,
208 | (B, T, self.nfeats*2),
209 | clip_denoised=False,
210 | progress=True,
211 | model_kwargs={
212 | "mask":None,
213 | "cond":cond,
214 | },
215 | x_start=None)
216 | elif self.mode == "individual":
217 | self.cfg_model = ClassifierFreeSampleModel(self.net_individual, self.cfg_weight)
218 | output = self.diffusion_test.ddim_sample_loop(
219 | self.cfg_model,
220 | (B, T, self.nfeats),
221 | clip_denoised=False,
222 | progress=True,
223 | model_kwargs={
224 | "mask":None,
225 | "cond":cond,
226 | },
227 | x_start=None)
228 |
229 |
230 | return {"output":output}
231 |
232 | class in2INDenoiser(nn.Module):
233 | def __init__(self,
234 | input_feats,
235 | mode,
236 | latent_dim=512,
237 | num_frames=240,
238 | ff_size=1024,
239 | num_layers=8,
240 | num_heads=8,
241 | dropout=0.1,
242 | activation="gelu",
243 | **kargs):
244 | super().__init__()
245 |
246 | self.num_frames = num_frames
247 | self.latent_dim = latent_dim
248 | self.ff_size = ff_size
249 | self.num_layers = num_layers
250 | self.num_heads = num_heads
251 | self.dropout = dropout
252 | self.activation = activation
253 | self.input_feats = input_feats
254 | self.time_embed_dim = latent_dim
255 | self.mode = mode
256 | self.text_emb_dim = 768
257 |
258 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=0)
259 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
260 |
261 | # Input Embedding
262 | self.motion_embed = nn.Linear(self.input_feats, self.latent_dim)
263 | self.text_embed = nn.Linear(self.text_emb_dim, self.latent_dim)
264 |
265 | self.blocks = nn.ModuleList()
266 |
267 | for i in range(num_layers):
268 | self.blocks.append(TransformerBlockDoubleCond(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size, mode=self.mode))
269 |
270 | # Output Module
271 | self.out = zero_module(FinalLayer(self.latent_dim, self.input_feats))
272 |
273 |
274 | def forward(self, x, timesteps, mask=None, cond=None):
275 | """
276 | x: B, T, D
277 | """
278 | B, T = x.shape[0], x.shape[1]
279 | x_a = x[...,:self.input_feats]
280 |
281 | if self.mode != "individual":
282 | x_b = x[...,self.input_feats:]
283 |
284 | if mask is not None:
285 | mask = mask[...,0]
286 |
287 | if self.mode == "dual_interaction" or self.mode == "interaction":
288 | emb = self.embed_timestep(timesteps) + self.text_embed(cond[:,:768])
289 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768:768*2])
290 | emb_individual2 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*2:768*3])
291 | elif self.mode == "dual_individual":
292 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*3:768*4])
293 | emb_individual2 = self.embed_timestep(timesteps) + self.text_embed(cond[:,768*4:])
294 | elif self.mode == "individual":
295 | emb_individual1 = self.embed_timestep(timesteps) + self.text_embed(cond[:,:768])
296 | else:
297 | raise ValueError("Mode not recognized")
298 |
299 | a_emb = self.motion_embed(x_a)
300 | h_a_prev = self.sequence_pos_encoder(a_emb)
301 |
302 | if self.mode != "individual":
303 | b_emb = self.motion_embed(x_b)
304 | h_b_prev = self.sequence_pos_encoder(b_emb)
305 |
306 | if mask is None:
307 | mask = torch.ones(B, T).to(x_a.device)
308 | key_padding_mask = ~(mask > 0.5)
309 |
310 | for i,block in enumerate(self.blocks):
311 | if self.mode == "interaction" or self.mode == "dual_interaction":
312 | h_a = block(h_a_prev, h_b_prev, emb_individual1, emb, key_padding_mask)
313 | h_b = block(h_b_prev, h_a_prev, emb_individual2, emb, key_padding_mask)
314 | elif self.mode == "dual_individual":
315 | h_a = block(h_a_prev, None, emb_individual1, None, key_padding_mask)
316 | h_b = block(h_b_prev, None, emb_individual2, None, key_padding_mask)
317 | elif self.mode == "individual":
318 | h_a = block(h_a_prev, None, emb_individual1, None, key_padding_mask)
319 | else:
320 | raise ValueError("Mode not recognized")
321 |
322 | h_a_prev = h_a
323 |
324 | if self.mode == "dual_interaction" or self.mode == "interaction":
325 | h_b_prev = h_b
326 |
327 |
328 | output_a = self.out(h_a)
329 |
330 | if self.mode == "individual":
331 | output = torch.cat([output_a], dim=-1)
332 | else:
333 | output_b = self.out(h_b)
334 | output = torch.cat([output_a, output_b], dim=-1)
335 |
336 | return output
--------------------------------------------------------------------------------