├── common
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── quaternion.cpython-38.pyc
└── quaternion.py
├── utils
├── __init__.py
├── __pycache__
│ ├── utils.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ ├── plot_script.cpython-38.pyc
│ ├── preprocess.cpython-38.pyc
│ ├── quaternion.cpython-38.pyc
│ └── rotation_conversions.cpython-38.pyc
├── preprocess.py
├── paramUtil.py
├── plot_script.py
├── metrics.py
├── utils.py
├── quaternion.py
└── rotation_conversions.py
├── tools
├── __init__.py
├── train.py
└── eval.py
├── assets
├── model.png
├── teaser.png
└── demo_teaser.mp4
├── data
├── global_std.npy
└── global_mean.npy
├── models
├── __pycache__
│ ├── nets.cpython-38.pyc
│ ├── utils.cpython-38.pyc
│ ├── blocks.cpython-38.pyc
│ ├── layers.cpython-38.pyc
│ ├── losses.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── intergen.cpython-38.pyc
│ ├── cfg_sampler.cpython-38.pyc
│ └── gaussian_diffusion.cpython-38.pyc
├── __init__.py
├── cfg_sampler.py
├── blocks.py
├── utils.py
├── layers.py
├── intergen.py
├── nets.py
└── losses.py
├── configs
├── __pycache__
│ └── __init__.cpython-38.pyc
├── eval_model.yaml
├── infer.yaml
├── train_single.yaml
├── train_inter.yaml
├── model_single.yaml
├── model_inter.yaml
├── datasets_inter.yaml
├── datasets_single.yaml
└── __init__.py
├── datasets
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── evaluator.cpython-38.pyc
│ ├── interhuman.cpython-38.pyc
│ └── evaluator_models.cpython-38.pyc
├── __init__.py
├── evaluator_models.py
├── dataloader.py
└── evaluator.py
├── test.sh
├── train_inter.sh
├── train_single.sh
├── prepare
└── download_evaluation_model.sh
├── requirements.txt
└── README.md
/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/model.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/data/global_std.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/data/global_std.npy
--------------------------------------------------------------------------------
/data/global_mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/data/global_mean.npy
--------------------------------------------------------------------------------
/assets/demo_teaser.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/assets/demo_teaser.mp4
--------------------------------------------------------------------------------
/models/__pycache__/nets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/nets.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/blocks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/blocks.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/layers.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/common/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/common/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/common/__pycache__/quaternion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/common/__pycache__/quaternion.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/configs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/intergen.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/intergen.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/plot_script.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/plot_script.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/preprocess.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/preprocess.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/quaternion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/quaternion.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/evaluator.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/evaluator.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/interhuman.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/interhuman.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/cfg_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/cfg_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/evaluator_models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/datasets/__pycache__/evaluator_models.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/gaussian_diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/models/__pycache__/gaussian_diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/rotation_conversions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VankouF/FreeMotion-Codes/HEAD/utils/__pycache__/rotation_conversions.cpython-38.pyc
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python tools/eval.py --model_config configs/model_inter.yaml --dataset_config configs/datasets_inter.yaml --evalmodel_config configs/eval_model.yaml
2 |
3 |
--------------------------------------------------------------------------------
/train_inter.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --model_config configs/model_inter.yaml --dataset_config configs/datasets_inter.yaml --train_config configs/train_inter.yaml
2 |
3 |
--------------------------------------------------------------------------------
/train_single.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --model_config configs/model_single.yaml --dataset_config configs/datasets_single.yaml --train_config configs/train_single.yaml
2 |
3 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian_diffusion import GaussianDiffusion
2 | from .nets import *
3 | from .intergen import *
4 | from .blocks import *
5 | from .cfg_sampler import *
6 | from .utils import *
--------------------------------------------------------------------------------
/configs/eval_model.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterCLIP
2 | NUM_LAYERS: 8
3 | NUM_HEADS: 8
4 | DROPOUT: 0.1
5 | INPUT_DIM: 258
6 | LATENT_DIM: 1024
7 | FF_SIZE: 2048
8 | ACTIVATION: gelu
9 |
10 | MOTION_REP: global
11 | FINETUNE: False
12 |
--------------------------------------------------------------------------------
/prepare/download_evaluation_model.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ./eval_model/
2 | cd ./eval_model/
3 | echo "The pretrained evaluation model will be stored in the 'eval_model' folder\n"
4 | # InterHuman
5 | echo "Downloading the evaluation model..."
6 | gdown https://drive.google.com/uc?id=1bJv5lTP7otJleaBYZ2byjru_k_wCsGvH
7 | echo "Evaluation model downloaded"
--------------------------------------------------------------------------------
/configs/infer.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: IG-S-8
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 1
10 | EPOCH: 2000
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_STEPS: 20000
14 | SAVE_EPOCH: 100
15 | RESUME: #checkpoints/IG-S/8/model/epoch=99-step=17600.ckpt
16 | NUM_WORKERS: 2
17 | MODE: finetune
18 | LAST_EPOCH: 0
19 | LAST_ITER: 0
20 |
--------------------------------------------------------------------------------
/configs/train_single.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: train_single
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 80
10 | EPOCH: 2500
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_STEPS: 20000
14 | SAVE_EPOCH: 50
15 | SAVE_TOPK: -1
16 | RESUME:
17 | NUM_WORKERS: 32
18 | MODE: finetune
19 | LAST_EPOCH: 0
20 | LAST_ITER: 0
21 | FROM_PRETRAIN:
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | numpy~=1.23.3
3 | tqdm~=4.65.0
4 | lightning==1.9.1
5 | scipy~=1.4.1
6 | matplotlib==3.2.0
7 | pillow~=7.2.0
8 | yacs~=0.1.8
9 | mmcv~=1.6.2
10 | opencv-python~=4.5.3.56
11 | tabulate~=0.8.9
12 | termcolor~=1.1.0
13 | smplx~=0.1.28
14 | torch==1.13.1+cu117
15 | torchvision==0.14.1+cu117
16 | torchaudio==0.13.1
17 | tensorboard==2.14.0
18 | git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/configs/train_inter.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | EXP_NAME: train_inter
3 | CHECKPOINT: ./checkpoints
4 | LOG_DIR: ./log
5 |
6 | TRAIN:
7 | LR: 1e-4
8 | WEIGHT_DECAY: 0.00002
9 | BATCH_SIZE: 30
10 | EPOCH: 1000
11 | STEP: 1000000
12 | LOG_STEPS: 10
13 | SAVE_STEPS: 20000
14 | SAVE_EPOCH: 50
15 | SAVE_TOPK: -1
16 | RESUME:
17 | NUM_WORKERS: 32
18 | MODE: finetune
19 | LAST_EPOCH: 0
20 | LAST_ITER: 0
21 | FROM_PRETRAIN: ./checkpoints/train_single/epoch999.ckpt
22 |
--------------------------------------------------------------------------------
/configs/model_single.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterGenSpatialControl
2 | ARCHI: single # single # single, multi
3 | NUM_LAYERS: 8
4 | NUM_HEADS: 8
5 | DROPOUT: 0.1
6 | INPUT_DIM: 262
7 | LATENT_DIM: 1024
8 | FF_SIZE: 2048
9 | ACTIVATION: gelu
10 | CHECKPOINT:
11 |
12 | DIFFUSION_STEPS: 1000
13 | BETA_SCHEDULER: cosine
14 | SAMPLER: uniform
15 |
16 | MOTION_REP: global
17 | FINETUNE: False
18 |
19 | TEXT_ENCODER: clip
20 | T_BAR: 700
21 |
22 | CONTROL: text
23 | STRATEGY: ddim50
24 | CFG_WEIGHT: 3.5
25 |
26 | GENERATE_NUM: 2
27 |
28 |
--------------------------------------------------------------------------------
/configs/model_inter.yaml:
--------------------------------------------------------------------------------
1 | NAME: InterGenSpatialControl
2 | ARCHI: multi # single # single, multi
3 | NUM_LAYERS: 8
4 | NUM_HEADS: 8
5 | DROPOUT: 0.1
6 | INPUT_DIM: 262
7 | LATENT_DIM: 1024
8 | FF_SIZE: 2048
9 | ACTIVATION: gelu
10 | CHECKPOINT: ./checkpoints/epoch999.ckpt
11 |
12 | DIFFUSION_STEPS: 1000
13 | BETA_SCHEDULER: cosine
14 | SAMPLER: uniform
15 |
16 | MOTION_REP: global
17 | FINETUNE: False
18 |
19 | TEXT_ENCODER: clip
20 | T_BAR: 700
21 |
22 | CONTROL: text
23 | STRATEGY: ddim50
24 | CFG_WEIGHT: 3.5
25 |
26 | GENERATE_NUM: 2
27 |
28 |
--------------------------------------------------------------------------------
/configs/datasets_inter.yaml:
--------------------------------------------------------------------------------
1 | interhuman:
2 | NAME: interhuman
3 | DATA_ROOT: ./data
4 | MOTION_REP: global
5 | MODE: train
6 | CACHE: True
7 | SINGLE_TEXT_DESCRIPTION: True
8 |
9 | interhuman_val:
10 | NAME: interhuman
11 | DATA_ROOT: ./data
12 | MOTION_REP: global
13 | MODE: val
14 | CACHE: True
15 | SINGLE_TEXT_DESCRIPTION: True
16 |
17 | interhuman_test:
18 | NAME: interhuman
19 | DATA_ROOT: ./data
20 | MOTION_REP: global
21 | MODE: test
22 | CACHE: True
23 | SINGLE_TEXT_DESCRIPTION: True
24 |
--------------------------------------------------------------------------------
/configs/datasets_single.yaml:
--------------------------------------------------------------------------------
1 | interhuman:
2 | NAME: singlehuman
3 | DATA_ROOT: ./data
4 | MOTION_REP: global
5 | MODE: train
6 | CACHE: True
7 | SINGLE_TEXT_DESCRIPTION: True
8 |
9 | interhuman_val:
10 | NAME: singlehuman
11 | DATA_ROOT: ./data
12 | MOTION_REP: global
13 | MODE: val
14 | CACHE: True
15 | SINGLE_TEXT_DESCRIPTION: True
16 |
17 | interhuman_test:
18 | NAME: singlehuman
19 | DATA_ROOT: ./data
20 | MOTION_REP: global
21 | MODE: test
22 | CACHE: True
23 | SINGLE_TEXT_DESCRIPTION: True
24 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils.utils import *
3 |
4 | FPS = 30
5 |
6 | def load_motion(file_path, min_length, swap=False):
7 |
8 | try:
9 | motion = np.load(file_path).astype(np.float32)
10 | except:
11 | print("error: ", file_path)
12 | return None, None
13 |
14 | motion1 = motion[:, :22 * 3] # 22*3 表示的是每个关节的position
15 | motion2 = motion[:, 62 * 3:62 * 3 + 21 * 6] # 21*6 表示的6D旋转表示
16 | motion = np.concatenate([motion1, motion2], axis=1)
17 |
18 | if motion.shape[0] < min_length:
19 | return None, None
20 | if swap:
21 | motion_swap = swap_left_right(motion, 22)
22 | else:
23 | motion_swap = None
24 | return motion, motion_swap
--------------------------------------------------------------------------------
/models/cfg_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | class ClassifierFreeSampleModel(nn.Module):
4 |
5 | def __init__(self, model, cfg_scale):
6 | super().__init__()
7 | self.model = model # model is the actual model to run
8 | self.s = cfg_scale
9 |
10 | def forward(self, x, timesteps, cond=None, mask=None, motion_guidance = None, spatial_condition=None, **kwargs):
11 | B, T, D = x.shape
12 |
13 | x_combined = torch.cat([x, x], dim=0)
14 | timesteps_combined = torch.cat([timesteps, timesteps], dim=0)
15 | if cond is not None:
16 | cond = torch.cat([cond, torch.zeros_like(cond)], dim=0)
17 | if mask is not None:
18 | mask = torch.cat([mask, mask], dim=0)
19 | if motion_guidance is not None:
20 | motion_guidance = torch.cat([motion_guidance, motion_guidance], dim=0)
21 | if spatial_condition is not None:
22 | spatial_condition = torch.cat([spatial_condition, spatial_condition], dim=0)
23 | out = self.model(x_combined, timesteps_combined, cond=cond, mask=mask, motion_guidance=motion_guidance, spatial_condition=spatial_condition, **kwargs)
24 |
25 | out_cond = out[:B]
26 | out_uncond = out[B:]
27 |
28 | cfg_out = self.s * out_cond + (1-self.s) *out_uncond
29 | return cfg_out
30 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict
3 | from yacs.config import CfgNode as CN
4 |
5 |
6 | def to_lower(x: Dict) -> Dict:
7 | """
8 | Convert all dictionary keys to lowercase
9 | Args:
10 | x (dict): Input dictionary
11 | Returns:
12 | dict: Output dictionary with all keys converted to lowercase
13 | """
14 | return {k.lower(): v for k, v in x.items()}
15 |
16 | _C = CN(new_allowed=True)
17 |
18 | def default_config() -> CN:
19 | """
20 | Get a yacs CfgNode object with the default config values.
21 | """
22 | # Return a clone so that the defaults will not be altered
23 | # This is for the "local variable" use pattern
24 | return _C.clone()
25 |
26 | def get_config(config_file: str, merge: bool = True) -> CN:
27 | """
28 | Read a config file and optionally merge it with the default config file.
29 | Args:
30 | config_file (str): Path to config file.
31 | merge (bool): Whether to merge with the default config or not.
32 | Returns:
33 | CfgNode: Config as a yacs CfgNode object.
34 | """
35 | if merge:
36 | cfg = default_config()
37 | else:
38 | cfg = CN(new_allowed=True)
39 | cfg.merge_from_file(config_file)
40 | cfg.freeze()
41 | return cfg
42 |
43 | def dataset_config() -> CN:
44 | """
45 | Get dataset config file
46 | Returns:
47 | CfgNode: Dataset config as a yacs CfgNode object.
48 | """
49 | cfg = CN(new_allowed=True)
50 | config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets.yaml')
51 | cfg.merge_from_file(config_file)
52 | cfg.freeze()
53 | return cfg
54 |
55 |
--------------------------------------------------------------------------------
/utils/paramUtil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Define a kinematic tree for the skeletal struture
4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5 |
6 | kit_raw_offsets = np.array(
7 | [
8 | [0, 0, 0],
9 | [0, 1, 0],
10 | [0, 1, 0],
11 | [0, 1, 0],
12 | [0, 1, 0],
13 | [1, 0, 0],
14 | [0, -1, 0],
15 | [0, -1, 0],
16 | [-1, 0, 0],
17 | [0, -1, 0],
18 | [0, -1, 0],
19 | [1, 0, 0],
20 | [0, -1, 0],
21 | [0, -1, 0],
22 | [0, 0, 1],
23 | [0, 0, 1],
24 | [-1, 0, 0],
25 | [0, -1, 0],
26 | [0, -1, 0],
27 | [0, 0, 1],
28 | [0, 0, 1]
29 | ]
30 | )
31 |
32 | t2m_raw_offsets = np.array([[0,0,0],
33 | [1,0,0],
34 | [-1,0,0],
35 | [0,1,0],
36 | [0,-1,0],
37 | [0,-1,0],
38 | [0,1,0],
39 | [0,-1,0],
40 | [0,-1,0],
41 | [0,1,0],
42 | [0,0,1],
43 | [0,0,1],
44 | [0,1,0],
45 | [1,0,0],
46 | [-1,0,0],
47 | [0,0,1],
48 | [0,-1,0],
49 | [0,-1,0],
50 | [0,-1,0],
51 | [0,-1,0],
52 | [0,-1,0],
53 | [0,-1,0]])
54 |
55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58 |
59 |
60 | kit_tgt_skel_id = '03950'
61 |
62 | t2m_tgt_skel_id = '000021'
63 |
64 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | import torch
3 | from .interhuman import (
4 | InterHumanDataset,
5 | SingleHumanDataset,
6 | InterHumanPipelineInferDataset,
7 | )
8 | from datasets.evaluator import (
9 | EvaluatorModelWrapper,
10 | EvaluationDataset,
11 | get_dataset_motion_loader,
12 | get_motion_loader)
13 |
14 |
15 | __all__ = [
16 | 'InterHumanDataset',
17 | 'InterHumanPipelineInferDataset',
18 | 'get_dataset_motion_loader', 'get_motion_loader'
19 | ]
20 |
21 | def build_loader(cfg, data_cfg):
22 | # setup data
23 | if data_cfg.NAME == "interhuman":
24 | train_dataset = InterHumanDataset(data_cfg)
25 | else:
26 | raise NotImplementedError
27 |
28 | loader = torch.utils.data.DataLoader(
29 | train_dataset,
30 | batch_size=cfg.BATCH_SIZE,
31 | num_workers=1,
32 | pin_memory=False,
33 | shuffle=True,
34 | drop_last=True,
35 | )
36 |
37 | return loader
38 |
39 | class DataModule(pl.LightningDataModule):
40 | def __init__(self, cfg, batch_size, num_workers):
41 | """
42 | Initialize LightningDataModule for ProHMR training
43 | Args:
44 | cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
45 | dataset_cfg (CfgNode): Dataset configuration file
46 | """
47 | super().__init__()
48 | self.cfg = cfg
49 | self.batch_size = batch_size
50 | self.num_workers = num_workers
51 |
52 | def setup(self, stage = None):
53 | """
54 | Create train and validation datasets
55 | """
56 |
57 | if self.cfg.NAME == "interhuman":
58 | self.train_dataset = InterHumanDataset(self.cfg)
59 | elif self.cfg.NAME == 'singlehuman':
60 | self.train_dataset = SingleHumanDataset(self.cfg)
61 | else:
62 | raise NotImplementedError
63 |
64 | def train_dataloader(self):
65 | """
66 | Return train dataloader
67 | """
68 | return torch.utils.data.DataLoader(
69 | self.train_dataset,
70 | batch_size=self.batch_size,
71 | num_workers=self.num_workers,
72 | pin_memory=False,
73 | shuffle=True,
74 | drop_last=True,
75 | )
76 |
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | from .layers import *
2 |
3 |
4 | class TransformerBlock(nn.Module):
5 | def __init__(self,
6 | latent_dim=512,
7 | num_heads=8,
8 | ff_size=1024,
9 | dropout=0.,
10 | cond_abl=False,
11 | **kargs):
12 | super().__init__()
13 | self.latent_dim = latent_dim
14 | self.num_heads = num_heads
15 | self.dropout = dropout
16 | self.cond_abl = cond_abl
17 |
18 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout)
19 | self.ca_block = VanillaCrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
20 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim)
21 |
22 | def forward(self, x, y, emb=None, key_padding_mask=None):
23 | h1 = self.sa_block(x, emb, key_padding_mask)
24 | h1 = h1 + x
25 | h2 = self.ca_block(h1, y, emb, key_padding_mask)
26 | h2 = h2 + h1
27 | out = self.ffn(h2, emb)
28 | out = out + h2
29 | return out
30 |
31 |
32 | class TransformerMotionGuidanceBlock(nn.Module):
33 | def __init__(self,
34 | latent_dim=512,
35 | num_heads=8,
36 | ff_size=1024,
37 | dropout=0.,
38 | cond_abl=False,
39 | **kargs):
40 | super().__init__()
41 | self.latent_dim = latent_dim
42 | self.num_heads = num_heads
43 | self.dropout = dropout
44 | self.cond_abl = cond_abl
45 |
46 | self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout)
47 | self.condition_sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout, latent_dim)
48 |
49 | self.ffn = FFN(latent_dim, ff_size, dropout, latent_dim)
50 |
51 | def forward(self, x, T=300, emb=None, key_padding_mask=None):
52 |
53 | x_a = x[:,:T,...]
54 | key_padding_mask_a = key_padding_mask[:,:T]
55 |
56 | # self_att first
57 | h1 = self.sa_block(x_a, emb, key_padding_mask_a)
58 | h1 = h1 + x_a
59 |
60 | h1 = torch.cat([h1, x[:,T:,...]], dim=1)
61 |
62 | # add motion guidance to att again
63 | h2 = self.condition_sa_block(h1, emb, key_padding_mask)
64 | h2 = h2 + h1
65 |
66 | out = self.ffn(h2, emb)
67 | out = out + h2
68 |
69 | return out
70 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 |
6 | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
7 | def __init__(self, optimizer, warmup, max_iters, verbose=False):
8 | self.warmup = warmup
9 | self.max_num_iters = max_iters
10 | super().__init__(optimizer, verbose=verbose)
11 |
12 | def get_lr(self):
13 | lr_factor = self.get_lr_factor(epoch=self.last_epoch)
14 | return [base_lr * lr_factor for base_lr in self.base_lrs]
15 |
16 | def get_lr_factor(self, epoch):
17 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
18 | if epoch <= self.warmup:
19 | lr_factor *= (epoch+1) * 1.0 / self.warmup
20 | return lr_factor
21 |
22 |
23 |
24 | class PositionalEncoding(nn.Module):
25 | def __init__(self, d_model, dropout=0.0, max_len=5000):
26 | super(PositionalEncoding, self).__init__()
27 | self.dropout = nn.Dropout(p=dropout)
28 |
29 | pe = torch.zeros(max_len, d_model)
30 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
31 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
32 | pe[:, 0::2] = torch.sin(position * div_term)
33 | pe[:, 1::2] = torch.cos(position * div_term)
34 | # pe = pe.unsqueeze(0)#.transpose(0, 1)
35 |
36 | self.register_buffer('pe', pe)
37 |
38 | def forward(self, x):
39 | # not used in the final model
40 | x = x + self.pe[:x.shape[1], :].unsqueeze(0)
41 | return self.dropout(x)
42 |
43 |
44 | class TimestepEmbedder(nn.Module):
45 | def __init__(self, latent_dim, sequence_pos_encoder):
46 | super().__init__()
47 | self.latent_dim = latent_dim
48 | self.sequence_pos_encoder = sequence_pos_encoder
49 |
50 | time_embed_dim = self.latent_dim
51 | self.time_embed = nn.Sequential(
52 | nn.Linear(self.latent_dim, time_embed_dim),
53 | nn.SiLU(),
54 | nn.Linear(time_embed_dim, time_embed_dim),
55 | )
56 |
57 | def forward(self, timesteps):
58 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps])
59 |
60 |
61 | class IdentityEmbedder(nn.Module):
62 | def __init__(self, latent_dim, sequence_pos_encoder):
63 | super().__init__()
64 | self.latent_dim = latent_dim
65 | self.sequence_pos_encoder = sequence_pos_encoder
66 |
67 | time_embed_dim = self.latent_dim
68 | self.time_embed = nn.Sequential(
69 | nn.Linear(self.latent_dim, time_embed_dim),
70 | nn.SiLU(),
71 | nn.Linear(time_embed_dim, time_embed_dim),
72 | )
73 |
74 | def forward(self, timesteps):
75 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).unsqueeze(1)
76 |
77 |
78 | def set_requires_grad(nets, requires_grad=False):
79 | """Set requies_grad for all the networks.
80 |
81 | Args:
82 | nets (nn.Module | list[nn.Module]): A list of networks or a single
83 | network.
84 | requires_grad (bool): Whether the networks require gradients or not
85 | """
86 | if not isinstance(nets, list):
87 | nets = [nets]
88 | for net in nets:
89 | if net is not None:
90 | for param in net.parameters():
91 | param.requires_grad = requires_grad
92 |
93 |
94 | def zero_module(module):
95 | """
96 | Zero out the parameters of a module and return it.
97 | """
98 | for p in module.parameters():
99 | p.detach().zero_()
100 | return module
101 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
3 | class AdaLN(nn.Module):
4 |
5 | def __init__(self, latent_dim, embed_dim=None):
6 | super().__init__()
7 | if embed_dim is None:
8 | embed_dim = latent_dim
9 | self.emb_layers = nn.Sequential(
10 | # nn.Linear(embed_dim, latent_dim, bias=True),
11 | nn.SiLU(),
12 | zero_module(nn.Linear(embed_dim, 2 * latent_dim, bias=True)),
13 | )
14 | self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False, eps=1e-6)
15 |
16 | def forward(self, h, emb):
17 | """
18 | h: B, T, D
19 | emb: B, D
20 | """
21 | # B, 1, 2D
22 | emb_out = self.emb_layers(emb)
23 | # scale: B, 1, D / shift: B, 1, D
24 | scale, shift = torch.chunk(emb_out, 2, dim=-1)
25 | h = self.norm(h) * (1 + scale[:, None]) + shift[:, None]
26 | return h
27 |
28 |
29 | class VanillaSelfAttention(nn.Module):
30 |
31 | def __init__(self, latent_dim, num_head, dropout, embed_dim=None):
32 | super().__init__()
33 | self.num_head = num_head
34 | self.norm = AdaLN(latent_dim, embed_dim)
35 | self.attention = nn.MultiheadAttention(latent_dim, num_head, dropout=dropout, batch_first=True,
36 | add_zero_attn=True)
37 |
38 | def forward(self, x, emb, key_padding_mask=None):
39 | """
40 | x: B, T, D
41 | """
42 | x_norm = self.norm(x, emb)
43 | y = self.attention(x_norm, x_norm, x_norm,
44 | attn_mask=None,
45 | key_padding_mask=key_padding_mask,
46 | need_weights=False)[0]
47 | return y
48 |
49 |
50 | class VanillaCrossAttention(nn.Module):
51 |
52 | def __init__(self, latent_dim, xf_latent_dim, num_head, dropout, embed_dim=None):
53 | super().__init__()
54 | self.num_head = num_head
55 | self.norm = AdaLN(latent_dim, embed_dim)
56 | self.xf_norm = AdaLN(xf_latent_dim, embed_dim)
57 | self.attention = nn.MultiheadAttention(latent_dim, num_head, kdim=xf_latent_dim, vdim=xf_latent_dim,
58 | dropout=dropout, batch_first=True, add_zero_attn=True)
59 |
60 | def forward(self, x, xf, emb, key_padding_mask=None):
61 | """
62 | x: B, T, D
63 | xf: B, N, L
64 | """
65 | x_norm = self.norm(x, emb)
66 | xf_norm = self.xf_norm(xf, emb)
67 | y = self.attention(x_norm, xf_norm, xf_norm,
68 | attn_mask=None,
69 | key_padding_mask=key_padding_mask,
70 | need_weights=False)[0]
71 | return y
72 |
73 |
74 | class FFN(nn.Module):
75 | def __init__(self, latent_dim, ffn_dim, dropout, embed_dim=None):
76 | super().__init__()
77 | self.norm = AdaLN(latent_dim, embed_dim)
78 | self.linear1 = nn.Linear(latent_dim, ffn_dim, bias=True)
79 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim, bias=True))
80 | self.activation = nn.GELU()
81 | self.dropout = nn.Dropout(dropout)
82 |
83 | def forward(self, x, emb=None):
84 | if emb is not None:
85 | x_norm = self.norm(x, emb)
86 | else:
87 | x_norm = x
88 | y = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
89 | return y
90 |
91 |
92 | class FinalLayer(nn.Module):
93 | def __init__(self, latent_dim, out_dim):
94 | super().__init__()
95 | self.linear = zero_module(nn.Linear(latent_dim, out_dim, bias=True))
96 |
97 | def forward(self, x):
98 | x = self.linear(x)
99 | return x
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Official repo for FreeMotion
3 |
4 |
5 |
6 |
17 |
18 |
19 |

20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | ## Intro FreeMotion
31 |
32 | Text-to-motion synthesis is a crucial task in computer vision. Existing methods are limited in their universality, as they are tailored for single-person or two-person scenarios and can not be applied to generate motions for more individuals. To achieve the number-free motion synthesis, this paper reconsiders motion generation and proposes to unify the single and multi-person motion by the conditional motion distribution. Furthermore, a generation module and an interaction module are designed for our FreeMotion framework to decouple the process of conditional motion generation and finally support the number-free motion synthesis. Besides, based on our framework, the current single-person motion spatial control method could be seamlessly integrated, achieving precise control of multi-person motion. Extensive experiments demonstrate the superior performance of our method and our capability to infer single and multi-human motions simultaneously.
33 |
34 |
35 |
36 | ## ☑️ Todo List
37 |
38 | - [✓] Release the FreeMotion training.
39 | - [✓] Release the FreeMotion evaluation.
40 | - [✓] Release the separate_annots dataset.
41 | - [] Release the inference code.
42 | - [] Release the FreeMotion checkpoints.
43 |
44 | ## Quick Start
45 |
46 |
48 |
49 | ### 1. Conda environment
50 |
51 | ```
52 | conda create python=3.8 --name freemotion
53 | conda activate freemotion
54 | ```
55 |
56 | ```
57 | pip install -r requirements.txt
58 | ```
59 |
60 | ### 2. Download Text to Motion Evaluation Model
61 |
62 | ```
63 | bash prepare/download_evaluation_model.sh
64 | ```
65 |
66 |
67 | ## Train and Evaluation your own models
68 |
69 | ### 1. Prepare the datasets
70 |
71 | > Download the data from [InterGen Webpage](https://tr3e.github.io/intergen-page/). And put them into ./data/.
72 |
73 | > Download the data from [Ours Webpage](https://vankouf.github.io/FreeMotion/) And put it into ./data/.
74 |
75 | #### Data Structure
76 | ```sh
77 |
78 | ./annots //Natural language annotations where each file consisting of three sentences.
79 | ./motions //Raw motion data standardized as SMPL which is similiar to AMASS.
80 | ./motions_processed //Processed motion data with joint positions and rotations (6D representation) of SMPL 22 joints kinematic structure.
81 | ./split //Train-val-test split.
82 | ./separate_annots //Annotations for each person's motion
83 | ```
84 |
85 | ### 2. Train
86 |
87 | #### Stage 1: train generation module
88 | ```
89 | sh train_single.sh
90 | ```
91 |
92 | #### Stage 2: train interaction module
93 | ```
94 | sh train_inter.sh
95 | ```
96 |
97 | ### 3. Evaluation
98 |
99 | ```
100 | sh test.sh
101 | ```
102 |
103 | ## 📖 Citation
104 |
105 | If you find our code or paper helps, please consider citing:
106 |
107 | ```bibtex
108 | @article{fan2024freemotion,
109 | title={FreeMotion: A Unified Framework for Number-free Text-to-Motion Synthesis},
110 | author={Ke Fan and Junshu Tang and Weijian Cao and Ran Yi and Moran Li and Jingyu Gong and Jiangning Zhang and Yabiao Wang and Chengjie Wang and Lizhuang Ma},
111 | year={2024},
112 | eprint={2405.15763},
113 | archivePrefix={arXiv},
114 | primaryClass={cs.CV}
115 | }
116 |
117 | ```
118 |
119 | ## Acknowledgments
120 |
121 | Thanks to [interhuman](https://github.com/tr3e/InterGen),[MotionGPT](https://github.com/OpenMotionLab/MotionGPT), our code is partially borrowing from them.
122 |
123 | ## Licenses
124 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
--------------------------------------------------------------------------------
/datasets/evaluator_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import clip
6 |
7 | from models import *
8 |
9 |
10 | loss_ce = nn.CrossEntropyLoss()
11 | class InterCLIP(nn.Module):
12 | def __init__(self, cfg):
13 | super().__init__()
14 | self.cfg = cfg
15 | self.latent_dim = cfg.LATENT_DIM
16 | self.motion_encoder = MotionEncoder(cfg)
17 |
18 | self.latent_dim = self.latent_dim
19 |
20 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False)
21 | self.token_embedding = clip_model.token_embedding
22 | self.positional_embedding = clip_model.positional_embedding
23 | self.dtype = clip_model.dtype
24 | self.latent_scale = nn.Parameter(torch.Tensor([1]))
25 |
26 | set_requires_grad(self.token_embedding, False)
27 |
28 | textTransEncoderLayer = nn.TransformerEncoderLayer(
29 | d_model=768,
30 | nhead=8,
31 | dim_feedforward=cfg.FF_SIZE,
32 | dropout=0.1,
33 | activation="gelu",
34 | batch_first=True)
35 | self.textTransEncoder = nn.TransformerEncoder(
36 | textTransEncoderLayer,
37 | num_layers=8)
38 | self.text_ln = nn.LayerNorm(768)
39 | self.out = nn.Linear(768, 512)
40 |
41 | self.clip_training = "text_"
42 | self.l1_criterion = torch.nn.L1Loss(reduction='mean')
43 |
44 | def compute_loss(self, batch):
45 | losses = {}
46 | losses["total"] = 0
47 |
48 | # compute clip losses
49 | batch = self.encode_text(batch)
50 | batch = self.encode_motion(batch)
51 |
52 | mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)
53 | losses.update(clip_losses)
54 | losses["total"] += mixed_clip_loss
55 |
56 | return losses["total"], losses
57 |
58 | def forward(self, batch):
59 | return self.compute_loss(batch)
60 |
61 | def compute_clip_losses(self, batch):
62 | mixed_clip_loss = 0.
63 | clip_losses = {}
64 |
65 | if 1:
66 | for d in self.clip_training.split('_')[:1]:
67 | if d == 'image':
68 | features = self.clip_model.encode_image(batch['images']).float() # preprocess is done in dataloader
69 | elif d == 'text':
70 | features = batch['text_emb']
71 | motion_features = batch['motion_emb']
72 | # normalized features
73 | features_norm = features / features.norm(dim=-1, keepdim=True)
74 | motion_features_norm = motion_features / motion_features.norm(dim=-1, keepdim=True)
75 |
76 | logit_scale = self.latent_scale ** 2
77 | logits_per_motion = logit_scale * motion_features_norm @ features_norm.t()
78 | logits_per_d = logits_per_motion.t()
79 |
80 | batch_size = motion_features.shape[0]
81 | ground_truth = torch.arange(batch_size, dtype=torch.long, device=motion_features.device)
82 |
83 | ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
84 | ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
85 | clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.
86 |
87 | clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
88 | clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
89 | clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
90 | mixed_clip_loss += clip_mixed_loss
91 |
92 | return mixed_clip_loss, clip_losses
93 |
94 | def generate_src_mask(self, T, length):
95 | B = length.shape[0]
96 | src_mask = torch.ones(B, T)
97 | for i in range(B):
98 | for j in range(length[i], T):
99 | src_mask[i, j] = 0
100 | return src_mask
101 |
102 | def encode_motion(self, batch):
103 | batch["mask"] = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"]).to(batch["motions"].device)
104 | batch.update(self.motion_encoder(batch))
105 | batch["motion_emb"] = batch["motion_emb"] / batch["motion_emb"].norm(dim=-1, keepdim=True) * self.latent_scale
106 |
107 | return batch
108 |
109 | def encode_text(self, batch):
110 | device = next(self.parameters()).device
111 | raw_text = batch["text"]
112 |
113 | with torch.no_grad():
114 | text = clip.tokenize(raw_text, truncate=True).to(device)
115 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
116 | pe_tokens = x + self.positional_embedding.type(self.dtype)
117 |
118 | out = self.textTransEncoder(pe_tokens)
119 | out = self.text_ln(out)
120 |
121 | out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
122 | out = self.out(out)
123 |
124 | batch['text_emb'] = out
125 | batch["text_emb"] = batch["text_emb"] / batch["text_emb"].norm(dim=-1, keepdim=True) * self.latent_scale
126 |
127 | return batch
128 |
--------------------------------------------------------------------------------
/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import random
3 | from functools import partial
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | from mmcv.runner import get_dist_info
8 | from mmcv.utils import Registry, build_from_cfg
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.dataset import Dataset
11 |
12 | import torch
13 | from torch.utils.data import DistributedSampler as _DistributedSampler
14 |
15 |
16 | class DistributedSampler(_DistributedSampler):
17 |
18 | def __init__(self,
19 | dataset,
20 | num_replicas=None,
21 | rank=None,
22 | shuffle=True,
23 | round_up=True):
24 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
25 | self.shuffle = shuffle
26 | self.round_up = round_up
27 | if self.round_up:
28 | self.total_size = self.num_samples * self.num_replicas
29 | else:
30 | self.total_size = len(self.dataset)
31 |
32 | def __iter__(self):
33 | # deterministically shuffle based on epoch
34 | if self.shuffle:
35 | g = torch.Generator()
36 | g.manual_seed(self.epoch)
37 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
38 | else:
39 | indices = torch.arange(len(self.dataset)).tolist()
40 |
41 | # add extra samples to make it evenly divisible
42 | if self.round_up:
43 | indices = (
44 | indices *
45 | int(self.total_size / len(indices) + 1))[:self.total_size]
46 | assert len(indices) == self.total_size
47 |
48 | # subsample
49 | indices = indices[self.rank:self.total_size:self.num_replicas]
50 | if self.round_up:
51 | assert len(indices) == self.num_samples
52 |
53 | return iter(indices)
54 |
55 |
56 | def build_dataloader(dataset: Dataset,
57 | samples_per_gpu: int,
58 | workers_per_gpu: int,
59 | num_gpus: Optional[int] = 1,
60 | dist: Optional[bool] = True,
61 | shuffle: Optional[bool] = True,
62 | round_up: Optional[bool] = True,
63 | seed: Optional[Union[int, None]] = None,
64 | persistent_workers: Optional[bool] = True,
65 | **kwargs):
66 | """Build PyTorch DataLoader.
67 |
68 | In distributed training, each GPU/process has a dataloader.
69 | In non-distributed training, there is only one dataloader for all GPUs.
70 |
71 | Args:
72 | dataset (:obj:`Dataset`): A PyTorch dataset.
73 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
74 | batch size of each GPU.
75 | workers_per_gpu (int): How many subprocesses to use for data loading
76 | for each GPU.
77 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed
78 | training.
79 | dist (bool, optional): Distributed training/test or not. Default: True.
80 | shuffle (bool, optional): Whether to shuffle the data at every epoch.
81 | Default: True.
82 | round_up (bool, optional): Whether to round up the length of dataset by
83 | adding extra samples to make it evenly divisible. Default: True.
84 | persistent_workers (bool): If True, the data loader will not shutdown
85 | the worker processes after a dataset has been consumed once.
86 | This allows to maintain the workers Dataset instances alive.
87 | The argument also has effect in PyTorch>=1.7.0.
88 | Default: True
89 | kwargs: any keyword argument to be used to initialize DataLoader
90 |
91 | Returns:
92 | DataLoader: A PyTorch dataloader.
93 | """
94 | rank, world_size = get_dist_info()
95 | if dist:
96 | sampler = DistributedSampler(
97 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
98 | shuffle = False
99 | batch_size = samples_per_gpu
100 | num_workers = workers_per_gpu
101 | else:
102 | sampler = None
103 | batch_size = num_gpus * samples_per_gpu
104 | num_workers = num_gpus * workers_per_gpu
105 |
106 | init_fn = partial(
107 | worker_init_fn, num_workers=num_workers, rank=rank,
108 | seed=seed) if seed is not None else None
109 |
110 | data_loader = DataLoader(
111 | dataset,
112 | batch_size=batch_size,
113 | sampler=sampler,
114 | num_workers=num_workers,
115 | pin_memory=False,
116 | shuffle=shuffle,
117 | worker_init_fn=init_fn,
118 | persistent_workers=persistent_workers,
119 | **kwargs)
120 |
121 |
122 |
123 | return data_loader
124 |
125 |
126 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
127 | """Init random seed for each worker."""
128 | # The seed of each worker equals to
129 | # num_worker * rank + worker_id + user_seed
130 | worker_seed = num_workers * rank + worker_id + seed
131 | np.random.seed(worker_seed)
132 | random.seed(worker_seed)
133 |
--------------------------------------------------------------------------------
/utils/plot_script.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | from mpl_toolkits.mplot3d import Axes3D
6 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter
7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8 | import mpl_toolkits.mplot3d.axes3d as p3
9 | # import cv2
10 |
11 |
12 | def list_cut_average(ll, intervals):
13 | if intervals == 1:
14 | return ll
15 |
16 | bins = math.ceil(len(ll) * 1.0 / intervals)
17 | ll_new = []
18 | for i in range(bins):
19 | l_low = intervals * i
20 | l_high = l_low + intervals
21 | l_high = l_high if l_high < len(ll) else len(ll)
22 | ll_new.append(np.mean(ll[l_low:l_high]))
23 | return ll_new
24 |
25 |
26 | def plot_3d_motion(save_path, kinematic_tree, mp_joints, title, figsize=(10, 10), fps=120, radius=4, hints=None):
27 |
28 | matplotlib.use('Agg')
29 |
30 | title_sp = title.split(' ')
31 | if len(title_sp) > 20:
32 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
33 | elif len(title_sp) > 10:
34 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
35 |
36 | def init():
37 | ax.set_xlim3d([-radius / 4, radius / 4])
38 | ax.set_ylim3d([0, radius / 2])
39 | ax.set_zlim3d([0, radius / 2])
40 |
41 | fig.suptitle(title, fontsize=20)
42 | ax.grid(b=False)
43 |
44 | def plot_xzPlane(minx, maxx, miny, minz, maxz):
45 | ## Plot a plane XZ
46 | verts = [
47 | [minx, miny, minz],
48 | [minx, miny, maxz],
49 | [maxx, miny, maxz],
50 | [maxx, miny, minz]
51 | ]
52 | xz_plane = Poly3DCollection([verts])
53 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
54 | ax.add_collection3d(xz_plane)
55 |
56 | # if hint is not None:
57 | # mask = hint.sum(-1) != 0
58 | # hint = hint[mask]
59 |
60 | hints_new = []
61 | for hint in hints:
62 | if min(hint.shape) != 0:
63 | mask = hint.sum(-1) != 0
64 | hint = hint[mask]
65 | hints_new.append(hint)
66 | hints = hints_new
67 |
68 | fig = plt.figure(figsize=figsize)
69 | ax = p3.Axes3D(fig)
70 | init()
71 |
72 | mp_data = []
73 | frame_number = min([data.shape[0] for data in mp_joints])
74 |
75 | colors = ['red', 'green', 'black', 'blue', 'red',
76 | 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
77 | 'darkred', 'darkred', 'darkred', 'black', 'blue']
78 |
79 | mp_offset = list(range(-len(mp_joints)//2, len(mp_joints)//2, 1))
80 | mp_colors = [[colors[i]] * 15 for i in range(len(mp_offset))]
81 |
82 | for i,joints in enumerate(mp_joints):
83 |
84 | # (seq_len, joints_num, 3)
85 | data = joints.copy().reshape(len(joints), -1, 3)
86 |
87 | MINS = data.min(axis=0).min(axis=0)
88 | MAXS = data.max(axis=0).max(axis=0)
89 |
90 | height_offset = MINS[1]
91 | data[:, :, 1] -= height_offset
92 | trajec = data[:, 0, [0, 2]]
93 |
94 | # if hint is not None:
95 | # hint[..., 1] -= height_offset
96 |
97 | hints_new = []
98 | for hint in hints:
99 | if min(hint.shape) != 0:
100 | hint[..., 1] -= height_offset
101 | hints_new.append(hint)
102 | hints = hints_new
103 |
104 | # data[:, :, 0] -= data[0:1, 0:1, 0]
105 | # data[:, :, 0] += mp_offset[i]
106 | #
107 | # data[:, :, 2] -= data[0:1, 0:1, 2]
108 | mp_data.append({"joints":data,
109 | "MINS":MINS,
110 | "MAXS":MAXS,
111 | "trajec":trajec, })
112 |
113 | # print(trajec.shape)
114 |
115 | def update(index):
116 | # print(index)
117 | ax.lines = []
118 | ax.collections = []
119 | ax.view_init(elev=120, azim=-90)
120 | ax.dist = 15#7.5
121 | # ax =
122 | plot_xzPlane(-3, 3, 0, -3, 3)
123 |
124 | for pid,data in enumerate(mp_data):
125 | for i, (chain, color) in enumerate(zip(kinematic_tree, mp_colors[pid])):
126 | # print(color)
127 | if i < 5:
128 | linewidth = 2.0
129 | else:
130 | linewidth = 1.0
131 |
132 | ax.plot3D(data["joints"][index, chain, 0], data["joints"][index, chain, 1], data["joints"][index, chain, 2], linewidth=linewidth,
133 | color=color)
134 |
135 | # if hint is not None:
136 | # ax.plot3D(hint[:index+1, 0], hint[:index+1, 1], hint[:index+1, 2], color="blue", linewidth=3.0)
137 |
138 | for idx, hint in enumerate(hints):
139 | if min(hint.shape) != 0:
140 | ax.plot3D(hint[:index+1, 0], hint[:index+1, 1], hint[:index+1, 2], color=colors[-1-idx], linewidth=3.0)
141 |
142 | plt.axis('off')
143 | ax.set_xticklabels([])
144 | ax.set_yticklabels([])
145 | ax.set_zticklabels([])
146 |
147 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
148 |
149 | # writer = FFMpegFileWriter(fps=fps)
150 | ani.save(save_path, fps=fps)
151 | plt.close()
152 |
--------------------------------------------------------------------------------
/models/intergen.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping, Union, List
2 |
3 | import torch
4 | import clip
5 |
6 | from torch import nn
7 | from models import *
8 | from collections import OrderedDict
9 |
10 | class InterGenSpatialControlNet(nn.Module):
11 | def __init__(self, cfg):
12 | super().__init__()
13 | self.cfg = cfg
14 | self.latent_dim = cfg.LATENT_DIM
15 | self.decoder = InterDiffusionSpatialControlNet(cfg, sampling_strategy=cfg.STRATEGY)
16 | self.archi = cfg.ARCHI
17 | clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False)
18 | self.token_embedding = clip_model.token_embedding
19 | self.clip_transformer = clip_model.transformer
20 | self.positional_embedding = clip_model.positional_embedding
21 | self.ln_final = clip_model.ln_final
22 | self.dtype = clip_model.dtype
23 |
24 | set_requires_grad(self.clip_transformer, False)
25 | set_requires_grad(self.token_embedding, False)
26 | set_requires_grad(self.ln_final, False)
27 |
28 | clipTransEncoderLayer = nn.TransformerEncoderLayer(
29 | d_model=768,
30 | nhead=8,
31 | dim_feedforward=2048,
32 | dropout=0.1,
33 | activation="gelu",
34 | batch_first=True)
35 | self.clipTransEncoder = nn.TransformerEncoder(
36 | clipTransEncoderLayer,
37 | num_layers=2)
38 | self.clip_ln = nn.LayerNorm(768)
39 |
40 | self.positional_embedding.requires_grad = False
41 |
42 | set_requires_grad(self.clipTransEncoder, False)
43 | set_requires_grad(self.clip_ln, False)
44 |
45 |
46 | def compute_loss(self, batch):
47 |
48 | batch = self.text_process(batch)
49 | losses = self.decoder.compute_loss(batch)
50 | return losses["total"], losses
51 |
52 | def decode_motion(self, batch):
53 | batch.update(self.decoder(batch)) # batch['output'].shape = [1, 210, 524]
54 | return batch
55 |
56 | def forward(self, batch):
57 | return self.compute_loss(batch)
58 |
59 | def forward_test(self, batch): # batch: 'motion_lens', 'prompt', 'text'=['prompt']
60 |
61 | batch = self.text_process(batch)
62 | batch.update(self.decode_motion(batch))
63 | return batch
64 |
65 | def text_process(self, batch):
66 | device = next(self.clip_transformer.parameters()).device
67 |
68 | if "text" in batch and batch["text"] is not None and not isinstance(batch["text"], torch.Tensor):
69 |
70 | raw_text = batch["text"]
71 |
72 | with torch.no_grad():
73 |
74 | text = clip.tokenize(raw_text, truncate=True).to(device) # [batch_szie, n_ctx]=[1,77]
75 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768]
76 | pe_tokens = x + self.positional_embedding.type(self.dtype)
77 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND [n_ctx, batch_size, d_model]=[77,1,768]
78 | x = self.clip_transformer(x)
79 | x = x.permute(1, 0, 2)
80 | clip_out = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768]
81 |
82 | out = self.clipTransEncoder(clip_out) # [batch_size, n_ctx, d_model]=[1,77,768]
83 | out = self.clip_ln(out)
84 |
85 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
86 | batch["cond"] = cond
87 |
88 | if "text_multi_person" in batch and batch["text_multi_person"] is not None and not isinstance(batch["text_multi_person"], torch.Tensor):
89 | raw_text = batch["text_multi_person"]
90 |
91 | with torch.no_grad():
92 |
93 | text = clip.tokenize(raw_text, truncate=True).to(device) # [batch_szie, n_ctx]=[1,77]
94 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768]
95 | pe_tokens = x + self.positional_embedding.type(self.dtype)
96 | x = pe_tokens.permute(1, 0, 2) # NLD -> LND [n_ctx, batch_size, d_model]=[77,1,768]
97 | x = self.clip_transformer(x)
98 | x = x.permute(1, 0, 2)
99 | clip_out = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, d_model]=[1,77,768]
100 |
101 | out = self.clipTransEncoder(clip_out) # [batch_size, n_ctx, d_model]=[1,77,768]
102 | out = self.clip_ln(out)
103 |
104 | cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
105 |
106 | if "cond" in batch:
107 | batch["cond"] = batch["cond"] + cond
108 | else:
109 | batch["cond"] = cond
110 |
111 | return batch
112 |
113 | def load_state_dict(self, state_dict: Union[List[Mapping[str, Any]],Mapping[str, Any]], strict: bool = True):
114 |
115 | if self.archi == 'single':
116 | return super().load_state_dict(state_dict, strict=strict)
117 |
118 | new_state_dict = OrderedDict()
119 | for name, value in state_dict.items():
120 | if 'decoder.net.net' in name:
121 | name_new = name.replace('decoder.net.net', 'decoder.net.control_branch')
122 | new_state_dict[name_new] = value
123 | new_state_dict[name] = value
124 | else:
125 | new_state_dict[name] = value
126 | return super().load_state_dict(new_state_dict, strict=strict)
127 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(sys.path[0] + r"/../")
3 | import logging
4 | import torch
5 | import lightning.pytorch as pl
6 | import torch.optim as optim
7 | from collections import OrderedDict
8 | from datasets import DataModule
9 | from configs import get_config
10 | from os.path import join as pjoin
11 | from torch.utils.tensorboard import SummaryWriter
12 | from models import *
13 | import os
14 | import argparse
15 |
16 | os.environ['PL_TORCH_DISTRIBUTED_BACKEND'] = 'nccl'
17 | from lightning.pytorch.strategies import DDPStrategy
18 | torch.set_float32_matmul_precision('medium')
19 |
20 | class LitTrainModel(pl.LightningModule):
21 | def __init__(self, model, cfg):
22 | super().__init__()
23 | # cfg init
24 | self.cfg = cfg
25 | self.mode = cfg.TRAIN.MODE
26 |
27 | self.automatic_optimization = False
28 |
29 | self.save_root = pjoin(self.cfg.GENERAL.CHECKPOINT, self.cfg.GENERAL.EXP_NAME)
30 | self.model_dir = pjoin(self.save_root, 'model')
31 | self.meta_dir = pjoin(self.save_root, 'meta')
32 | self.log_dir = pjoin(self.save_root, 'log')
33 |
34 | os.makedirs(self.model_dir, exist_ok=True)
35 | os.makedirs(self.meta_dir, exist_ok=True)
36 | os.makedirs(self.log_dir, exist_ok=True)
37 |
38 | self.model = model
39 |
40 | self.writer = SummaryWriter(self.log_dir)
41 |
42 | logging.basicConfig(filename=os.path.join(self.log_dir,'train.log'), filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s", datefmt="%d-%M-%Y %H:%M:%S", level=logging.DEBUG)
43 |
44 | def _configure_optim(self):
45 |
46 | optimizer = optim.AdamW([p for p in self.model.parameters() if p.requires_grad], lr=float(self.cfg.TRAIN.LR), weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
47 |
48 | scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=10, max_iters=self.cfg.TRAIN.EPOCH, verbose=True)
49 | return [optimizer], [scheduler]
50 |
51 | def configure_optimizers(self):
52 | return self._configure_optim()
53 |
54 | def forward(self, batch_data):
55 |
56 | name, text, text_multi_person, motion1, motion2, motion_lens, spatial_condition = batch_data
57 |
58 | motion1 = motion1.detach().float()
59 |
60 | batch = OrderedDict({})
61 | if min(motion2.shape) == 0: # for single human training
62 | motions = motion1
63 | elif min(spatial_condition.shape) == 0: # for pure double huaman training
64 | motion2 = motion2.detach().float()
65 | motions = torch.cat([motion1, motion2], dim=-1)
66 |
67 |
68 | B, T = motion1.shape[:2]
69 |
70 | batch["motions"] = motions.reshape(B, T, -1).type(torch.float32)
71 | batch["motion_lens"] = motion_lens.long()
72 | batch["person_num"] = motions.shape[-1] // motion1.shape[-1]
73 |
74 | if isinstance(text, torch.Tensor):
75 | batch["text"] = None
76 | else:
77 | batch["text"] = text
78 |
79 | if isinstance(text_multi_person, torch.Tensor):
80 | batch["text_multi_person"] = None
81 | else:
82 | batch["text_multi_person"] = text_multi_person
83 |
84 | loss, loss_logs = self.model(batch)
85 | return loss, loss_logs
86 |
87 | def on_train_start(self):
88 | self.rank = 0
89 | self.world_size = 1
90 | self.start_time = time.time()
91 | self.it = self.cfg.TRAIN.LAST_ITER if self.cfg.TRAIN.LAST_ITER else 0
92 | self.epoch = self.cfg.TRAIN.LAST_EPOCH if self.cfg.TRAIN.LAST_EPOCH else 0
93 | self.logs = OrderedDict()
94 |
95 |
96 | def training_step(self, batch, batch_idx):
97 | loss, loss_logs = self.forward(batch)
98 | opt = self.optimizers()
99 | opt.zero_grad()
100 | self.manual_backward(loss)
101 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
102 | opt.step()
103 |
104 | return {"loss": loss,
105 | "loss_logs": loss_logs}
106 |
107 |
108 | def on_train_batch_end(self, outputs, batch, batch_idx):
109 | if outputs.get('skip_batch') or not outputs.get('loss_logs'):
110 | return
111 | for k, v in outputs['loss_logs'].items():
112 | if k not in self.logs:
113 | self.logs[k] = v.item()
114 | else:
115 | self.logs[k] += v.item()
116 |
117 | self.it += 1
118 | if self.it % self.cfg.TRAIN.LOG_STEPS == 0 and self.device.index == 0:
119 | mean_loss = OrderedDict({})
120 | for tag, value in self.logs.items():
121 | mean_loss[tag] = value / self.cfg.TRAIN.LOG_STEPS
122 | self.writer.add_scalar(tag, mean_loss[tag], self.it)
123 | self.logs = OrderedDict()
124 | print_current_loss(self.start_time, self.it, mean_loss,
125 | self.trainer.current_epoch,
126 | inner_iter=batch_idx,
127 | lr=self.trainer.optimizers[0].param_groups[0]['lr'])
128 |
129 | def on_train_epoch_end(self):
130 | # pass
131 | sch = self.lr_schedulers()
132 | if sch is not None:
133 | sch.step()
134 |
135 | def save(self, file_name):
136 | state = {}
137 | try:
138 | state['model'] = self.model.module.state_dict()
139 | except:
140 | state['model'] = self.model.state_dict()
141 | torch.save(state, file_name, _use_new_zipfile_serialization=False)
142 | return
143 |
144 |
145 | def build_models(cfg):
146 | model = InterGenSpatialControlNet(cfg)
147 | return model
148 |
149 |
150 | if __name__ == '__main__':
151 |
152 | parser = argparse.ArgumentParser(description='Process configs.')
153 | parser.add_argument('--model_config', type=str, help='model config')
154 | parser.add_argument('--dataset_config', type=str, help='dataset config')
155 | parser.add_argument('--train_config', type=str, help='train config')
156 |
157 | args = parser.parse_args()
158 |
159 | print(os.getcwd())
160 | model_cfg = get_config(args.model_config)
161 | train_cfg = get_config(args.train_config)
162 | data_cfg = get_config(args.dataset_config).interhuman
163 |
164 | datamodule = DataModule(data_cfg, train_cfg.TRAIN.BATCH_SIZE, train_cfg.TRAIN.NUM_WORKERS)
165 | model = build_models(model_cfg)
166 |
167 | if train_cfg.TRAIN.FROM_PRETRAIN:
168 | ckpt = torch.load(train_cfg.TRAIN.FROM_PRETRAIN, map_location="cpu")
169 | for k in list(ckpt["state_dict"].keys()):
170 | if "model" in k:
171 | ckpt["state_dict"][k.replace("model.", "")] = ckpt["state_dict"].pop(k)
172 | model.load_state_dict(ckpt["state_dict"], strict=False)
173 | print("checkpoint state loaded!")
174 |
175 | if train_cfg.TRAIN.RESUME:
176 | resume_path = train_cfg.TRAIN.RESUME
177 | else:
178 | resume_path = None
179 |
180 | litmodel = LitTrainModel(model, train_cfg)
181 |
182 | checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=litmodel.model_dir,
183 | every_n_epochs=train_cfg.TRAIN.SAVE_EPOCH,
184 | save_top_k = train_cfg.TRAIN.SAVE_TOPK,
185 | )
186 | trainer = pl.Trainer(
187 | default_root_dir=litmodel.model_dir,
188 | devices="auto", accelerator='gpu',
189 | max_epochs=train_cfg.TRAIN.EPOCH,
190 | strategy=DDPStrategy(find_unused_parameters=True),
191 | precision=32,
192 | callbacks=[checkpoint_callback],
193 | detect_anomaly=True
194 | )
195 |
196 | trainer.fit(model=litmodel, datamodule=datamodule, ckpt_path=resume_path)
197 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 | from scipy.ndimage import uniform_filter1d
4 |
5 | emb_scale = 6
6 |
7 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
8 | def euclidean_distance_matrix(matrix1, matrix2):
9 | """
10 | Params:
11 | -- matrix1: N1 x D
12 | -- matrix2: N2 x D
13 | Returns:
14 | -- dist: N1 x N2
15 | dist[i, j] == distance(matrix1[i], matrix2[j])
16 | """
17 | assert matrix1.shape[1] == matrix2.shape[1]
18 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
19 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
20 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
21 | dists = np.sqrt(d1 + d2 + d3) # broadcasting
22 | return dists
23 |
24 | def calculate_top_k(mat, top_k):
25 | size = mat.shape[0]
26 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
27 | bool_mat = (mat == gt_mat)
28 | correct_vec = False
29 | top_k_list = []
30 | for i in range(top_k):
31 | # print(correct_vec, bool_mat[:, i])
32 | correct_vec = (correct_vec | bool_mat[:, i])
33 | # print(correct_vec)
34 | top_k_list.append(correct_vec[:, None])
35 | top_k_mat = np.concatenate(top_k_list, axis=1)
36 | return top_k_mat
37 |
38 |
39 | def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
40 | dist_mat = euclidean_distance_matrix(embedding1, embedding2)
41 | argmax = np.argsort(dist_mat, axis=1)
42 | top_k_mat = calculate_top_k(argmax, top_k)
43 | if sum_all:
44 | return top_k_mat.sum(axis=0)
45 | else:
46 | return top_k_mat
47 |
48 |
49 | def calculate_matching_score(embedding1, embedding2, sum_all=False):
50 | assert len(embedding1.shape) == 2
51 | assert embedding1.shape[0] == embedding2.shape[0]
52 | assert embedding1.shape[1] == embedding2.shape[1]
53 |
54 | dist = linalg.norm(embedding1 - embedding2, axis=1)
55 | if sum_all:
56 | return dist.sum(axis=0)
57 | else:
58 | return dist
59 |
60 |
61 |
62 | def calculate_activation_statistics(activations):
63 | """
64 | Params:
65 | -- activation: num_samples x dim_feat
66 | Returns:
67 | -- mu: dim_feat
68 | -- sigma: dim_feat x dim_feat
69 | """
70 | activations = activations * emb_scale
71 | mu = np.mean(activations, axis=0)
72 | cov = np.cov(activations, rowvar=False)
73 | return mu, cov
74 |
75 |
76 | def calculate_diversity(activation, diversity_times):
77 | assert len(activation.shape) == 2
78 | assert activation.shape[0] > diversity_times
79 | num_samples = activation.shape[0]
80 |
81 | activation = activation * emb_scale
82 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
83 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
84 | dist = linalg.norm((activation[first_indices] - activation[second_indices])/2, axis=1)
85 | return dist.mean()
86 |
87 |
88 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
89 | """Numpy implementation of the Frechet Distance.
90 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
91 | and X_2 ~ N(mu_2, C_2) is
92 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
93 | Stable version by Dougal J. Sutherland.
94 | Params:
95 | -- mu1 : Numpy array containing the activations of a layer of the
96 | inception net (like returned by the function 'get_predictions')
97 | for generated samples.
98 | -- mu2 : The sample mean over activations, precalculated on an
99 | representative data set.
100 | -- sigma1: The covariance matrix over activations for generated samples.
101 | -- sigma2: The covariance matrix over activations, precalculated on an
102 | representative data set.
103 | Returns:
104 | -- : The Frechet Distance.
105 | """
106 |
107 | mu1 = np.atleast_1d(mu1)
108 | mu2 = np.atleast_1d(mu2)
109 |
110 | sigma1 = np.atleast_2d(sigma1)
111 | sigma2 = np.atleast_2d(sigma2)
112 |
113 | assert mu1.shape == mu2.shape, \
114 | 'Training and test mean vectors have different lengths'
115 | assert sigma1.shape == sigma2.shape, \
116 | 'Training and test covariances have different dimensions'
117 |
118 | diff = mu1 - mu2
119 |
120 | # Product might be almost singular
121 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
122 | if not np.isfinite(covmean).all():
123 | msg = ('fid calculation produces singular product; '
124 | 'adding %s to diagonal of cov estimates') % eps
125 | print(msg)
126 | offset = np.eye(sigma1.shape[0]) * eps
127 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
128 |
129 | # Numerical error might give slight imaginary component
130 | if np.iscomplexobj(covmean):
131 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
132 | m = np.max(np.abs(covmean.imag))
133 | raise ValueError('Imaginary component {}'.format(m))
134 | covmean = covmean.real
135 |
136 | tr_covmean = np.trace(covmean)
137 |
138 | return (diff.dot(diff) + np.trace(sigma1) +
139 | np.trace(sigma2) - 2 * tr_covmean)
140 |
141 |
142 | def calculate_multimodality(activation, multimodality_times):
143 | assert len(activation.shape) == 3
144 | assert activation.shape[1] > multimodality_times
145 | num_per_sent = activation.shape[1]
146 |
147 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
148 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
149 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
150 | return dist.mean()
151 |
152 | def calculate_trajectory_error(dist_error, mean_err_traj, mask, strict=True):
153 | ''' dist_error shape [5]: error for each kps in metre
154 | Two threshold: 20 cm and 50 cm.
155 | If mean error in sequence is more then the threshold, fails
156 | return: traj_fail(0.2), traj_fail(0.5), all_kps_fail(0.2), all_kps_fail(0.5), all_mean_err.
157 | Every metrics are already averaged.
158 | '''
159 | # mean_err_traj = dist_error.mean(1)
160 | if strict:
161 | # Traj fails if any of the key frame fails
162 | traj_fail_02 = 1.0 - (dist_error <= 0.2).all()
163 | traj_fail_05 = 1.0 - (dist_error <= 0.5).all()
164 | else:
165 | # Traj fails if the mean error of all keyframes more than the threshold
166 | traj_fail_02 = (mean_err_traj > 0.2)
167 | traj_fail_05 = (mean_err_traj > 0.5)
168 | all_fail_02 = (dist_error > 0.2).sum() / mask.sum()
169 | all_fail_05 = (dist_error > 0.5).sum() / mask.sum()
170 |
171 | # out = {"traj_fail_02": traj_fail_02,
172 | # "traj_fail_05": traj_fail_05,
173 | # "all_fail_02": all_fail_02,
174 | # "all_fail_05": all_fail_05,
175 | # "all_mean_err": dist_error.mean()}
176 | return np.array([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()])
177 |
178 |
179 | def calculate_trajectory_diversity(trajectories, lengths):
180 | ''' Standard diviation of point locations in the trajectories
181 | Args:
182 | trajectories: [bs, rep, 196, 2]
183 | lengths: [bs]
184 | '''
185 | # [32, 2, 196, 2 (xz)]
186 | # mean_trajs = trajectories.mean(1, keepdims=True)
187 | # dist_to_mean = np.linalg.norm(trajectories - mean_trajs, axis=3)
188 | def traj_div(traj, length):
189 | # traj [rep, 196, 2]
190 | # length (int)
191 | traj = traj[:, :length, :]
192 | # point_var = traj.var(axis=0, keepdims=True).mean()
193 | # point_var = np.sqrt(point_var)
194 | # return point_var
195 |
196 | mean_traj = traj.mean(axis=0, keepdims=True)
197 | dist = np.sqrt(((traj - mean_traj)**2).sum(axis=2))
198 | rms_dist = np.sqrt((dist**2).mean())
199 | return rms_dist
200 |
201 | div = []
202 | for i in range(len(trajectories)):
203 | div.append(traj_div(trajectories[i], lengths[i]))
204 | return np.array(div).mean()
205 |
206 |
207 | def calculate_skating_ratio(motions):
208 | thresh_height = 0.05 # 10
209 | fps = 20.0
210 | thresh_vel = 0.50 # 20 cm /s
211 | avg_window = 5 # frames
212 |
213 | batch_size = motions.shape[0]
214 | # 10 left, 11 right foot. XZ plane, y up
215 | # motions [bs, 22, 3, max_len]
216 | verts_feet = motions[:, [10, 11], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len]
217 | verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1], axis=2) * fps # [bs, 2, max_len-1]
218 | # [bs, 2, max_len-1]
219 | vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0)
220 |
221 | verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len]
222 | # If feet touch ground in agjecent frames
223 | feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height), (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1]
224 | # skate velocity
225 | skate_vel = feet_contact * vel_avg
226 |
227 | # it must both skating in the current frame
228 | skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel))
229 | # and also skate in the windows of frames
230 | skating = np.logical_and(skating, (vel_avg > thresh_vel))
231 |
232 | # Both feet slide
233 | skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1]
234 | skating_ratio = np.sum(skating, axis=1) / skating.shape[1]
235 |
236 | return skating_ratio, skate_vel
237 |
238 | # verts_feet_gt = markers_got[:, [16, 47], :].detach().cpu().numpy() # [119, 2, 3] heels
239 | # verts_feet_horizon_vel_gt = np.linalg.norm(verts_feet_gt[1:, :, :-1] - verts_feet_gt[:-1, :, :-1], axis=-1) * 30
240 |
241 | # verts_feet_height_gt = verts_feet_gt[:, :, -1][0:-1] # [118,2]
242 | # min_z = markers_gt[:, :, 2].min().detach().cpu().numpy()
243 | # verts_feet_height_gt = verts_feet_height_gt - min_z
244 |
245 | # skating_gt = (verts_feet_horizon_vel_gt > thresh_vel) * (verts_feet_height_gt < thresh_height)
246 | # skating_gt = np.sum(np.logival_and(skating_gt[:, 0], skating_gt[:, 1])) / 118
247 | # skating_gt_list.append(skating_gt)
248 |
249 |
250 | def calculate_skating_ratio_kit(motions):
251 | thresh_height = 0.05 # 10
252 | fps = 20.0
253 | thresh_vel = 0.50 # 20 cm /s
254 | avg_window = 5 # frames
255 |
256 | batch_size = motions.shape[0]
257 | # 15 left, 20 right foot. XZ plane, y up
258 | # motions [bs, 22, 3, max_len]
259 | verts_feet = motions[:, [15, 20], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len]
260 | verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1], axis=2) * fps # [bs, 2, max_len-1]
261 | # [bs, 2, max_len-1]
262 | vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0)
263 |
264 | verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len]
265 | # If feet touch ground in agjecent frames
266 | feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height), (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1]
267 | # skate velocity
268 | skate_vel = feet_contact * vel_avg
269 |
270 | # it must both skating in the current frame
271 | skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel))
272 | # and also skate in the windows of frames
273 | skating = np.logical_and(skating, (vel_avg > thresh_vel))
274 |
275 | # Both feet slide
276 | skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1]
277 | skating_ratio = np.sum(skating, axis=1) / skating.shape[1]
278 |
279 | return skating_ratio, skate_vel
280 |
281 |
282 | def control_l2(motion, hint, hint_mask):
283 | # motion: b, seq, 22, 3
284 | # hint: b, seq, 22, 1
285 | loss = np.linalg.norm((motion - hint) * hint_mask, axis=-1)
286 | # loss = loss.sum() / hint_mask.sum()
287 | return loss
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | sys.path.append(sys.path[0]+r"/../")
4 | import numpy as np
5 | import torch
6 |
7 | from datetime import datetime
8 | from datasets import get_dataset_motion_loader, get_motion_loader
9 | from models import *
10 | # from models_intergen2 import *
11 |
12 | from utils.metrics import *
13 | from datasets import EvaluatorModelWrapper
14 | from collections import OrderedDict
15 | from utils.plot_script import *
16 | from utils.utils import *
17 | from configs import get_config
18 | from os.path import join as pjoin
19 | from tqdm import tqdm
20 | import argparse
21 |
22 | os.environ['WORLD_SIZE'] = '1'
23 | os.environ['RANK'] = '0'
24 | os.environ['MASTER_ADDR'] = 'localhost'
25 | os.environ['MASTER_PORT'] = '12345'
26 | torch.multiprocessing.set_sharing_strategy('file_system')
27 |
28 | def build_models(cfg):
29 | model = InterGenSpatialControlNet(cfg)
30 | return model
31 |
32 | def evaluate_matching_score(motion_loaders, file):
33 | match_score_dict = OrderedDict({})
34 | R_precision_dict = OrderedDict({})
35 | activation_dict = OrderedDict({})
36 | # print(motion_loaders.keys())
37 | print('========== Evaluating MM Distance ==========')
38 | for motion_loader_name, motion_loader in motion_loaders.items():
39 | all_motion_embeddings = []
40 | score_list = []
41 | all_size = 0
42 | mm_dist_sum = 0
43 | top_k_count = 0
44 | # print(motion_loader_name)
45 | with torch.no_grad():
46 | for idx, batch in tqdm(enumerate(motion_loader)):
47 | text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch)
48 | # print(text_embeddings.shape)
49 | # print(motion_embeddings.shape)
50 |
51 | dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
52 | motion_embeddings.cpu().numpy())
53 | # print(dist_mat.shape)
54 | mm_dist_sum += dist_mat.trace()
55 |
56 | argsmax = np.argsort(dist_mat, axis=1)
57 | # print(argsmax.shape)
58 |
59 | top_k_mat = calculate_top_k(argsmax, top_k=3)
60 | top_k_count += top_k_mat.sum(axis=0)
61 |
62 | all_size += text_embeddings.shape[0]
63 |
64 | all_motion_embeddings.append(motion_embeddings.cpu().numpy())
65 |
66 | all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
67 | mm_dist = mm_dist_sum / all_size
68 | R_precision = top_k_count / all_size
69 | match_score_dict[motion_loader_name] = mm_dist
70 | R_precision_dict[motion_loader_name] = R_precision
71 | activation_dict[motion_loader_name] = all_motion_embeddings
72 |
73 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}')
74 | print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}', file=file, flush=True)
75 |
76 | line = f'---> [{motion_loader_name}] R_precision: '
77 | for i in range(len(R_precision)):
78 | line += '(top %d): %.4f ' % (i+1, R_precision[i])
79 | print(line)
80 | print(line, file=file, flush=True)
81 |
82 | return match_score_dict, R_precision_dict, activation_dict
83 |
84 |
85 | def evaluate_fid(groundtruth_loader, activation_dict, file):
86 | eval_dict = OrderedDict({})
87 | gt_motion_embeddings = []
88 | print('========== Evaluating FID ==========')
89 | with torch.no_grad():
90 | for idx, batch in tqdm(enumerate(groundtruth_loader)):
91 | motion_embeddings = eval_wrapper.get_motion_embeddings(batch)
92 | gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
93 | gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
94 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
95 |
96 | # print(gt_mu)
97 | for model_name, motion_embeddings in activation_dict.items():
98 | mu, cov = calculate_activation_statistics(motion_embeddings)
99 | # print(mu)
100 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
101 | print(f'---> [{model_name}] FID: {fid:.4f}')
102 | print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
103 | eval_dict[model_name] = fid
104 | return eval_dict
105 |
106 |
107 | def evaluate_diversity(activation_dict, file):
108 | eval_dict = OrderedDict({})
109 | print('========== Evaluating Diversity ==========')
110 | for model_name, motion_embeddings in activation_dict.items():
111 | diversity = calculate_diversity(motion_embeddings, diversity_times)
112 | eval_dict[model_name] = diversity
113 | print(f'---> [{model_name}] Diversity: {diversity:.4f}')
114 | print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
115 | return eval_dict
116 |
117 |
118 | def evaluate_multimodality(mm_motion_loaders, file):
119 | eval_dict = OrderedDict({})
120 | print('========== Evaluating MultiModality ==========')
121 | for model_name, mm_motion_loader in mm_motion_loaders.items():
122 | mm_motion_embeddings = []
123 | with torch.no_grad():
124 | for idx, batch in enumerate(mm_motion_loader):
125 |
126 | batch[-4] = batch[-4][0]
127 | batch[-3] = batch[-3][0]
128 | batch[-2] = batch[-2][0]
129 | motion_embedings = eval_wrapper.get_motion_embeddings(batch)
130 | mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
131 | if len(mm_motion_embeddings) == 0:
132 | multimodality = 0
133 | else:
134 | mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
135 | multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
136 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
137 | print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
138 | eval_dict[model_name] = multimodality
139 | return eval_dict
140 |
141 |
142 | def get_metric_statistics(values):
143 | mean = np.mean(values, axis=0)
144 | std = np.std(values, axis=0)
145 | conf_interval = 1.96 * std / np.sqrt(replication_times)
146 | return mean, conf_interval
147 |
148 |
149 | def evaluation(log_file):
150 | with open(log_file, 'w') as f:
151 | all_metrics = OrderedDict({'MM Distance': OrderedDict({}),
152 | 'R_precision': OrderedDict({}),
153 | 'FID': OrderedDict({}),
154 | 'Diversity': OrderedDict({}),
155 | 'MultiModality': OrderedDict({})
156 | }
157 | )
158 | for replication in range(replication_times):
159 | motion_loaders = {}
160 | mm_motion_loaders = {}
161 | motion_loaders['ground truth'] = gt_loader
162 | for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
163 | motion_loader, mm_motion_loader = motion_loader_getter()
164 | motion_loaders[motion_loader_name] = motion_loader
165 | mm_motion_loaders[motion_loader_name] = mm_motion_loader
166 |
167 | print(f'==================== Replication {replication} ====================')
168 | print(f'==================== Replication {replication} ====================', file=f, flush=True)
169 | print(f'Time: {datetime.now()}')
170 | print(f'Time: {datetime.now()}', file=f, flush=True)
171 | mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f)
172 |
173 | print(f'Time: {datetime.now()}')
174 | print(f'Time: {datetime.now()}', file=f, flush=True)
175 | fid_score_dict = evaluate_fid(gt_loader, acti_dict, f)
176 |
177 | print(f'Time: {datetime.now()}')
178 | print(f'Time: {datetime.now()}', file=f, flush=True)
179 | div_score_dict = evaluate_diversity(acti_dict, f)
180 |
181 | print(f'Time: {datetime.now()}')
182 | print(f'Time: {datetime.now()}', file=f, flush=True)
183 | mm_score_dict = evaluate_multimodality(mm_motion_loaders, f)
184 |
185 | print(f'!!! DONE !!!')
186 | print(f'!!! DONE !!!', file=f, flush=True)
187 |
188 | for key, item in mat_score_dict.items():
189 | if key not in all_metrics['MM Distance']:
190 | all_metrics['MM Distance'][key] = [item]
191 | else:
192 | all_metrics['MM Distance'][key] += [item]
193 |
194 | for key, item in R_precision_dict.items():
195 | if key not in all_metrics['R_precision']:
196 | all_metrics['R_precision'][key] = [item]
197 | else:
198 | all_metrics['R_precision'][key] += [item]
199 |
200 | for key, item in fid_score_dict.items():
201 | if key not in all_metrics['FID']:
202 | all_metrics['FID'][key] = [item]
203 | else:
204 | all_metrics['FID'][key] += [item]
205 |
206 | for key, item in div_score_dict.items():
207 | if key not in all_metrics['Diversity']:
208 | all_metrics['Diversity'][key] = [item]
209 | else:
210 | all_metrics['Diversity'][key] += [item]
211 |
212 | for key, item in mm_score_dict.items():
213 | if key not in all_metrics['MultiModality']:
214 | all_metrics['MultiModality'][key] = [item]
215 | else:
216 | all_metrics['MultiModality'][key] += [item]
217 |
218 |
219 | # print(all_metrics['Diversity'])
220 | for metric_name, metric_dict in all_metrics.items():
221 | print('========== %s Summary ==========' % metric_name)
222 | print('========== %s Summary ==========' % metric_name, file=f, flush=True)
223 |
224 | for model_name, values in metric_dict.items():
225 | # print(metric_name, model_name)
226 | mean, conf_interval = get_metric_statistics(np.array(values))
227 | # print(mean, mean.dtype)
228 | if isinstance(mean, np.float64) or isinstance(mean, np.float32):
229 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
230 | print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
231 | elif isinstance(mean, np.ndarray):
232 | line = f'---> [{model_name}]'
233 | for i in range(len(mean)):
234 | line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
235 | print(line)
236 | print(line, file=f, flush=True)
237 |
238 |
239 | if __name__ == '__main__':
240 |
241 | parser = argparse.ArgumentParser(description='Process configs.')
242 | parser.add_argument('--model_config', type=str, help='model config')
243 | parser.add_argument('--dataset_config', type=str, help='dataset config')
244 | parser.add_argument('--evalmodel_config', type=str, help='train config')
245 |
246 | args = parser.parse_args()
247 |
248 | mm_num_samples = 100
249 | mm_num_repeats = 30
250 | mm_num_times = 10
251 |
252 | diversity_times = 300
253 | replication_times = 20
254 |
255 | # batch_size is fixed to 96!!
256 | batch_size = 96
257 |
258 | data_cfg = get_config(args.dataset_config).interhuman_test
259 | cfg_path_list = [args.model_config]
260 |
261 |
262 | normalizer = MotionNormalizerTorch()
263 |
264 | spatial_mean = normalizer.motion_mean[:66]
265 | spatial_std = normalizer.motion_std[:66]
266 |
267 | eval_motion_loaders = {}
268 | for cfg_path in cfg_path_list:
269 | model_cfg = get_config(cfg_path)
270 |
271 | device = torch.device('cuda:%d' % 0 if torch.cuda.is_available() else 'cpu')
272 | torch.cuda.set_device(0)
273 | model = build_models(model_cfg)
274 |
275 | if model_cfg.CHECKPOINT and isinstance(model_cfg.CHECKPOINT, list):
276 |
277 | ckpts = []
278 | for ckpt_path in model_cfg.CHECKPOINT:
279 | ckpt = torch.load(ckpt_path, map_location="cpu")
280 | for k in list(ckpt["state_dict"].keys()):
281 | if "model" in k:
282 | ckpt["state_dict"][k.replace("model.", "")] = ckpt["state_dict"].pop(k)
283 | ckpts.append(ckpt['state_dict'])
284 | model.load_state_dict(ckpts, strict=True)
285 | print("checkpoint state loaded!")
286 | elif model_cfg.CHECKPOINT and isinstance(model_cfg.CHECKPOINT,str):
287 |
288 | checkpoint = torch.load(model_cfg.CHECKPOINT, map_location=torch.device("cpu"))
289 | for k in list(checkpoint["state_dict"].keys()):
290 | if "model" in k:
291 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k)
292 | model.load_state_dict(checkpoint['state_dict'], strict=True)
293 |
294 | eval_motion_loaders[model_cfg.NAME] = lambda: get_motion_loader(
295 | batch_size,
296 | model,
297 | gt_dataset,
298 | device,
299 | mm_num_samples,
300 | mm_num_repeats
301 | )
302 |
303 | device = torch.device('cuda:%d' % 0 if torch.cuda.is_available() else 'cpu')
304 | gt_loader, gt_dataset = get_dataset_motion_loader(data_cfg, batch_size)
305 | evalmodel_cfg = get_config(args.evalmodel_config)
306 | eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device)
307 |
308 | log_file = f'evaluation.log'
309 | evaluation(log_file)
310 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | # import cv2
5 | from PIL import Image
6 | import math
7 | import time
8 | import matplotlib.pyplot as plt
9 | from scipy.ndimage import gaussian_filter
10 | from common.quaternion import *
11 |
12 | from utils.rotation_conversions import *
13 |
14 | import logging
15 |
16 | face_joint_indx = [2,1,17,16]
17 | fid_l = [7,10]
18 | fid_r = [8,11]
19 |
20 |
21 | def swap_left_right_position(data):
22 | assert len(data.shape) == 3 and data.shape[-1] == 3
23 | data = data.copy()
24 | data[..., 0] *= -1
25 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21]
26 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20]
27 | left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30, 52, 53, 54, 55, 56]
28 | right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51, 57, 58, 59, 60, 61]
29 |
30 | tmp = data[:, right_chain]
31 | data[:, right_chain] = data[:, left_chain]
32 | data[:, left_chain] = tmp
33 | if data.shape[1] > 24:
34 | tmp = data[:, right_hand_chain]
35 | data[:, right_hand_chain] = data[:, left_hand_chain]
36 | data[:, left_hand_chain] = tmp
37 | return data
38 |
39 | def swap_left_right_rot(data):
40 | assert len(data.shape) == 3 and data.shape[-1] == 6
41 | data = data.copy()
42 |
43 | data[..., [1,2,4]] *= -1
44 |
45 | right_chain = np.array([2, 5, 8, 11, 14, 17, 19, 21])-1
46 | left_chain = np.array([1, 4, 7, 10, 13, 16, 18, 20])-1
47 | left_hand_chain = np.array([22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30,])-1
48 | right_hand_chain = np.array([43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51,])-1
49 |
50 | tmp = data[:, right_chain]
51 | data[:, right_chain] = data[:, left_chain]
52 | data[:, left_chain] = tmp
53 | if data.shape[1] > 24:
54 | tmp = data[:, right_hand_chain]
55 | data[:, right_hand_chain] = data[:, left_hand_chain]
56 | data[:, left_hand_chain] = tmp
57 | return data
58 |
59 |
60 | def swap_left_right(data, n_joints):
61 | T = data.shape[0]
62 | new_data = data.copy()
63 | positions = new_data[..., :3*n_joints].reshape(T, n_joints, 3)
64 | rotations = new_data[..., 3*n_joints:].reshape(T, -1, 6)
65 |
66 | positions = swap_left_right_position(positions)
67 | rotations = swap_left_right_rot(rotations)
68 |
69 | new_data = np.concatenate([positions.reshape(T, -1), rotations.reshape(T, -1)], axis=-1)
70 | return new_data
71 |
72 |
73 | def rigid_transform(relative, data): # 这段代码实现了一个刚性变换,将给定的相对变换应用到数据中的全局位置和全局速度上。
74 |
75 | global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3))
76 | global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3))
77 |
78 | relative_rot = relative[0]
79 | relative_t = relative[1:3]
80 | relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4,))
81 | relative_r_rot_quat[..., 0] = np.cos(relative_rot)
82 | relative_r_rot_quat[..., 2] = np.sin(relative_rot)
83 | global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions)
84 | global_positions[..., [0, 2]] += relative_t
85 | data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1,))
86 | global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel)
87 | data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1,))
88 |
89 | return data
90 |
91 |
92 | class MotionNormalizer():
93 | def __init__(self):
94 | mean = np.load("./data/global_mean.npy")
95 | std = np.load("./data/global_std.npy")
96 |
97 | self.motion_mean = mean # (262,)
98 | self.motion_std = std # (262,)
99 |
100 |
101 | def forward(self, x):
102 |
103 | x = (x - self.motion_mean) / self.motion_std
104 | return x
105 |
106 | def backward(self, x):
107 | x = x * self.motion_std + self.motion_mean
108 | return x
109 |
110 |
111 |
112 | class MotionNormalizerTorch():
113 | def __init__(self):
114 | mean = np.load("./data/global_mean.npy") # shape=(262,)
115 | std = np.load("./data/global_std.npy")
116 |
117 | self.motion_mean = torch.from_numpy(mean).float()
118 | self.motion_std = torch.from_numpy(std).float()
119 |
120 |
121 | def forward(self, x):
122 | device = x.device
123 | x = x.clone()
124 | x = (x - self.motion_mean.to(device)) / self.motion_std.to(device)
125 | return x
126 |
127 | def backward(self, x, global_rt=False):
128 | device = x.device
129 | x = x.clone()
130 | x = x * self.motion_std.to(device) + self.motion_mean.to(device)
131 | return x
132 |
133 | trans_matrix = torch.Tensor([[1.0, 0.0, 0.0], # 绕x转180度
134 | [0.0, 0.0, 1.0],
135 | [0.0, -1.0, 0.0]])
136 |
137 |
138 | def get_orientation(motion, prev_frames, n_joints):
139 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3]
140 | root_pos_init = positions[prev_frames]
141 |
142 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
143 | across = root_pos_init[r_hip] - root_pos_init[l_hip] # 计算身体的横向向量
144 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] # 归一化,使得across只表示方向
145 |
146 | # forward (3,), rotate around y-axis
147 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) # 将 [0,1,0]和横向向量across叉乘,计算身体的前向向量
148 | # forward (3,)
149 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化
150 | return forward_init
151 |
152 | def process_motion_np(motion, feet_thre, prev_frames, n_joints, target=np.array([[0, 0, 1]])):
153 |
154 | # (seq_len, joints_num, 3)
155 | # '''Down Sample'''
156 | # positions = positions[::ds_num]
157 |
158 | '''Uniform Skeleton'''
159 | # positions = uniform_skeleton(positions, tgt_offsets)
160 |
161 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3]
162 | rotations = motion[:, n_joints*3:] # [95,126=21*6]
163 |
164 | positions = np.einsum("mn, tjn->tjm", trans_matrix, positions)
165 |
166 | '''Put on Floor'''
167 | floor_height = positions.min(axis=0).min(axis=0)[1] # find the lowest point across all frames and joints
168 | positions[:, :, 1] -= floor_height
169 |
170 |
171 | '''XZ at origin'''
172 | # move the root pos of the first frame to (0,0) of xz plane
173 | # for example
174 | # poistion_root_init = [2,3,5] -> [0,3,5]
175 |
176 | root_pos_init = positions[prev_frames] # [95,22,3] -> [22,3]
177 | root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) # [3]
178 | positions = positions - root_pose_init_xz # 把第一帧root关节平移到(0,y,0)处, [95,22,3]
179 |
180 | '''All initially face Z+'''
181 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
182 | across = root_pos_init[r_hip] - root_pos_init[l_hip] # 计算身体的横向向量
183 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] # 归一化,使得across只表示方向
184 |
185 | # forward (3,), rotate around y-axis
186 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) # 将 [0,1,0]和横向向量across叉乘,计算身体的前向向量
187 | # forward (3,)
188 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化
189 |
190 | # target = np.array([[0, 0, 1]]) # 目标朝向,即z轴正方向
191 | root_quat_init = qbetween_np(forward_init, target) # 使用qbetween_np函数计算了一个四元数root_quat_init,它表示将初始的前向向量forward_init旋转到目标向量target所需的旋转。
192 | root_quat_init_for_all = np.ones(positions.shape[:-1] + (4,)) * root_quat_init # 代码使用qrot_np函数将所有的姿势向量positions应用了旋转,使得身体的朝向与期望的朝向一致。
193 |
194 |
195 | positions = qrot_np(root_quat_init_for_all, positions) # [95,22,3] 代码使用qrot_np函数将所有的姿势向量positions应用了旋转,使得身体的朝向与期望的朝向一致
196 |
197 | """ Get Foot Contacts """
198 |
199 | def foot_detect(positions, thres):
200 | velfactor, heightfactor = np.array([thres, thres]), np.array([0.12, 0.05])
201 |
202 | feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
203 | feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
204 | feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
205 | feet_l_h = positions[:-1,fid_l,1]
206 | feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float32)
207 |
208 | feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
209 | feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
210 | feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
211 | feet_r_h = positions[:-1,fid_r,1]
212 | feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float32)
213 | return feet_l, feet_r
214 | #
215 | feet_l, feet_r = foot_detect(positions, feet_thre) # [94,2], [94,2]
216 |
217 | '''Get Joint Rotation Representation'''
218 | rot_data = rotations # [95,126]
219 |
220 | '''Get Joint Rotation Invariant Position Represention'''
221 | joint_positions = positions.reshape(len(positions), -1) # [95,66]
222 | joint_vels = positions[1:] - positions[:-1] # [94,66]
223 | joint_vels = joint_vels.reshape(len(joint_vels), -1) # [94,66]
224 |
225 | data = joint_positions[:-1] # [94,66]
226 | data = np.concatenate([data, joint_vels], axis=-1) # [94,66+66]
227 | data = np.concatenate([data, rot_data[:-1]], axis=-1) # [94,66+66+126]
228 | data = np.concatenate([data, feet_l, feet_r], axis=-1) # [94,66+66+126+2+2]
229 |
230 | return data, root_quat_init, root_pose_init_xz[None]
231 |
232 | def mkdir(path):
233 | if not os.path.exists(path):
234 | os.makedirs(path)
235 |
236 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
237 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
238 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
239 |
240 | MISSING_VALUE = -1
241 |
242 | def save_image(image_numpy, image_path):
243 | img_pil = Image.fromarray(image_numpy)
244 | img_pil.save(image_path)
245 |
246 |
247 | def save_logfile(log_loss, save_path):
248 | with open(save_path, 'wt') as f:
249 | for k, v in log_loss.items():
250 | w_line = k
251 | for digit in v:
252 | w_line += ' %.3f' % digit
253 | f.write(w_line + '\n')
254 |
255 |
256 | def print_current_loss(start_time, niter_state, losses, epoch=None, inner_iter=None, lr=None):
257 |
258 | def as_minutes(s):
259 | m = math.floor(s / 60)
260 | s -= m * 60
261 | return '%dm %ds' % (m, s)
262 |
263 | def time_since(since, percent):
264 | now = time.time()
265 | s = now - since
266 | es = s / percent
267 | rs = es - s
268 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
269 |
270 | if epoch is not None and lr is not None :
271 | epoch_mes = 'epoch: %3d niter:%6d inner_iter:%4d lr:%5f' % (epoch, niter_state, inner_iter, lr)
272 | print(epoch_mes, end=" ")
273 | elif epoch is not None:
274 | epoch_mes = 'epoch: %3d niter:%6d inner_iter:%4d' % (epoch, niter_state, inner_iter)
275 | print(epoch_mes, end=" ")
276 |
277 | now = time.time()
278 | message = '%s'%(as_minutes(now - start_time))
279 |
280 | for k, v in losses.items():
281 | message += ' %s: %.4f ' % (k, v)
282 | print(message)
283 | logging.info(epoch_mes+' '+message)
284 |
285 |
286 | def compose_gif_img_list(img_list, fp_out, duration):
287 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list]
288 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False,
289 | save_all=True, loop=0, duration=duration)
290 |
291 |
292 | def save_images(visuals, image_path):
293 | if not os.path.exists(image_path):
294 | os.makedirs(image_path)
295 |
296 | for i, (label, img_numpy) in enumerate(visuals.items()):
297 | img_name = '%d_%s.jpg' % (i, label)
298 | save_path = os.path.join(image_path, img_name)
299 | save_image(img_numpy, save_path)
300 |
301 |
302 | def save_images_test(visuals, image_path, from_name, to_name):
303 | if not os.path.exists(image_path):
304 | os.makedirs(image_path)
305 |
306 | for i, (label, img_numpy) in enumerate(visuals.items()):
307 | img_name = "%s_%s_%s" % (from_name, to_name, label)
308 | save_path = os.path.join(image_path, img_name)
309 | save_image(img_numpy, save_path)
310 |
311 |
312 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)):
313 | # print(col, row)
314 | compose_img = compose_image(img_list, col, row, img_size)
315 | if not os.path.exists(save_dir):
316 | os.makedirs(save_dir)
317 | img_path = os.path.join(save_dir, img_name)
318 | # print(img_path)
319 | compose_img.save(img_path)
320 |
321 |
322 | def compose_image(img_list, col, row, img_size):
323 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1]))
324 | for y in range(0, row):
325 | for x in range(0, col):
326 | from_img = Image.fromarray(img_list[y * col + x])
327 | # print((x * img_size[0], y*img_size[1],
328 | # (x + 1) * img_size[0], (y + 1) * img_size[1]))
329 | paste_area = (x * img_size[0], y*img_size[1],
330 | (x + 1) * img_size[0], (y + 1) * img_size[1])
331 | to_image.paste(from_img, paste_area)
332 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img
333 | return to_image
334 |
335 |
336 | def list_cut_average(ll, intervals):
337 | if intervals == 1:
338 | return ll
339 |
340 | bins = math.ceil(len(ll) * 1.0 / intervals)
341 | ll_new = []
342 | for i in range(bins):
343 | l_low = intervals * i
344 | l_high = l_low + intervals
345 | l_high = l_high if l_high < len(ll) else len(ll)
346 | ll_new.append(np.mean(ll[l_low:l_high]))
347 | return ll_new
348 |
349 |
350 | def motion_temporal_filter(motion, sigma=1):
351 | motion = motion.reshape(motion.shape[0], -1)
352 | # print(motion.shape)
353 | for i in range(motion.shape[1]):
354 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
355 | return motion.reshape(motion.shape[0], -1, 3)
356 |
357 |
--------------------------------------------------------------------------------
/common/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | _EPS4 = np.finfo(np.float32).eps * 4.0
12 |
13 | _FLOAT_EPS = np.finfo(np.float32).eps
14 |
15 | # PyTorch-backed implementations
16 | def qinv(q):
17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18 | mask = torch.ones_like(q)
19 | mask[..., 1:] = -mask[..., 1:]
20 | return q * mask
21 |
22 |
23 | def qinv_np(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | return qinv(torch.from_numpy(q).float()).numpy()
26 |
27 |
28 | def qnormalize(q):
29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30 | return q / torch.norm(q, dim=-1, keepdim=True)
31 |
32 |
33 | def qmul(q, r):
34 | """
35 | Multiply quaternion(s) q with quaternion(s) r.
36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37 | Returns q*r as a tensor of shape (*, 4).
38 | """
39 | assert q.shape[-1] == 4
40 | assert r.shape[-1] == 4
41 |
42 | original_shape = q.shape
43 |
44 | # Compute outer product
45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46 |
47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
52 |
53 |
54 | def qrot(q, v):
55 | """
56 | Rotate vector(s) v about the rotation described by quaternion(s) q.
57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58 | where * denotes any number of dimensions.
59 | Returns a tensor of shape (*, 3).
60 | """
61 | assert q.shape[-1] == 4
62 | assert v.shape[-1] == 3
63 | assert q.shape[:-1] == v.shape[:-1]
64 |
65 | original_shape = list(v.shape)
66 | # print(q.shape)
67 | q = q.contiguous().view(-1, 4)
68 | v = v.contiguous().view(-1, 3)
69 |
70 | qvec = q[:, 1:]
71 | uv = torch.cross(qvec, v, dim=1)
72 | uuv = torch.cross(qvec, uv, dim=1)
73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74 |
75 |
76 | def qeuler(q, order, epsilon=0, deg=True):
77 | """
78 | Convert quaternion(s) q to Euler angles.
79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80 | Returns a tensor of shape (*, 3).
81 | """
82 | assert q.shape[-1] == 4
83 |
84 | original_shape = list(q.shape)
85 | original_shape[-1] = 3
86 | q = q.view(-1, 4)
87 |
88 | q0 = q[:, 0]
89 | q1 = q[:, 1]
90 | q2 = q[:, 2]
91 | q3 = q[:, 3]
92 |
93 | if order == 'xyz':
94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | elif order == 'yzx':
98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101 | elif order == 'zxy':
102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105 | elif order == 'xzy':
106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109 | elif order == 'yxz':
110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113 | elif order == 'zyx':
114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117 | else:
118 | raise
119 |
120 | if deg:
121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122 | else:
123 | return torch.stack((x, y, z), dim=1).view(original_shape)
124 |
125 |
126 | # Numpy-backed implementations
127 |
128 | def qmul_np(q, r):
129 | q = torch.from_numpy(q).contiguous().float()
130 | r = torch.from_numpy(r).contiguous().float()
131 | return qmul(q, r).numpy()
132 |
133 |
134 | def qrot_np(q, v):
135 | q = torch.from_numpy(q).contiguous().float()
136 | v = torch.from_numpy(v).contiguous().float()
137 | return qrot(q, v).numpy()
138 |
139 |
140 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
141 | if use_gpu:
142 | q = torch.from_numpy(q).cuda().float()
143 | return qeuler(q, order, epsilon).cpu().numpy()
144 | else:
145 | q = torch.from_numpy(q).contiguous().float()
146 | return qeuler(q, order, epsilon).numpy()
147 |
148 |
149 | def qfix(q):
150 | """
151 | Enforce quaternion continuity across the time dimension by selecting
152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153 | between two consecutive frames.
154 |
155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156 | Returns a tensor of the same shape.
157 | """
158 | assert len(q.shape) == 3
159 | assert q.shape[-1] == 4
160 |
161 | result = q.copy()
162 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
163 | mask = dot_products < 0
164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165 | result[1:][mask] *= -1
166 | return result
167 |
168 |
169 | def euler2quat(e, order, deg=True):
170 | """
171 | Convert Euler angles to quaternions.
172 | """
173 | assert e.shape[-1] == 3
174 |
175 | original_shape = list(e.shape)
176 | original_shape[-1] = 4
177 |
178 | e = e.view(-1, 3)
179 |
180 | ## if euler angles in degrees
181 | if deg:
182 | e = e * np.pi / 180.
183 |
184 | x = e[:, 0]
185 | y = e[:, 1]
186 | z = e[:, 2]
187 |
188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191 |
192 | result = None
193 | for coord in order:
194 | if coord == 'x':
195 | r = rx
196 | elif coord == 'y':
197 | r = ry
198 | elif coord == 'z':
199 | r = rz
200 | else:
201 | raise
202 | if result is None:
203 | result = r
204 | else:
205 | result = qmul(result, r)
206 |
207 | # Reverse antipodal representation to have a non-negative "w"
208 | if order in ['xyz', 'yzx', 'zxy']:
209 | result *= -1
210 |
211 | return result.view(original_shape)
212 |
213 |
214 | def expmap_to_quaternion(e):
215 | """
216 | Convert axis-angle rotations (aka exponential maps) to quaternions.
217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219 | Returns a tensor of shape (*, 4).
220 | """
221 | assert e.shape[-1] == 3
222 |
223 | original_shape = list(e.shape)
224 | original_shape[-1] = 4
225 | e = e.reshape(-1, 3)
226 |
227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228 | w = np.cos(0.5 * theta).reshape(-1, 1)
229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231 |
232 |
233 | def euler_to_quaternion(e, order):
234 | """
235 | Convert Euler angles to quaternions.
236 | """
237 | assert e.shape[-1] == 3
238 |
239 | original_shape = list(e.shape)
240 | original_shape[-1] = 4
241 |
242 | e = e.reshape(-1, 3)
243 |
244 | x = e[:, 0]
245 | y = e[:, 1]
246 | z = e[:, 2]
247 |
248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251 |
252 | result = None
253 | for coord in order:
254 | if coord == 'x':
255 | r = rx
256 | elif coord == 'y':
257 | r = ry
258 | elif coord == 'z':
259 | r = rz
260 | else:
261 | raise
262 | if result is None:
263 | result = r
264 | else:
265 | result = qmul_np(result, r)
266 |
267 | # Reverse antipodal representation to have a non-negative "w"
268 | if order in ['xyz', 'yzx', 'zxy']:
269 | result *= -1
270 |
271 | return result.reshape(original_shape)
272 |
273 |
274 | def quaternion_to_matrix(quaternions):
275 | """
276 | Convert rotations given as quaternions to rotation matrices.
277 | Args:
278 | quaternions: quaternions with real part first,
279 | as tensor of shape (..., 4).
280 | Returns:
281 | Rotation matrices as tensor of shape (..., 3, 3).
282 | """
283 | r, i, j, k = torch.unbind(quaternions, -1)
284 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
285 |
286 | o = torch.stack(
287 | (
288 | 1 - two_s * (j * j + k * k),
289 | two_s * (i * j - k * r),
290 | two_s * (i * k + j * r),
291 | two_s * (i * j + k * r),
292 | 1 - two_s * (i * i + k * k),
293 | two_s * (j * k - i * r),
294 | two_s * (i * k - j * r),
295 | two_s * (j * k + i * r),
296 | 1 - two_s * (i * i + j * j),
297 | ),
298 | -1,
299 | )
300 | return o.reshape(quaternions.shape[:-1] + (3, 3))
301 |
302 |
303 | def quaternion_to_matrix_np(quaternions):
304 | q = torch.from_numpy(quaternions).contiguous().float()
305 | return quaternion_to_matrix(q).numpy()
306 |
307 |
308 | def quaternion_to_cont6d_np(quaternions):
309 | rotation_mat = quaternion_to_matrix_np(quaternions)
310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311 | return cont_6d
312 |
313 |
314 | def quaternion_to_cont6d(quaternions):
315 | rotation_mat = quaternion_to_matrix(quaternions)
316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317 | return cont_6d
318 |
319 |
320 | def cont6d_to_matrix(cont6d):
321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322 | x_raw = cont6d[..., 0:3]
323 | y_raw = cont6d[..., 3:6]
324 |
325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326 | z = torch.cross(x, y_raw, dim=-1)
327 | z = z / torch.norm(z, dim=-1, keepdim=True)
328 |
329 | y = torch.cross(z, x, dim=-1)
330 |
331 | x = x[..., None]
332 | y = y[..., None]
333 | z = z[..., None]
334 |
335 | mat = torch.cat([x, y, z], dim=-1)
336 | return mat
337 |
338 |
339 | def cont6d_to_matrix_np(cont6d):
340 | q = torch.from_numpy(cont6d).contiguous().float()
341 | return cont6d_to_matrix(q).numpy()
342 |
343 |
344 | def qpow(q0, t, dtype=torch.float):
345 | ''' q0 : tensor of quaternions
346 | t: tensor of powers
347 | '''
348 | q0 = qnormalize(q0)
349 | theta0 = torch.acos(q0[..., 0])
350 |
351 | ## if theta0 is close to zero, add epsilon to avoid NaNs
352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353 | theta0 = (1 - mask) * theta0 + mask * 10e-10
354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355 |
356 | if isinstance(t, torch.Tensor):
357 | q = torch.zeros(t.shape + q0.shape)
358 | theta = t.view(-1, 1) * theta0.view(1, -1)
359 | else: ## if t is a number
360 | q = torch.zeros(q0.shape)
361 | theta = t * theta0
362 |
363 | q[..., 0] = torch.cos(theta)
364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365 |
366 | return q.to(dtype)
367 |
368 |
369 | def qslerp(q0, q1, t):
370 | '''
371 | q0: starting quaternion
372 | q1: ending quaternion
373 | t: array of points along the way
374 |
375 | Returns:
376 | Tensor of Slerps: t.shape + q0.shape
377 | '''
378 |
379 | q0 = qnormalize(q0)
380 | q1 = qnormalize(q1)
381 | q_ = qpow(qmul(q1, qinv(q0)), t)
382 |
383 | return qmul(q_,
384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385 |
386 |
387 | def qbetween(v0, v1):
388 | '''
389 | find the quaternion used to rotate v0 to v1
390 | '''
391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393 |
394 | v = torch.cross(v0, v1)
395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396 | keepdim=True)
397 | return qnormalize(torch.cat([w, v], dim=-1))
398 |
399 |
400 | def qbetween_np(v0, v1):
401 | '''
402 | find the quaternion used to rotate v0 to v1
403 | '''
404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406 |
407 | v0 = torch.from_numpy(v0).float()
408 | v1 = torch.from_numpy(v1).float()
409 | return qbetween(v0, v1).numpy()
410 |
411 |
412 | def lerp(p0, p1, t):
413 | if not isinstance(t, torch.Tensor):
414 | t = torch.Tensor([t])
415 |
416 | new_shape = t.shape + p0.shape
417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419 | p0 = p0.view(new_view_p).expand(new_shape)
420 | p1 = p1.view(new_view_p).expand(new_shape)
421 | t = t.view(new_view_t).expand(new_shape)
422 |
423 | return p0 + t * (p1 - p0)
424 |
--------------------------------------------------------------------------------
/utils/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | _EPS4 = np.finfo(np.float32).eps * 4.0
12 |
13 | _FLOAT_EPS = np.finfo(np.float32).eps
14 |
15 | # PyTorch-backed implementations
16 | def qinv(q):
17 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18 | mask = torch.ones_like(q)
19 | mask[..., 1:] = -mask[..., 1:]
20 | return q * mask
21 |
22 |
23 | def qinv_np(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | return qinv(torch.from_numpy(q).float()).numpy()
26 |
27 |
28 | def qnormalize(q):
29 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30 | return q / torch.norm(q, dim=-1, keepdim=True)
31 |
32 |
33 | def qmul(q, r):
34 | """
35 | Multiply quaternion(s) q with quaternion(s) r.
36 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37 | Returns q*r as a tensor of shape (*, 4).
38 | """
39 | assert q.shape[-1] == 4
40 | assert r.shape[-1] == 4
41 |
42 | original_shape = q.shape
43 |
44 | # Compute outer product
45 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46 |
47 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
52 |
53 |
54 | def qrot(q, v):
55 | """
56 | Rotate vector(s) v about the rotation described by quaternion(s) q.
57 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58 | where * denotes any number of dimensions.
59 | Returns a tensor of shape (*, 3).
60 | """
61 | assert q.shape[-1] == 4
62 | assert v.shape[-1] == 3
63 | assert q.shape[:-1] == v.shape[:-1]
64 |
65 | original_shape = list(v.shape)
66 | # print(q.shape)
67 | q = q.contiguous().view(-1, 4)
68 | v = v.contiguous().view(-1, 3)
69 |
70 | qvec = q[:, 1:]
71 | uv = torch.cross(qvec, v, dim=1)
72 | uuv = torch.cross(qvec, uv, dim=1)
73 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74 |
75 |
76 | def qeuler(q, order, epsilon=0, deg=True):
77 | """
78 | Convert quaternion(s) q to Euler angles.
79 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80 | Returns a tensor of shape (*, 3).
81 | """
82 | assert q.shape[-1] == 4
83 |
84 | original_shape = list(q.shape)
85 | original_shape[-1] = 3
86 | q = q.view(-1, 4)
87 |
88 | q0 = q[:, 0]
89 | q1 = q[:, 1]
90 | q2 = q[:, 2]
91 | q3 = q[:, 3]
92 |
93 | if order == 'xyz':
94 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | elif order == 'yzx':
98 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101 | elif order == 'zxy':
102 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105 | elif order == 'xzy':
106 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109 | elif order == 'yxz':
110 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113 | elif order == 'zyx':
114 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117 | else:
118 | raise
119 |
120 | if deg:
121 | return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122 | else:
123 | return torch.stack((x, y, z), dim=1).view(original_shape)
124 |
125 |
126 | # Numpy-backed implementations
127 |
128 | def qmul_np(q, r):
129 | q = torch.from_numpy(q).contiguous().float()
130 | r = torch.from_numpy(r).contiguous().float()
131 | return qmul(q, r).numpy()
132 |
133 |
134 | def qrot_np(q, v):
135 | q = torch.from_numpy(q).contiguous().float()
136 | v = torch.from_numpy(v).contiguous().float()
137 | return qrot(q, v).numpy()
138 |
139 |
140 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
141 | if use_gpu:
142 | q = torch.from_numpy(q).cuda().float()
143 | return qeuler(q, order, epsilon).cpu().numpy()
144 | else:
145 | q = torch.from_numpy(q).contiguous().float()
146 | return qeuler(q, order, epsilon).numpy()
147 |
148 |
149 | def qfix(q):
150 | """
151 | Enforce quaternion continuity across the time dimension by selecting
152 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153 | between two consecutive frames.
154 |
155 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156 | Returns a tensor of the same shape.
157 | """
158 | assert len(q.shape) == 3
159 | assert q.shape[-1] == 4
160 |
161 | result = q.copy()
162 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
163 | mask = dot_products < 0
164 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165 | result[1:][mask] *= -1
166 | return result
167 |
168 |
169 | def euler2quat(e, order, deg=True):
170 | """
171 | Convert Euler angles to quaternions.
172 | """
173 | assert e.shape[-1] == 3
174 |
175 | original_shape = list(e.shape)
176 | original_shape[-1] = 4
177 |
178 | e = e.view(-1, 3)
179 |
180 | ## if euler angles in degrees
181 | if deg:
182 | e = e * np.pi / 180.
183 |
184 | x = e[:, 0]
185 | y = e[:, 1]
186 | z = e[:, 2]
187 |
188 | rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189 | ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190 | rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191 |
192 | result = None
193 | for coord in order:
194 | if coord == 'x':
195 | r = rx
196 | elif coord == 'y':
197 | r = ry
198 | elif coord == 'z':
199 | r = rz
200 | else:
201 | raise
202 | if result is None:
203 | result = r
204 | else:
205 | result = qmul(result, r)
206 |
207 | # Reverse antipodal representation to have a non-negative "w"
208 | if order in ['xyz', 'yzx', 'zxy']:
209 | result *= -1
210 |
211 | return result.view(original_shape)
212 |
213 |
214 | def expmap_to_quaternion(e):
215 | """
216 | Convert axis-angle rotations (aka exponential maps) to quaternions.
217 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219 | Returns a tensor of shape (*, 4).
220 | """
221 | assert e.shape[-1] == 3
222 |
223 | original_shape = list(e.shape)
224 | original_shape[-1] = 4
225 | e = e.reshape(-1, 3)
226 |
227 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228 | w = np.cos(0.5 * theta).reshape(-1, 1)
229 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231 |
232 |
233 | def euler_to_quaternion(e, order):
234 | """
235 | Convert Euler angles to quaternions.
236 | """
237 | assert e.shape[-1] == 3
238 |
239 | original_shape = list(e.shape)
240 | original_shape[-1] = 4
241 |
242 | e = e.reshape(-1, 3)
243 |
244 | x = e[:, 0]
245 | y = e[:, 1]
246 | z = e[:, 2]
247 |
248 | rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249 | ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250 | rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251 |
252 | result = None
253 | for coord in order:
254 | if coord == 'x':
255 | r = rx
256 | elif coord == 'y':
257 | r = ry
258 | elif coord == 'z':
259 | r = rz
260 | else:
261 | raise
262 | if result is None:
263 | result = r
264 | else:
265 | result = qmul_np(result, r)
266 |
267 | # Reverse antipodal representation to have a non-negative "w"
268 | if order in ['xyz', 'yzx', 'zxy']:
269 | result *= -1
270 |
271 | return result.reshape(original_shape)
272 |
273 |
274 | def quaternion_to_matrix(quaternions):
275 | """
276 | Convert rotations given as quaternions to rotation matrices.
277 | Args:
278 | quaternions: quaternions with real part first,
279 | as tensor of shape (..., 4).
280 | Returns:
281 | Rotation matrices as tensor of shape (..., 3, 3).
282 | """
283 | r, i, j, k = torch.unbind(quaternions, -1)
284 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
285 |
286 | o = torch.stack(
287 | (
288 | 1 - two_s * (j * j + k * k),
289 | two_s * (i * j - k * r),
290 | two_s * (i * k + j * r),
291 | two_s * (i * j + k * r),
292 | 1 - two_s * (i * i + k * k),
293 | two_s * (j * k - i * r),
294 | two_s * (i * k - j * r),
295 | two_s * (j * k + i * r),
296 | 1 - two_s * (i * i + j * j),
297 | ),
298 | -1,
299 | )
300 | return o.reshape(quaternions.shape[:-1] + (3, 3))
301 |
302 |
303 | def quaternion_to_matrix_np(quaternions):
304 | q = torch.from_numpy(quaternions).contiguous().float()
305 | return quaternion_to_matrix(q).numpy()
306 |
307 |
308 | def quaternion_to_cont6d_np(quaternions):
309 | rotation_mat = quaternion_to_matrix_np(quaternions)
310 | cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311 | return cont_6d
312 |
313 |
314 | def quaternion_to_cont6d(quaternions):
315 | rotation_mat = quaternion_to_matrix(quaternions)
316 | cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317 | return cont_6d
318 |
319 |
320 | def cont6d_to_matrix(cont6d):
321 | assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322 | x_raw = cont6d[..., 0:3]
323 | y_raw = cont6d[..., 3:6]
324 |
325 | x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326 | z = torch.cross(x, y_raw, dim=-1)
327 | z = z / torch.norm(z, dim=-1, keepdim=True)
328 |
329 | y = torch.cross(z, x, dim=-1)
330 |
331 | x = x[..., None]
332 | y = y[..., None]
333 | z = z[..., None]
334 |
335 | mat = torch.cat([x, y, z], dim=-1)
336 | return mat
337 |
338 |
339 | def cont6d_to_matrix_np(cont6d):
340 | q = torch.from_numpy(cont6d).contiguous().float()
341 | return cont6d_to_matrix(q).numpy()
342 |
343 |
344 | def qpow(q0, t, dtype=torch.float):
345 | ''' q0 : tensor of quaternions
346 | t: tensor of powers
347 | '''
348 | q0 = qnormalize(q0)
349 | theta0 = torch.acos(q0[..., 0])
350 |
351 | ## if theta0 is close to zero, add epsilon to avoid NaNs
352 | mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353 | theta0 = (1 - mask) * theta0 + mask * 10e-10
354 | v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355 |
356 | if isinstance(t, torch.Tensor):
357 | q = torch.zeros(t.shape + q0.shape)
358 | theta = t.view(-1, 1) * theta0.view(1, -1)
359 | else: ## if t is a number
360 | q = torch.zeros(q0.shape)
361 | theta = t * theta0
362 |
363 | q[..., 0] = torch.cos(theta)
364 | q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365 |
366 | return q.to(dtype)
367 |
368 |
369 | def qslerp(q0, q1, t):
370 | '''
371 | q0: starting quaternion
372 | q1: ending quaternion
373 | t: array of points along the way
374 |
375 | Returns:
376 | Tensor of Slerps: t.shape + q0.shape
377 | '''
378 |
379 | q0 = qnormalize(q0)
380 | q1 = qnormalize(q1)
381 | q_ = qpow(qmul(q1, qinv(q0)), t)
382 |
383 | return qmul(q_,
384 | q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385 |
386 |
387 | def qbetween(v0, v1):
388 | '''
389 | find the quaternion used to rotate v0 to v1
390 | '''
391 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393 |
394 | v = torch.cross(v0, v1)
395 | w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396 | keepdim=True)
397 | return qnormalize(torch.cat([w, v], dim=-1))
398 |
399 |
400 | def qbetween_np(v0, v1):
401 | '''
402 | find the quaternion used to rotate v0 to v1
403 | '''
404 | assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405 | assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406 |
407 | v0 = torch.from_numpy(v0).float()
408 | v1 = torch.from_numpy(v1).float()
409 | return qbetween(v0, v1).numpy()
410 |
411 |
412 | def lerp(p0, p1, t):
413 | if not isinstance(t, torch.Tensor):
414 | t = torch.Tensor([t])
415 |
416 | new_shape = t.shape + p0.shape
417 | new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418 | new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419 | p0 = p0.view(new_view_p).expand(new_shape)
420 | p1 = p1.view(new_view_p).expand(new_shape)
421 | t = t.view(new_view_t).expand(new_shape)
422 |
423 | return p0 + t * (p1 - p0)
424 |
--------------------------------------------------------------------------------
/models/nets.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping
2 | import torch
3 |
4 | from models.utils import *
5 | from models.cfg_sampler import ClassifierFreeSampleModel
6 | from models.blocks import *
7 | from utils.utils import *
8 |
9 | from models.gaussian_diffusion import (
10 | # MotionDiffusion,
11 | MotionSpatialControlNetDiffusion,
12 | space_timesteps,
13 | get_named_beta_schedule,
14 | create_named_schedule_sampler,
15 | ModelMeanType,
16 | ModelVarType,
17 | LossType
18 | )
19 | import random
20 |
21 | class MotionEncoder(nn.Module):
22 | def __init__(self, cfg):
23 | super().__init__()
24 |
25 | self.cfg = cfg
26 | self.input_feats = cfg.INPUT_DIM
27 | self.latent_dim = cfg.LATENT_DIM
28 | self.ff_size = cfg.FF_SIZE
29 | self.num_layers = cfg.NUM_LAYERS
30 | self.num_heads = cfg.NUM_HEADS
31 | self.dropout = cfg.DROPOUT
32 | self.activation = cfg.ACTIVATION
33 |
34 | self.query_token = nn.Parameter(torch.randn(1, self.latent_dim))
35 |
36 | self.embed_motion = nn.Linear(self.input_feats*2, self.latent_dim)
37 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=2000)
38 |
39 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
40 | nhead=self.num_heads,
41 | dim_feedforward=self.ff_size,
42 | dropout=self.dropout,
43 | activation=self.activation,
44 | batch_first=True)
45 | self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers)
46 | self.out_ln = nn.LayerNorm(self.latent_dim)
47 | self.out = nn.Linear(self.latent_dim, 512)
48 |
49 |
50 | def forward(self, batch):
51 | x, mask = batch["motions"], batch["mask"]
52 | B, T, D = x.shape
53 |
54 | x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1)
55 |
56 | x_emb = self.embed_motion(x)
57 |
58 | emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=x.device)][:,None], x_emb], dim=1)
59 |
60 | seq_mask = (mask>0.5)
61 | token_mask = torch.ones((B, 1), dtype=bool, device=x.device)
62 | valid_mask = torch.cat([token_mask, seq_mask], dim=1)
63 |
64 | h = self.sequence_pos_encoder(emb)
65 | h = self.transformer(h, src_key_padding_mask=~valid_mask)
66 | h = self.out_ln(h)
67 | motion_emb = self.out(h[:,0])
68 |
69 | batch["motion_emb"] = motion_emb
70 |
71 | return batch
72 |
73 |
74 | class InterDenoiser(nn.Module):
75 | def __init__(self,
76 | input_feats,
77 | latent_dim=512,
78 | num_frames=240,
79 | ff_size=1024,
80 | num_layers=8,
81 | num_heads=8,
82 | dropout=0.1,
83 | activation="gelu",
84 | cfg_weight=0.,
85 | archi='single',
86 | return_intermediate=False,
87 | **kargs):
88 | super().__init__()
89 |
90 | self.cfg_weight = cfg_weight
91 | self.num_frames = num_frames
92 | self.latent_dim = latent_dim
93 | self.ff_size = ff_size
94 | self.num_layers = num_layers
95 | self.num_heads = num_heads
96 | self.dropout = dropout
97 | self.activation = activation
98 | self.input_feats = input_feats
99 | self.time_embed_dim = latent_dim
100 |
101 | self.archi = archi
102 |
103 | self.text_emb_dim = 768
104 | self.spatial_emb_dim = 4
105 |
106 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, dropout=0)
107 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
108 |
109 | # Input Embedding
110 | self.motion_embed = nn.Linear(self.input_feats, self.latent_dim)
111 | self.text_embed = nn.Linear(self.text_emb_dim, self.latent_dim)
112 |
113 | self.return_intermediate = return_intermediate
114 |
115 | self.blocks = nn.ModuleList()
116 | for _ in range(num_layers):
117 | self.blocks.append(TransformerMotionGuidanceBlock(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size))
118 | # self.blocks.append(TransformerBlock(num_heads=num_heads,latent_dim=latent_dim, dropout=dropout, ff_size=ff_size))
119 |
120 | # Output Module
121 | self.out = zero_module(FinalLayer(self.latent_dim, self.input_feats))
122 |
123 |
124 | def mask_motion_guidance(self, motion_mask, cond_mask_prob = 0.1, force_mask=False):
125 | # This function is aims to mask the motion guidance with a probability of cond_mask_prob.
126 | # motion_mask of input has been mask the invalid time.
127 | # mask.shape=[B,P,T]
128 | bs, guidance_num = motion_mask.shape[0], motion_mask.shape[1]-1
129 | if guidance_num == 0:
130 | return motion_mask.clone()
131 | elif force_mask:
132 | mask_new = motion_mask.clone()
133 | mask_new[:,1:,:] = 0
134 | return mask_new
135 | elif cond_mask_prob > 0.:
136 |
137 | mask_new = torch.bernoulli(torch.ones(bs,guidance_num, device=motion_mask.device) * cond_mask_prob).view(bs,guidance_num,1).contiguous() # 1-> use null_cond, 0-> use real cond
138 | mask_new = torch.cat([torch.zeros(bs,1,1, device=motion_mask.device), mask_new], dim=1)
139 |
140 | return motion_mask * (1. - mask_new)
141 | else:
142 | return motion_mask
143 |
144 | def forward(self, x, timesteps, mask=None, cond=None, motion_guidance=None, **kwargs): # spatial_condition=None, **kwargs):
145 | """
146 | x: B, T, D
147 | """
148 |
149 | B, T = x.shape[0], x.shape[1]
150 |
151 | if motion_guidance is None: # single motion generation
152 | x_cat = x.unsqueeze(1) # B,1,T,D
153 | elif self.return_intermediate: # control branch for two motion
154 | motion_guidance = motion_guidance.permute(0,2,1,3) # B,P,T,D
155 | x_cat = torch.cat([x.unsqueeze(1), motion_guidance], dim=1) # B,1,T,D, B,P,T,D -> B,P+1,T,D
156 | else: # main branch for two motion
157 | x_cat = x.unsqueeze(1) # B,1,T,D
158 |
159 | if mask is not None:
160 | mask = mask.permute(0,2,1).contiguous()[:,:x_cat.shape[1],...] # B,T,P -> B,P,T
161 | else:
162 | mask = torch.ones(B, x_cat.shape[1], T).to(x_cat.device)
163 |
164 | if x_cat.shape[1] > 1:
165 | mask = self.mask_motion_guidance(mask, 0.1)
166 |
167 | emb = self.embed_timestep(timesteps) + self.text_embed(cond) # [4],[4,768] -> [4,1024]
168 |
169 | x_cat_emb = self.motion_embed(x_cat) # B,P+1,T,D -> B,P+1,T,1024
170 |
171 | h_cat_prev = self.sequence_pos_encoder(x_cat_emb.view(-1, T, self.latent_dim)).view(B, -1, T, self.latent_dim)
172 |
173 | key_padding_mask = ~(mask > 0.5)
174 |
175 | h_cat_prev = h_cat_prev.view(B,-1,self.latent_dim) # [B,P*T,D]
176 | key_padding_mask = key_padding_mask.view(B,-1) # [B,P*T]
177 |
178 | if self.return_intermediate: # For control branch | spatial control branch
179 | intermediate_feats = []
180 |
181 | for idx ,block in enumerate(self.blocks):
182 |
183 | h_cat_prev = block(h_cat_prev, T, emb, key_padding_mask)
184 |
185 | if motion_guidance is None: # single motion generation
186 | pass
187 | elif self.return_intermediate: # control branch
188 | intermediate_feats.append(h_cat_prev.view(B,-1,T,self.latent_dim)[:,:1,...])
189 | else: # main branch for two motion
190 | h_cat_prev = h_cat_prev.view(B,-1,T,self.latent_dim)
191 | h_cat_prev = h_cat_prev + motion_guidance[idx]
192 | h_cat_prev = h_cat_prev.view(B,-1,self.latent_dim)
193 |
194 | if self.return_intermediate:
195 | return intermediate_feats
196 |
197 | h_cat_prev = h_cat_prev.view(B,-1,T,self.latent_dim)
198 | output = self.out(h_cat_prev)
199 |
200 | return output[:,0,...] # only return the first person for matching the dimension of diffusion process.
201 |
202 |
203 | class InterDenoiserSpatialControlNet(nn.Module):
204 | def __init__(self,
205 | input_feats,
206 | latent_dim=512,
207 | num_frames=240,
208 | ff_size=1024,
209 | num_layers=8,
210 | num_heads=8,
211 | dropout=0.1,
212 | activation="gelu",
213 | cfg_weight=0.,
214 | archi='single',
215 | **kargs):
216 | super().__init__()
217 |
218 | self.cfg_weight = cfg_weight
219 | self.num_frames = num_frames
220 | self.latent_dim = latent_dim
221 | self.ff_size = ff_size
222 | self.num_layers = num_layers
223 | self.num_heads = num_heads
224 | self.dropout = dropout
225 | self.activation = activation
226 | self.input_feats = input_feats
227 | self.time_embed_dim = latent_dim
228 |
229 | self.archi = archi
230 |
231 | if self.input_feats == 262 or self.input_feats == 263:
232 | self.n_joints = 22
233 | else:
234 | self.n_joints = 21
235 |
236 | self.net = InterDenoiser(self.input_feats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers,
237 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation,
238 | cfg_weight=self.cfg_weight, archi=self.archi, return_intermediate=False)
239 |
240 | if self.archi != 'single':
241 | self.control_branch = InterDenoiser(self.input_feats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers,
242 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation,
243 | cfg_weight=self.cfg_weight, archi=self.archi, return_intermediate=True)
244 |
245 | self.zero_linear = zero_module(nn.ModuleList([nn.Linear(self.latent_dim, self.latent_dim) for _ in range(self.num_layers)]))
246 |
247 | set_requires_grad(self.net, False)
248 |
249 |
250 | def forward(self, x, timesteps, mask=None, cond=None, motion_guidance=None, spatial_condition=None, **kwargs):
251 |
252 | """
253 | x: B, T, D
254 | spatial_condition: B, T, D
255 | """
256 |
257 | if motion_guidance is not None:
258 | intermediate_feats = self.control_branch(x, timesteps, mask=mask, cond=cond, motion_guidance=motion_guidance, spatial_condition=None,**kwargs)
259 | intermediate_feats = [self.zero_linear[i](intermediate_feats[i]) for i in range(len(intermediate_feats))]
260 | else:
261 | intermediate_feats = None
262 |
263 | output = self.net(x, timesteps, mask=mask, cond=cond, motion_guidance=intermediate_feats,**kwargs) #spatial_condition=None,**kwargs)
264 |
265 | return output
266 |
267 |
268 | class InterDiffusionSpatialControlNet(nn.Module):
269 | def __init__(self, cfg, sampling_strategy="ddim50"):
270 | super().__init__()
271 | self.cfg = cfg
272 | self.archi = cfg.ARCHI
273 | self.nfeats = cfg.INPUT_DIM
274 | self.latent_dim = cfg.LATENT_DIM
275 | self.ff_size = cfg.FF_SIZE
276 | self.num_layers = cfg.NUM_LAYERS
277 | self.num_heads = cfg.NUM_HEADS
278 | self.dropout = cfg.DROPOUT
279 | self.activation = cfg.ACTIVATION
280 | self.motion_rep = cfg.MOTION_REP
281 |
282 | self.cfg_weight = cfg.CFG_WEIGHT
283 | self.diffusion_steps = cfg.DIFFUSION_STEPS
284 | self.beta_scheduler = cfg.BETA_SCHEDULER
285 | self.sampler = cfg.SAMPLER
286 | self.sampling_strategy = sampling_strategy
287 |
288 | self.net = InterDenoiserSpatialControlNet(self.nfeats, self.latent_dim, ff_size=self.ff_size, num_layers=self.num_layers,
289 | num_heads=self.num_heads, dropout=self.dropout, activation=self.activation, cfg_weight=self.cfg_weight, archi=self.archi)
290 |
291 | self.diffusion_steps = self.diffusion_steps
292 | self.betas = get_named_beta_schedule(self.beta_scheduler, self.diffusion_steps)
293 |
294 | timestep_respacing=[self.diffusion_steps]
295 | self.diffusion = MotionSpatialControlNetDiffusion(
296 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing),
297 | betas=self.betas,
298 | motion_rep=self.motion_rep,
299 | model_mean_type=ModelMeanType.START_X,
300 | model_var_type=ModelVarType.FIXED_SMALL,
301 | loss_type=LossType.MSE,
302 | rescale_timesteps = False,
303 | archi= self.archi,
304 | )
305 | self.sampler = create_named_schedule_sampler(self.sampler, self.diffusion)
306 |
307 | def mask_cond(self, cond, cond_mask_prob = 0.1, force_mask=False):
308 | bs = cond.shape[0]
309 | if force_mask:
310 | return torch.zeros_like(cond)
311 | elif cond_mask_prob > 0.:
312 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * cond_mask_prob).view([bs]+[1]*len(cond.shape[1:])) # 1-> use null_cond, 0-> use real cond
313 | return cond * (1. - mask), (1. - mask)
314 | else:
315 | return cond, None
316 |
317 | def generate_src_mask(self, T, length, person_num, B=0):
318 | if B==0:
319 | B = length.shape[0]
320 | else:
321 | if len(length.shape) == 1:
322 | length = torch.cat([length]*B, dim=0)
323 | src_mask = torch.ones(B, T, person_num)
324 | for p in range(person_num):
325 | for i in range(B):
326 | for j in range(length[i], T):
327 | src_mask[i, j, p] = 0
328 | return src_mask
329 |
330 | def compute_loss(self, batch): # 似乎这个函数是在训练阶段使用的
331 | cond = batch["cond"]
332 | x_and_xCondition = batch["motions"] # [4,300,524]
333 | B,T = batch["motions"].shape[:2]
334 |
335 | if cond is not None:
336 | cond, cond_mask = self.mask_cond(cond, 0.1)
337 |
338 | seq_mask = self.generate_src_mask(batch["motions"].shape[1], batch["motion_lens"], batch["person_num"]).to(x_and_xCondition.device)
339 |
340 | t, _ = self.sampler.sample(B, x_and_xCondition.device)
341 |
342 | output = self.diffusion.training_losses(
343 | model=self.net,
344 | x_start=x_and_xCondition,
345 | t=t,
346 | mask=seq_mask,
347 | t_bar=self.cfg.T_BAR,
348 | cond_mask=cond_mask,
349 | model_kwargs={"mask":seq_mask,
350 | "cond":cond,
351 | "person_num":batch["person_num"],
352 | # "spatial_condition":batch["spatial_condition"],
353 | },
354 | )
355 | return output
356 |
357 | def forward(self, batch):
358 |
359 | cond = batch["cond"]
360 |
361 | motion_guidance = batch["motion_guidance"]
362 | # spatial_condition = batch["spatial_condition"]
363 |
364 | if motion_guidance is not None:
365 | B, T = motion_guidance.shape[:2]
366 | else: # T equals valid motion lens, then all items of the mask is valid.
367 | B = cond.shape[0]
368 | T = batch["motion_lens"].item()
369 |
370 | # add valid time length mask
371 | seq_mask = self.generate_src_mask(T, batch["motion_lens"], batch["person_num"], B=cond.shape[0]).to(batch["motion_lens"].device)
372 |
373 | timestep_respacing= self.sampling_strategy
374 | self.diffusion_test = MotionSpatialControlNetDiffusion( # MotionSpatialControlNetDiffusion
375 | use_timesteps=space_timesteps(self.diffusion_steps, timestep_respacing),
376 | betas=self.betas,
377 | motion_rep=self.motion_rep,
378 | model_mean_type=ModelMeanType.START_X,
379 | model_var_type=ModelVarType.FIXED_SMALL,
380 | loss_type=LossType.MSE,
381 | rescale_timesteps = False,
382 | archi= self.archi,
383 | )
384 |
385 | self.cfg_model = ClassifierFreeSampleModel(self.net, self.cfg_weight) # cfg_weight=3.5
386 |
387 | output = self.diffusion_test.ddim_sample_loop(
388 | self.cfg_model,
389 | (B,T,self.nfeats),
390 | clip_denoised=False,
391 | progress=True,
392 | model_kwargs={
393 | "mask":seq_mask, # None,
394 | "cond":cond,
395 | "motion_guidance":motion_guidance,
396 | # "spatial_condition": spatial_condition
397 | },
398 | x_start=None)
399 |
400 | return {"output":output} # output: [batch_size, n_ctx, 2*d_model]=[1,210,524]
401 |
402 |
403 |
404 |
405 |
--------------------------------------------------------------------------------
/datasets/evaluator.py:
--------------------------------------------------------------------------------
1 | from os.path import join as pjoin
2 | from torch.utils.data import Dataset, DataLoader
3 | from datasets import InterHumanPipelineInferDataset
4 | from models import *
5 | import copy
6 | from datasets.evaluator_models import InterCLIP
7 | from tqdm import tqdm
8 | import torch
9 |
10 | from utils.quaternion import *
11 |
12 | normalizer = MotionNormalizerTorch()
13 |
14 | def process_motion_np(motion, feet_thre, prev_frames, n_joints, target=np.array([[0, 0, 1]])):
15 |
16 | positions = motion[:, :n_joints*3].reshape(-1, n_joints, 3) # [95,22,3]
17 |
18 | vels = motion[:, n_joints*3:n_joints*6].reshape(-1, n_joints, 3) # [95,22,3]
19 |
20 | '''XZ at origin'''
21 | # move the root pos of the first frame to (0,0) of xz plane
22 | # for example
23 | # poistion_root_init = [2,3,5] -> [0,3,5]
24 |
25 | root_pos_init = positions[prev_frames] # [95,22,3] -> [22,3]
26 | root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) # [3]
27 | positions = positions - root_pose_init_xz
28 |
29 | '''All initially face Z+'''
30 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
31 | across = root_pos_init[r_hip] - root_pos_init[l_hip]
32 | across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
33 |
34 | # forward (3,), rotate around y-axis
35 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
36 | # forward (3,)
37 | forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] # 归一化
38 |
39 | # target = np.array([[0, 0, 1]])
40 | root_quat_init = qbetween_np(forward_init, target)
41 | root_quat_init_for_all = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
42 |
43 | positions = qrot_np(root_quat_init_for_all, positions)
44 |
45 | vels = qrot_np(root_quat_init_for_all, vels)
46 |
47 | '''Get Joint Rotation Invariant Position Represention'''
48 | joint_positions = positions.reshape(len(positions), -1) # [95,66]
49 | joint_vels = vels.reshape(len(vels), -1) # [95,66]
50 |
51 | motion[:, :n_joints*3] = joint_positions
52 | motion[:, n_joints*3:n_joints*6] = joint_vels
53 |
54 | return motion, root_quat_init, root_pose_init_xz[None]
55 |
56 | def rigid_transform(relative, data):
57 |
58 | global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3))
59 | global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3))
60 |
61 | relative_rot = relative[0]
62 | relative_t = relative[1:3]
63 | relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4,))
64 | relative_r_rot_quat[..., 0] = np.cos(relative_rot)
65 | relative_r_rot_quat[..., 2] = np.sin(relative_rot)
66 | global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions)
67 | global_positions[..., [0, 2]] += relative_t
68 | data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1,))
69 | global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel)
70 | data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1,))
71 |
72 | return data
73 |
74 |
75 | class EvaluationDataset(Dataset):
76 |
77 | def __init__(self, model, dataset, device, mm_num_samples, mm_num_repeats):
78 |
79 | self.normalizer = MotionNormalizer()
80 | self.device = device
81 | self.model = model.to(device)
82 | self.model.eval()
83 | dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True)
84 | self.max_length = dataset.max_length
85 |
86 | idxs = list(range(len(dataset)))
87 | random.shuffle(idxs)
88 | mm_idxs = idxs[:mm_num_samples]
89 |
90 | generated_motions = []
91 | mm_generated_motions = []
92 | # Pre-process all target captions
93 | with torch.no_grad():
94 | for i, data in tqdm(enumerate(dataloader)):
95 | name, text, text1, text2, motion1, motion2, motion_lens, hint = data
96 | batch = {}
97 |
98 | if min(hint.shape) == 0:
99 | hint = None
100 |
101 | if i in mm_idxs:
102 | text1 = list(text1) * mm_num_repeats
103 | text2 = list(text2) * mm_num_repeats
104 | text = list(text) * mm_num_repeats
105 |
106 | batch["motion_lens"] = motion_lens
107 |
108 | model_name = self.model.__class__.__name__
109 | motions_output = self.pipeline_generate(self.model, motion_lens, [text1,text2], text, [motion1,motion2], 1, hint) # [1,27,2,262]
110 |
111 | motions_output = self.normalizer.backward(motions_output.cpu().detach().numpy())
112 |
113 | B,T = motions_output.shape[0], motions_output.shape[1]
114 | if T < self.max_length:
115 | padding_len = self.max_length - T
116 | D = motions_output.shape[-1]
117 | padding_zeros = np.zeros((B, padding_len, 2, D))
118 | motions_output = np.concatenate((motions_output, padding_zeros), axis=1)
119 | assert motions_output.shape[1] == self.max_length
120 |
121 | sub_dict = {'motion1': motions_output[0, :,0],
122 | 'motion2': motions_output[0, :,1],
123 | 'motion_lens': motion_lens[0],
124 | 'text': text[0]}
125 |
126 | if hint is not None:
127 | sub_dict['spatial_condition'] = hint[0]
128 | # else:
129 | # sub_dict['hint'] = None
130 | generated_motions.append(sub_dict)
131 | if i in mm_idxs:
132 | mm_sub_dict = {'mm_motions': motions_output,
133 | 'motion_lens': motion_lens[0],
134 | 'text': text[0]}
135 | mm_generated_motions.append(mm_sub_dict)
136 |
137 |
138 | self.generated_motions = generated_motions
139 | self.mm_generated_motions = mm_generated_motions
140 |
141 | def __len__(self):
142 | return len(self.generated_motions)
143 |
144 | def __getitem__(self, item):
145 | data = self.generated_motions[item]
146 | motion1, motion2, motion_lens, text = data['motion1'], data['motion2'], data['motion_lens'], data['text']
147 | hint = data['spatial_condition'] if 'spatial_condition' in data else torch.zeros(0)
148 | return "generated", text, "placeholder", "placeholder", motion1, motion2, motion_lens, hint
149 |
150 | def pipeline_generate(self, model, motion_lens, texts, text_multi_person, motions, FLAG=0, hint=None):
151 |
152 | def generate(motion_lens, text, text_multi_person, person_num=1, motion_guidance=None, hint=None):
153 | T = motion_lens
154 | batch = {}
155 | batch["prompt"] = list(text)
156 | batch["text"] = list(text)
157 | batch["text_multi_person"] = list(text_multi_person)
158 | batch["person_num"] = person_num
159 | batch["motion_lens"] = T # torch.tensor([T]).unsqueeze(0).long().to(torch.device("cuda:0"))
160 | batch["motion_guidance"] = motion_guidance
161 |
162 | if hint is not None:
163 | hint = hint[:,:T,...]
164 | batch["spatial_condition"] = hint
165 | batch = model.forward_test(batch)
166 | output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], 1, -1)
167 | return output
168 |
169 | generated_motions = []
170 | motion = generate(motion_lens.to(self.device), texts[0], text_multi_person)
171 |
172 | generated_motions.append(motion)
173 |
174 | if FLAG != 0:
175 | for person_idx, text in enumerate(texts[1:]):
176 | tmp_motions = [ m.detach().float().to(self.device) for m in generated_motions]
177 | motion_guidance = torch.cat(tmp_motions, dim=-2)
178 | motion = generate(motion_lens.to(self.device), text, text_multi_person, person_idx+2, motion_guidance, None) #hint.to(self.device))
179 | generated_motions.append(motion)
180 | return torch.cat(generated_motions, dim=-2) # [:motion_lens.item()] # [B,T,P,D]
181 |
182 | def intergen_generate(self, model, motion_lens, texts, text_multi_person, motions, person_num=0):
183 |
184 | def generate(motion_lens, texts, text_multi_person, person_num=1):
185 | T = motion_lens
186 | batch = {}
187 | batch["text1"] = list(texts[0])
188 | batch["text2"] = list(texts[1])
189 |
190 | batch["text"] = list(text_multi_person)
191 | batch["person_num"] = person_num
192 | batch["motion_lens"] = T
193 |
194 | batch = model.forward_test(batch)
195 | output = batch["output"].reshape(batch["output"].shape[0], batch["output"].shape[1], person_num, -1)
196 | return output
197 |
198 | motion = generate(motion_lens.to(self.device), texts, text_multi_person, person_num=person_num)
199 |
200 | return motion
201 |
202 | class MMGeneratedDataset(Dataset):
203 | def __init__(self, motion_dataset):
204 | self.dataset = motion_dataset.mm_generated_motions
205 |
206 | def __len__(self):
207 | return len(self.dataset)
208 |
209 | def __getitem__(self, item):
210 | data = self.dataset[item]
211 | mm_motions = data['mm_motions']
212 | motion_lens = data['motion_lens']
213 | mm_motions1 = mm_motions[:,:,0]
214 | mm_motions2 = mm_motions[:,:,1]
215 | text = data['text']
216 | motion_lens = np.array([motion_lens]*mm_motions1.shape[0])
217 | return "mm_generated", text, "placeholder", "placeholder", mm_motions1, mm_motions2, motion_lens, torch.zeros(0)
218 |
219 |
220 | def get_dataset_motion_loader(opt, batch_size):
221 | opt = copy.deepcopy(opt)
222 | # Configurations of T2M dataset and KIT dataset is almost the same
223 |
224 | if opt.NAME == 'interhuman':
225 | print('Loading dataset %s ...' % opt.NAME)
226 | dataset = InterHumanPipelineInferDataset(opt)
227 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=True, shuffle=True)
228 | else:
229 | raise KeyError('Dataset not Recognized !!')
230 |
231 | print('Ground Truth Dataset Loading Completed!!!')
232 | return dataloader, dataset
233 |
234 | def get_motion_loader(batch_size, model, ground_truth_dataset, device, mm_num_samples, mm_num_repeats):
235 |
236 | dataset = EvaluationDataset(model, ground_truth_dataset, device, mm_num_samples=mm_num_samples, mm_num_repeats=mm_num_repeats)
237 | mm_dataset = MMGeneratedDataset(dataset)
238 |
239 | motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True)
240 | mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0)
241 |
242 | print('Generated Dataset Loading Completed!!!')
243 |
244 | return motion_loader, mm_motion_loader
245 |
246 |
247 |
248 |
249 | def build_models(cfg):
250 | model = InterCLIP(cfg)
251 | checkpoint = torch.load(pjoin('eval_model/interclip.ckpt'),map_location="cpu")
252 |
253 | # checkpoint = torch.load(pjoin('checkpoints/interclip/model/5.ckpt'),map_location="cpu")
254 | for k in list(checkpoint["state_dict"].keys()):
255 | if "model" in k:
256 | checkpoint["state_dict"][k.replace("model.", "")] = checkpoint["state_dict"].pop(k)
257 | model.load_state_dict(checkpoint["state_dict"], strict=True)
258 |
259 | # print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
260 | return model
261 |
262 |
263 | class EvaluatorModelWrapper(object):
264 |
265 | def __init__(self, cfg, device):
266 |
267 | self.model = build_models(cfg)
268 | self.cfg = cfg
269 | self.device = device
270 |
271 | self.model = self.model.to(device)
272 | self.model.eval()
273 |
274 |
275 | # Please note that the results does not following the order of inputs
276 | def get_co_embeddings(self, batch_data):
277 | with torch.no_grad():
278 | name, text, text1, text2, motion1, motion2, motion_lens, _ = batch_data
279 |
280 | motion1 = motion1.detach().float() # .to(self.device)
281 | motion2 = motion2.detach().float() # .to(self.device)
282 | motions = torch.cat([motion1, motion2], dim=-1)
283 | motions = motions.detach().to(self.device).float()
284 |
285 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
286 | motions = motions[align_idx]
287 | motion_lens = motion_lens[align_idx]
288 | text = list(text)
289 |
290 | B, T = motions.shape[:2]
291 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
292 | padded_len = cur_len.max()
293 |
294 | batch = {}
295 | batch["text"] = text
296 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
297 | batch["motion_lens"] = motion_lens
298 |
299 | '''Motion Encoding'''
300 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
301 |
302 | '''Text Encoding'''
303 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx]
304 |
305 | return text_embedding, motion_embedding
306 |
307 | # Please note that the results does not following the order of inputs
308 | def get_motion_embeddings(self, batch_data):
309 | with torch.no_grad():
310 | name, text, text1, text2, motion1, motion2, motion_lens, _ = batch_data
311 | motion1 = motion1.detach().float() # .to(self.device)
312 | motion2 = motion2.detach().float() # .to(self.device)
313 | motions = torch.cat([motion1, motion2], dim=-1)
314 | motions = motions.detach().to(self.device).float()
315 |
316 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
317 | motions = motions[align_idx]
318 | motion_lens = motion_lens[align_idx]
319 | text = list(text)
320 |
321 | B, T = motions.shape[:2]
322 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
323 | padded_len = cur_len.max()
324 |
325 | batch = {}
326 | batch["text"] = text
327 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
328 | batch["motion_lens"] = motion_lens
329 |
330 | '''Motion Encoding'''
331 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
332 |
333 | return motion_embedding
334 |
335 |
336 | class TrainEvaluatorModelWrapper(object):
337 |
338 | def __init__(self, cfg, device):
339 |
340 | self.model = build_models(cfg)
341 | self.cfg = cfg
342 | self.device = device
343 |
344 | self.model = self.model.to(device)
345 | # self.model.eval()
346 |
347 |
348 | # Please note that the results does not following the order of inputs
349 | def get_co_embeddings(self, batch_data):
350 | name, text, motion1, motion2, motion_lens = batch_data
351 | motion1 = motion1.detach().float() # .to(self.device)
352 | motion2 = motion2.detach().float() # .to(self.device)
353 | motions = torch.cat([motion1, motion2], dim=-1)
354 | motions = motions.detach().to(self.device).float()
355 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
356 | motions = motions[align_idx]
357 | motion_lens = motion_lens[align_idx]
358 | text = list(text)
359 | B, T = motions.shape[:2]
360 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
361 | padded_len = cur_len.max()
362 | batch = {}
363 | batch["text"] = text
364 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
365 | batch["motion_lens"] = motion_lens
366 | '''Motion Encoding'''
367 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
368 | '''Text Encoding'''
369 | text_embedding = self.model.encode_text(batch)['text_emb'][align_idx]
370 |
371 | return text_embedding, motion_embedding
372 |
373 | # Please note that the results does not following the order of inputs
374 | def get_motion_embeddings(self, batch_data):
375 | with torch.no_grad():
376 | name, text, motion1, motion2, motion_lens = batch_data
377 | motion1 = motion1.detach().float() # .to(self.device)
378 | motion2 = motion2.detach().float() # .to(self.device)
379 | motions = torch.cat([motion1, motion2], dim=-1)
380 | motions = motions.detach().to(self.device).float()
381 |
382 | align_idx = np.argsort(motion_lens.data.tolist())[::-1].copy()
383 | motions = motions[align_idx]
384 | motion_lens = motion_lens[align_idx]
385 | text = list(text)
386 |
387 | B, T = motions.shape[:2]
388 | cur_len = torch.LongTensor([min(T, m_len) for m_len in motion_lens]).to(self.device)
389 | padded_len = cur_len.max()
390 |
391 | batch = {}
392 | batch["text"] = text
393 | batch["motions"] = motions.reshape(B, T, -1)[:, :padded_len]
394 | batch["motion_lens"] = motion_lens
395 |
396 | '''Motion Encoding'''
397 | motion_embedding = self.model.encode_motion(batch)['motion_emb']
398 |
399 | return motion_embedding
400 |
401 | def compute_contrastive_loss(self, batch_data):
402 |
403 | loss_total, losses = self.model(batch_data)
404 |
405 | return loss_total, losses
406 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.utils import *
5 |
6 | kinematic_chain = [[0, 2, 5, 8, 11],
7 | [0, 1, 4, 7, 10],
8 | [0, 3, 6, 9, 12, 15],
9 | [9, 14, 17, 19, 21],
10 | [9, 13, 16, 18, 20]]
11 |
12 | class InterLoss(nn.Module):
13 | def __init__(self, recons_loss, nb_joints):
14 | super(InterLoss, self).__init__()
15 | self.nb_joints = nb_joints
16 | if recons_loss == 'l1':
17 | self.Loss = torch.nn.L1Loss(reduction='none')
18 | elif recons_loss == 'l2':
19 | self.Loss = torch.nn.MSELoss(reduction='none')
20 | elif recons_loss == 'l1_smooth':
21 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
22 |
23 | self.normalizer = MotionNormalizerTorch()
24 |
25 | self.weights = {}
26 | self.weights["RO"] = 0.01
27 | self.weights["JA"] = 3
28 | self.weights["DM"] = 3
29 |
30 | self.losses = {}
31 |
32 | def seq_masked_mse(self, prediction, target, mask):
33 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
34 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7)
35 | return loss
36 |
37 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None):
38 | if dm_mask is not None:
39 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7)
40 | else:
41 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1]
42 | if contact_mask is not None:
43 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7)
44 | loss = (loss * mask).sum(dim=(-1, -2, -3)) / (mask.sum(dim=(-1, -2, -3)) + 1.e-7) # [b]
45 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7)
46 |
47 | return loss
48 |
49 | def forward(self, motion_pred, motion_gt, mask, timestep_mask):
50 | B, T = motion_pred.shape[:2]
51 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask)
52 | target = self.normalizer.backward(motion_gt, global_rt=True)
53 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
54 |
55 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
56 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
57 |
58 | self.mask = mask
59 | self.timestep_mask = timestep_mask
60 |
61 | self.forward_distance_map(thresh=1)
62 | self.forward_joint_affinity(thresh=0.1)
63 | self.forward_relatvie_rot()
64 | self.accum_loss()
65 |
66 |
67 | def forward_relatvie_rot(self):
68 |
69 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
70 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :]
71 | across = across / across.norm(dim=-1, keepdim=True)
72 | across_gt = self.tgt_g_joints[..., r_hip, :] - self.tgt_g_joints[..., l_hip, :]
73 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)
74 |
75 | y_axis = torch.zeros_like(across)
76 | y_axis[..., 1] = 1
77 |
78 | forward = torch.cross(y_axis, across, axis=-1)
79 | forward = forward / forward.norm(dim=-1, keepdim=True)
80 |
81 | forward_gt = torch.cross(y_axis, across_gt, axis=-1)
82 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)
83 |
84 | pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
85 | tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])
86 |
87 | self.losses["RO"] = self.mix_masked_mse(pred_relative_rot[..., [0, 2]],
88 | tgt_relative_rot[..., [0, 2]],
89 | self.mask[..., 0, :], self.timestep_mask) * self.weights["RO"]
90 |
91 |
92 | def forward_distance_map(self, thresh):
93 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,))
94 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,))
95 |
96 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
97 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
98 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
99 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
100 |
101 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape(
102 | self.mask.shape[:-2] + (1, -1,))
103 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape(
104 | self.mask.shape[:-2] + (1, -1,))
105 |
106 | distance_matrix_mask = (pred_distance_matrix < thresh).float()
107 |
108 | self.losses["DM"] = self.mix_masked_mse(pred_distance_matrix, tgt_distance_matrix,
109 | self.mask[..., 0:1, :],
110 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["DM"]
111 |
112 | def forward_joint_affinity(self, thresh):
113 | pred_g_joints = self.pred_g_joints.reshape(self.mask.shape[:-1] + (-1,))
114 | tgt_g_joints = self.tgt_g_joints.reshape(self.mask.shape[:-1] + (-1,))
115 |
116 | pred_g_joints1 = pred_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
117 | pred_g_joints2 = pred_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
118 | tgt_g_joints1 = tgt_g_joints[..., 0:1, :].reshape(-1, self.nb_joints, 3)
119 | tgt_g_joints2 = tgt_g_joints[..., 1:2, :].reshape(-1, self.nb_joints, 3)
120 |
121 | pred_distance_matrix = torch.cdist(pred_g_joints1.contiguous(), pred_g_joints2).reshape(
122 | self.mask.shape[:-2] + (1, -1,))
123 | tgt_distance_matrix = torch.cdist(tgt_g_joints1.contiguous(), tgt_g_joints2).reshape(
124 | self.mask.shape[:-2] + (1, -1,))
125 |
126 | distance_matrix_mask = (tgt_distance_matrix < thresh).float()
127 |
128 | self.losses["JA"] = self.mix_masked_mse(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix),
129 | self.mask[..., 0:1, :],
130 | self.timestep_mask, dm_mask=distance_matrix_mask) * self.weights["JA"]
131 |
132 | def accum_loss(self):
133 | loss = 0
134 | for term in self.losses.keys():
135 | loss += self.losses[term]
136 | self.losses["total"] = loss
137 | return self.losses
138 |
139 |
140 | class SingleLoss(nn.Module):
141 | def __init__(self, recons_loss, nb_joints):
142 | super(SingleLoss, self).__init__()
143 | self.nb_joints = nb_joints
144 | if recons_loss == 'l1':
145 | self.Loss = torch.nn.L1Loss(reduction='none')
146 | elif recons_loss == 'l2':
147 | self.Loss = torch.nn.MSELoss(reduction='none')
148 | elif recons_loss == 'l1_smooth':
149 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
150 |
151 | self.normalizer = MotionNormalizerTorch()
152 |
153 | self.losses = {}
154 |
155 | def seq_masked_mse(self, prediction, target, mask):
156 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
157 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7)
158 | return loss
159 |
160 | def forward(self, motion_pred, motion_gt, mask, timestep_mask):
161 | B, T = motion_pred.shape[:2]
162 | self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask)
163 | target = self.normalizer.backward(motion_gt, global_rt=True)
164 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
165 |
166 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
167 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3)
168 |
169 | self.mask = mask
170 | self.timestep_mask = timestep_mask
171 |
172 | self.accum_loss()
173 |
174 | def accum_loss(self):
175 | loss = 0
176 | for term in self.losses.keys():
177 | loss += self.losses[term]
178 | self.losses["total"] = loss
179 | return self.losses
180 |
181 |
182 | class GeometricLoss(nn.Module):
183 | def __init__(self, recons_loss, nb_joints, name):
184 | super(GeometricLoss, self).__init__()
185 | self.name = name
186 | self.nb_joints = nb_joints
187 | if recons_loss == 'l1':
188 | self.Loss = torch.nn.L1Loss(reduction='none')
189 | elif recons_loss == 'l2':
190 | self.Loss = torch.nn.MSELoss(reduction='none')
191 | elif recons_loss == 'l1_smooth':
192 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
193 |
194 | self.normalizer = MotionNormalizerTorch()
195 | self.fids = [7, 10, 8, 11]
196 |
197 | self.weights = {}
198 | self.weights["VEL"] = 30
199 | self.weights["BL"] = 10
200 | self.weights["FC"] = 30
201 | self.weights["POSE"] = 1
202 | self.weights["TR"] = 100
203 |
204 | self.losses = {}
205 |
206 | def seq_masked_mse(self, prediction, target, mask):
207 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
208 | loss = (loss * mask).sum() / (mask.sum() + 1.e-7)
209 | return loss
210 |
211 | def mix_masked_mse(self, prediction, target, mask, batch_mask, contact_mask=None, dm_mask=None):
212 | if dm_mask is not None:
213 | loss = (self.Loss(prediction, target) * dm_mask).sum(dim=-1, keepdim=True)/ (dm_mask.sum(dim=-1, keepdim=True) + 1.e-7) # [b,t,p,4,1]
214 | else:
215 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True) # [b,t,p,4,1]
216 | if contact_mask is not None:
217 | loss = (loss[..., 0] * contact_mask).sum(dim=-1, keepdim=True) / (contact_mask.sum(dim=-1, keepdim=True) + 1.e-7)
218 | loss = (loss * mask).sum(dim=(-1, -2)) / (mask.sum(dim=(-1, -2)) + 1.e-7) # [b]
219 | loss = (loss * batch_mask).sum(dim=0) / (batch_mask.sum(dim=0) + 1.e-7)
220 |
221 | return loss
222 |
223 | def forward(self, motion_pred, motion_gt, mask, timestep_mask):
224 | B, T = motion_pred.shape[:2]
225 | # self.losses["simple"] = self.seq_masked_mse(motion_pred, motion_gt, mask) # * 0.01
226 | target = self.normalizer.backward(motion_gt, global_rt=True)
227 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
228 |
229 | self.first_motion_pred =motion_pred[:,0:1]
230 | self.first_motion_gt =motion_gt[:,0:1]
231 |
232 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3)
233 | self.tgt_g_joints = target[..., :self.nb_joints * 3].reshape(B, T, self.nb_joints, 3)
234 | self.mask = mask
235 | self.timestep_mask = timestep_mask
236 |
237 | self.forward_vel()
238 | self.forward_bone_length()
239 | self.forward_contact()
240 | self.accum_loss()
241 | # return self.losses["simple"]
242 |
243 | def get_local_positions(self, positions, r_rot):
244 | '''Local pose'''
245 | positions[..., 0] -= positions[..., 0:1, 0]
246 | positions[..., 2] -= positions[..., 0:1, 2]
247 | '''All pose face Z+'''
248 | positions = qrot(r_rot[..., None, :].repeat(1, 1, positions.shape[-2], 1), positions)
249 | return positions
250 |
251 | def forward_local_pose(self):
252 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
253 |
254 | pred_g_joints = self.pred_g_joints.clone()
255 | tgt_g_joints = self.tgt_g_joints.clone()
256 |
257 | across = pred_g_joints[..., r_hip, :] - pred_g_joints[..., l_hip, :]
258 | across = across / across.norm(dim=-1, keepdim=True)
259 | across_gt = tgt_g_joints[..., r_hip, :] - tgt_g_joints[..., l_hip, :]
260 | across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)
261 |
262 | y_axis = torch.zeros_like(across)
263 | y_axis[..., 1] = 1
264 |
265 | forward = torch.cross(y_axis, across, axis=-1)
266 | forward = forward / forward.norm(dim=-1, keepdim=True)
267 | forward_gt = torch.cross(y_axis, across_gt, axis=-1)
268 | forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)
269 |
270 | z_axis = torch.zeros_like(forward)
271 | z_axis[..., 2] = 1
272 | noise = torch.randn_like(z_axis) *0.0001
273 | z_axis = z_axis+noise
274 | z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
275 |
276 |
277 | pred_rot = qbetween(forward, z_axis)
278 | tgt_rot = qbetween(forward_gt, z_axis)
279 |
280 | B, T, J, D = self.pred_g_joints.shape
281 | pred_joints = self.get_local_positions(pred_g_joints, pred_rot).reshape(B, T, -1)
282 | tgt_joints = self.get_local_positions(tgt_g_joints, tgt_rot).reshape(B, T, -1)
283 |
284 | self.losses["POSE_"+self.name] = self.mix_masked_mse(pred_joints, tgt_joints, self.mask, self.timestep_mask) * self.weights["POSE"]
285 |
286 | def forward_vel(self):
287 | B, T = self.pred_g_joints.shape[:2]
288 |
289 | pred_vel = self.pred_g_joints[:, 1:] - self.pred_g_joints[:, :-1]
290 | tgt_vel = self.tgt_g_joints[:, 1:] - self.tgt_g_joints[:, :-1]
291 |
292 | pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,))
293 | tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,))
294 |
295 | self.losses["VEL_"+self.name] = self.mix_masked_mse(pred_vel, tgt_vel, self.mask[:, :-1], self.timestep_mask) * self.weights["VEL"]
296 |
297 |
298 | def forward_contact(self):
299 |
300 | feet_vel = self.pred_g_joints[:, 1:, self.fids, :] - self.pred_g_joints[:, :-1, self.fids,:]
301 | feet_h = self.pred_g_joints[:, :-1, self.fids, 1]
302 | # contact = target[:,:-1,:,-8:-4] # [b,t,p,4]
303 |
304 | contact = self.foot_detect(feet_vel, feet_h, 0.001)
305 |
306 | self.losses["FC_"+self.name] = self.mix_masked_mse(feet_vel, torch.zeros_like(feet_vel), self.mask[:, :-1],
307 | self.timestep_mask,
308 | contact) * self.weights["FC"]
309 |
310 |
311 |
312 | def forward_bone_length(self):
313 | pred_g_joints = self.pred_g_joints
314 | tgt_g_joints = self.tgt_g_joints
315 | pred_bones = []
316 | tgt_bones = []
317 | for chain in kinematic_chain:
318 | for i, joint in enumerate(chain[:-1]):
319 | pred_bone = (pred_g_joints[..., chain[i], :] - pred_g_joints[..., chain[i + 1], :]).norm(dim=-1,
320 | keepdim=True) # [B,T,P,1]
321 | tgt_bone = (tgt_g_joints[..., chain[i], :] - tgt_g_joints[..., chain[i + 1], :]).norm(dim=-1,
322 | keepdim=True)
323 | pred_bones.append(pred_bone)
324 | tgt_bones.append(tgt_bone)
325 |
326 | pred_bones = torch.cat(pred_bones, dim=-1)
327 | tgt_bones = torch.cat(tgt_bones, dim=-1)
328 |
329 | self.losses["BL_"+self.name] = self.mix_masked_mse(pred_bones, tgt_bones, self.mask, self.timestep_mask) * self.weights[
330 | "BL"]
331 |
332 |
333 | def forward_traj(self):
334 | B, T = self.pred_g_joints.shape[:2]
335 |
336 | pred_traj = self.pred_g_joints[..., 0, [0, 2]]
337 | tgt_g_traj = self.tgt_g_joints[..., 0, [0, 2]]
338 |
339 | self.losses["TR_"+self.name] = self.mix_masked_mse(pred_traj, tgt_g_traj, self.mask, self.timestep_mask) * self.weights["TR"]
340 |
341 |
342 | def accum_loss(self):
343 | loss = 0
344 | for term in self.losses.keys():
345 | loss += self.losses[term]
346 | self.losses[self.name] = loss
347 |
348 | def foot_detect(self, feet_vel, feet_h, thres):
349 | velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor(
350 | [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device)
351 |
352 | feet_x = (feet_vel[..., 0]) ** 2
353 | feet_y = (feet_vel[..., 1]) ** 2
354 | feet_z = (feet_vel[..., 2]) ** 2
355 |
356 | contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float()
357 | return contact
358 |
359 | class InitOrientationAndPositionLoss(nn.Module):
360 | def __init__(self, recons_loss, nb_joints):
361 | super(InitOrientationAndPositionLoss, self).__init__()
362 | self.nb_joints = nb_joints
363 | if recons_loss == 'l1':
364 | self.Loss = torch.nn.L1Loss(reduction='none')
365 | elif recons_loss == 'l2':
366 | self.Loss = torch.nn.MSELoss(reduction='none')
367 | elif recons_loss == 'l1_smooth':
368 | self.Loss = torch.nn.SmoothL1Loss(reduction='none')
369 |
370 | self.normalizer = MotionNormalizerTorch()
371 |
372 | self.weights = {}
373 | self.weights["init_ori"] = 0.01
374 | self.weights["init_pos"] = 3
375 |
376 | self.losses = {}
377 |
378 | def mse(self, prediction, target, weights=1.0):
379 | loss = self.Loss(prediction, target).mean(dim=-1, keepdim=True)
380 | loss = loss.sum() * weights
381 | return loss
382 |
383 | def forward(self, motion_pred, spatial_gt):
384 |
385 | B, T = motion_pred.shape[:2]
386 |
387 | prediction = self.normalizer.backward(motion_pred, global_rt=True)
388 |
389 | self.pred_g_joints = prediction[..., :self.nb_joints * 3].reshape(B, T, -1, self.nb_joints, 3) # [B, T, 1, J, 3]
390 |
391 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
392 | across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :]
393 | across = across / across.norm(dim=-1, keepdim=True)
394 |
395 | y_axis = torch.zeros_like(across) # [B, T, 1, 3]
396 | y_axis[..., 1] = 1
397 |
398 | pred_init_forward = torch.cross(y_axis, across, axis=-1) # [B, T, 1, 3]????
399 |
400 | pred_init_forward = (pred_init_forward / pred_init_forward.norm(dim=-1, keepdim=True))[:,0,0,[0,2]] # [B, T, 1, 3] -> [B, 2]
401 |
402 | pred_init_pos = self.pred_g_joints[..., 0, :][:,0,0,[0,2]] # # [B, T, 1, J, 3] -> [B, T, 1, 3] -> [B, 2]
403 |
404 | self.losses['spatial'] = self.mse(pred_init_forward, spatial_gt[..., 0:2], self.weights['init_ori']) + self.mse(pred_init_pos, spatial_gt[..., 2:4], self.weights['init_pos'])
405 |
406 |
--------------------------------------------------------------------------------
/utils/rotation_conversions.py:
--------------------------------------------------------------------------------
1 | # This code is based on https://github.com/Mathux/ACTOR.git
2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3 | # Check PYTORCH3D_LICENCE before use
4 |
5 | import functools
6 | from typing import Optional
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | """
13 | The transformation matrices returned from the functions in this file assume
14 | the points on which the transformation will be applied are column vectors.
15 | i.e. the R matrix is structured as
16 |
17 | R = [
18 | [Rxx, Rxy, Rxz],
19 | [Ryx, Ryy, Ryz],
20 | [Rzx, Rzy, Rzz],
21 | ] # (3, 3)
22 |
23 | This matrix can be applied to column vectors by post multiplication
24 | by the points e.g.
25 |
26 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
27 | transformed_points = R * points
28 |
29 | To apply the same matrix to points which are row vectors, the R matrix
30 | can be transposed and pre multiplied by the points:
31 |
32 | e.g.
33 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
34 | transformed_points = points * R.transpose(1, 0)
35 | """
36 |
37 |
38 | def quaternion_to_matrix(quaternions):
39 | """
40 | Convert rotations given as quaternions to rotation matrices.
41 |
42 | Args:
43 | quaternions: quaternions with real part first,
44 | as tensor of shape (..., 4).
45 |
46 | Returns:
47 | Rotation matrices as tensor of shape (..., 3, 3).
48 | """
49 | r, i, j, k = torch.unbind(quaternions, -1)
50 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
51 |
52 | o = torch.stack(
53 | (
54 | 1 - two_s * (j * j + k * k),
55 | two_s * (i * j - k * r),
56 | two_s * (i * k + j * r),
57 | two_s * (i * j + k * r),
58 | 1 - two_s * (i * i + k * k),
59 | two_s * (j * k - i * r),
60 | two_s * (i * k - j * r),
61 | two_s * (j * k + i * r),
62 | 1 - two_s * (i * i + j * j),
63 | ),
64 | -1,
65 | )
66 | return o.reshape(quaternions.shape[:-1] + (3, 3))
67 |
68 |
69 | def _copysign(a, b):
70 | """
71 | Return a tensor where each element has the absolute value taken from the,
72 | corresponding element of a, with sign taken from the corresponding
73 | element of b. This is like the standard copysign floating-point operation,
74 | but is not careful about negative 0 and NaN.
75 |
76 | Args:
77 | a: source tensor.
78 | b: tensor whose signs will be used, of the same shape as a.
79 |
80 | Returns:
81 | Tensor of the same shape as a with the signs of b.
82 | """
83 | signs_differ = (a < 0) != (b < 0)
84 | return torch.where(signs_differ, -a, a)
85 |
86 |
87 | def _sqrt_positive_part(x):
88 | """
89 | Returns torch.sqrt(torch.max(0, x))
90 | but with a zero subgradient where x is 0.
91 | """
92 | ret = torch.zeros_like(x)
93 | positive_mask = x > 0
94 | ret[positive_mask] = torch.sqrt(x[positive_mask])
95 | return ret
96 |
97 |
98 | def matrix_to_quaternion(matrix):
99 | """
100 | Convert rotations given as rotation matrices to quaternions.
101 |
102 | Args:
103 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
104 |
105 | Returns:
106 | quaternions with real part first, as tensor of shape (..., 4).
107 | """
108 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
109 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
110 | m00 = matrix[..., 0, 0]
111 | m11 = matrix[..., 1, 1]
112 | m22 = matrix[..., 2, 2]
113 | o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
114 | x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
115 | y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
116 | z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
117 | o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
118 | o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
119 | o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
120 | return torch.stack((o0, o1, o2, o3), -1)
121 |
122 |
123 | def _axis_angle_rotation(axis: str, angle):
124 | """
125 | Return the rotation matrices for one of the rotations about an axis
126 | of which Euler angles describe, for each value of the angle given.
127 |
128 | Args:
129 | axis: Axis label "X" or "Y or "Z".
130 | angle: any shape tensor of Euler angles in radians
131 |
132 | Returns:
133 | Rotation matrices as tensor of shape (..., 3, 3).
134 | """
135 |
136 | cos = torch.cos(angle)
137 | sin = torch.sin(angle)
138 | one = torch.ones_like(angle)
139 | zero = torch.zeros_like(angle)
140 |
141 | if axis == "X":
142 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
143 | if axis == "Y":
144 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
145 | if axis == "Z":
146 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
147 |
148 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
149 |
150 |
151 | def euler_angles_to_matrix(euler_angles, convention: str):
152 | """
153 | Convert rotations given as Euler angles in radians to rotation matrices.
154 |
155 | Args:
156 | euler_angles: Euler angles in radians as tensor of shape (..., 3).
157 | convention: Convention string of three uppercase letters from
158 | {"X", "Y", and "Z"}.
159 |
160 | Returns:
161 | Rotation matrices as tensor of shape (..., 3, 3).
162 | """
163 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
164 | raise ValueError("Invalid input euler angles.")
165 | if len(convention) != 3:
166 | raise ValueError("Convention must have 3 letters.")
167 | if convention[1] in (convention[0], convention[2]):
168 | raise ValueError(f"Invalid convention {convention}.")
169 | for letter in convention:
170 | if letter not in ("X", "Y", "Z"):
171 | raise ValueError(f"Invalid letter {letter} in convention string.")
172 | matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
173 | return functools.reduce(torch.matmul, matrices)
174 |
175 |
176 | def _angle_from_tan(
177 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
178 | ):
179 | """
180 | Extract the first or third Euler angle from the two members of
181 | the matrix which are positive constant times its sine and cosine.
182 |
183 | Args:
184 | axis: Axis label "X" or "Y or "Z" for the angle we are finding.
185 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
186 | convention.
187 | data: Rotation matrices as tensor of shape (..., 3, 3).
188 | horizontal: Whether we are looking for the angle for the third axis,
189 | which means the relevant entries are in the same row of the
190 | rotation matrix. If not, they are in the same column.
191 | tait_bryan: Whether the first and third axes in the convention differ.
192 |
193 | Returns:
194 | Euler Angles in radians for each matrix in dataset as a tensor
195 | of shape (...).
196 | """
197 |
198 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
199 | if horizontal:
200 | i2, i1 = i1, i2
201 | even = (axis + other_axis) in ["XY", "YZ", "ZX"]
202 | if horizontal == even:
203 | return torch.atan2(data[..., i1], data[..., i2])
204 | if tait_bryan:
205 | return torch.atan2(-data[..., i2], data[..., i1])
206 | return torch.atan2(data[..., i2], -data[..., i1])
207 |
208 |
209 | def _index_from_letter(letter: str):
210 | if letter == "X":
211 | return 0
212 | if letter == "Y":
213 | return 1
214 | if letter == "Z":
215 | return 2
216 |
217 |
218 | def matrix_to_euler_angles(matrix, convention: str):
219 | """
220 | Convert rotations given as rotation matrices to Euler angles in radians.
221 |
222 | Args:
223 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
224 | convention: Convention string of three uppercase letters.
225 |
226 | Returns:
227 | Euler angles in radians as tensor of shape (..., 3).
228 | """
229 | if len(convention) != 3:
230 | raise ValueError("Convention must have 3 letters.")
231 | if convention[1] in (convention[0], convention[2]):
232 | raise ValueError(f"Invalid convention {convention}.")
233 | for letter in convention:
234 | if letter not in ("X", "Y", "Z"):
235 | raise ValueError(f"Invalid letter {letter} in convention string.")
236 | if matrix.size(-1) != 3 or matrix.size(-2) != 3:
237 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
238 | i0 = _index_from_letter(convention[0])
239 | i2 = _index_from_letter(convention[2])
240 | tait_bryan = i0 != i2
241 | if tait_bryan:
242 | central_angle = torch.asin(
243 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
244 | )
245 | else:
246 | central_angle = torch.acos(matrix[..., i0, i0])
247 |
248 | o = (
249 | _angle_from_tan(
250 | convention[0], convention[1], matrix[..., i2], False, tait_bryan
251 | ),
252 | central_angle,
253 | _angle_from_tan(
254 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
255 | ),
256 | )
257 | return torch.stack(o, -1)
258 |
259 |
260 | def random_quaternions(
261 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
262 | ):
263 | """
264 | Generate random quaternions representing rotations,
265 | i.e. versors with nonnegative real part.
266 |
267 | Args:
268 | n: Number of quaternions in a batch to return.
269 | dtype: Type to return.
270 | device: Desired device of returned tensor. Default:
271 | uses the current device for the default tensor type.
272 | requires_grad: Whether the resulting tensor should have the gradient
273 | flag set.
274 |
275 | Returns:
276 | Quaternions as tensor of shape (N, 4).
277 | """
278 | o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
279 | s = (o * o).sum(1)
280 | o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
281 | return o
282 |
283 |
284 | def random_rotations(
285 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
286 | ):
287 | """
288 | Generate random rotations as 3x3 rotation matrices.
289 |
290 | Args:
291 | n: Number of rotation matrices in a batch to return.
292 | dtype: Type to return.
293 | device: Device of returned tensor. Default: if None,
294 | uses the current device for the default tensor type.
295 | requires_grad: Whether the resulting tensor should have the gradient
296 | flag set.
297 |
298 | Returns:
299 | Rotation matrices as tensor of shape (n, 3, 3).
300 | """
301 | quaternions = random_quaternions(
302 | n, dtype=dtype, device=device, requires_grad=requires_grad
303 | )
304 | return quaternion_to_matrix(quaternions)
305 |
306 |
307 | def random_rotation(
308 | dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
309 | ):
310 | """
311 | Generate a single random 3x3 rotation matrix.
312 |
313 | Args:
314 | dtype: Type to return
315 | device: Device of returned tensor. Default: if None,
316 | uses the current device for the default tensor type
317 | requires_grad: Whether the resulting tensor should have the gradient
318 | flag set
319 |
320 | Returns:
321 | Rotation matrix as tensor of shape (3, 3).
322 | """
323 | return random_rotations(1, dtype, device, requires_grad)[0]
324 |
325 |
326 | def standardize_quaternion(quaternions):
327 | """
328 | Convert a unit quaternion to a standard form: one in which the real
329 | part is non negative.
330 |
331 | Args:
332 | quaternions: Quaternions with real part first,
333 | as tensor of shape (..., 4).
334 |
335 | Returns:
336 | Standardized quaternions as tensor of shape (..., 4).
337 | """
338 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
339 |
340 |
341 | def quaternion_raw_multiply(a, b):
342 | """
343 | Multiply two quaternions.
344 | Usual torch rules for broadcasting apply.
345 |
346 | Args:
347 | a: Quaternions as tensor of shape (..., 4), real part first.
348 | b: Quaternions as tensor of shape (..., 4), real part first.
349 |
350 | Returns:
351 | The product of a and b, a tensor of quaternions shape (..., 4).
352 | """
353 | aw, ax, ay, az = torch.unbind(a, -1)
354 | bw, bx, by, bz = torch.unbind(b, -1)
355 | ow = aw * bw - ax * bx - ay * by - az * bz
356 | ox = aw * bx + ax * bw + ay * bz - az * by
357 | oy = aw * by - ax * bz + ay * bw + az * bx
358 | oz = aw * bz + ax * by - ay * bx + az * bw
359 | return torch.stack((ow, ox, oy, oz), -1)
360 |
361 |
362 | def quaternion_multiply(a, b):
363 | """
364 | Multiply two quaternions representing rotations, returning the quaternion
365 | representing their composition, i.e. the versor with nonnegative real part.
366 | Usual torch rules for broadcasting apply.
367 |
368 | Args:
369 | a: Quaternions as tensor of shape (..., 4), real part first.
370 | b: Quaternions as tensor of shape (..., 4), real part first.
371 |
372 | Returns:
373 | The product of a and b, a tensor of quaternions of shape (..., 4).
374 | """
375 | ab = quaternion_raw_multiply(a, b)
376 | return standardize_quaternion(ab)
377 |
378 |
379 | def quaternion_invert(quaternion):
380 | """
381 | Given a quaternion representing rotation, get the quaternion representing
382 | its inverse.
383 |
384 | Args:
385 | quaternion: Quaternions as tensor of shape (..., 4), with real part
386 | first, which must be versors (unit quaternions).
387 |
388 | Returns:
389 | The inverse, a tensor of quaternions of shape (..., 4).
390 | """
391 |
392 | return quaternion * quaternion.new_tensor([1, -1, -1, -1])
393 |
394 |
395 | def quaternion_apply(quaternion, point):
396 | """
397 | Apply the rotation given by a quaternion to a 3D point.
398 | Usual torch rules for broadcasting apply.
399 |
400 | Args:
401 | quaternion: Tensor of quaternions, real part first, of shape (..., 4).
402 | point: Tensor of 3D points of shape (..., 3).
403 |
404 | Returns:
405 | Tensor of rotated points of shape (..., 3).
406 | """
407 | if point.size(-1) != 3:
408 | raise ValueError(f"Points are not in 3D, f{point.shape}.")
409 | real_parts = point.new_zeros(point.shape[:-1] + (1,))
410 | point_as_quaternion = torch.cat((real_parts, point), -1)
411 | out = quaternion_raw_multiply(
412 | quaternion_raw_multiply(quaternion, point_as_quaternion),
413 | quaternion_invert(quaternion),
414 | )
415 | return out[..., 1:]
416 |
417 |
418 | def axis_angle_to_matrix(axis_angle):
419 | """
420 | Convert rotations given as axis/angle to rotation matrices.
421 |
422 | Args:
423 | axis_angle: Rotations given as a vector in axis angle form,
424 | as a tensor of shape (..., 3), where the magnitude is
425 | the angle turned anticlockwise in radians around the
426 | vector's direction.
427 |
428 | Returns:
429 | Rotation matrices as tensor of shape (..., 3, 3).
430 | """
431 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
432 |
433 |
434 | def matrix_to_axis_angle(matrix):
435 | """
436 | Convert rotations given as rotation matrices to axis/angle.
437 |
438 | Args:
439 | matrix: Rotation matrices as tensor of shape (..., 3, 3).
440 |
441 | Returns:
442 | Rotations given as a vector in axis angle form, as a tensor
443 | of shape (..., 3), where the magnitude is the angle
444 | turned anticlockwise in radians around the vector's
445 | direction.
446 | """
447 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
448 |
449 |
450 | def axis_angle_to_quaternion(axis_angle):
451 | """
452 | Convert rotations given as axis/angle to quaternions.
453 |
454 | Args:
455 | axis_angle: Rotations given as a vector in axis angle form,
456 | as a tensor of shape (..., 3), where the magnitude is
457 | the angle turned anticlockwise in radians around the
458 | vector's direction.
459 |
460 | Returns:
461 | quaternions with real part first, as tensor of shape (..., 4).
462 | """
463 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
464 | half_angles = 0.5 * angles
465 | eps = 1e-6
466 | small_angles = angles.abs() < eps
467 | sin_half_angles_over_angles = torch.empty_like(angles)
468 | sin_half_angles_over_angles[~small_angles] = (
469 | torch.sin(half_angles[~small_angles]) / angles[~small_angles]
470 | )
471 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6
472 | # so sin(x/2)/x is about 1/2 - (x*x)/48
473 | sin_half_angles_over_angles[small_angles] = (
474 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48
475 | )
476 | quaternions = torch.cat(
477 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
478 | )
479 | return quaternions
480 |
481 |
482 | def quaternion_to_axis_angle(quaternions):
483 | """
484 | Convert rotations given as quaternions to axis/angle.
485 |
486 | Args:
487 | quaternions: quaternions with real part first,
488 | as tensor of shape (..., 4).
489 |
490 | Returns:
491 | Rotations given as a vector in axis angle form, as a tensor
492 | of shape (..., 3), where the magnitude is the angle
493 | turned anticlockwise in radians around the vector's
494 | direction.
495 | """
496 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
497 | half_angles = torch.atan2(norms, quaternions[..., :1])
498 | angles = 2 * half_angles
499 | eps = 1e-6
500 | small_angles = angles.abs() < eps
501 | sin_half_angles_over_angles = torch.empty_like(angles)
502 | sin_half_angles_over_angles[~small_angles] = (
503 | torch.sin(half_angles[~small_angles]) / angles[~small_angles]
504 | )
505 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6
506 | # so sin(x/2)/x is about 1/2 - (x*x)/48
507 | sin_half_angles_over_angles[small_angles] = (
508 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48
509 | )
510 | return quaternions[..., 1:] / sin_half_angles_over_angles
511 |
512 |
513 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
514 | """
515 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
516 | using Gram--Schmidt orthogonalisation per Section B of [1].
517 | Args:
518 | d6: 6D rotation representation, of size (*, 6)
519 |
520 | Returns:
521 | batch of rotation matrices of size (*, 3, 3)
522 |
523 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
524 | On the Continuity of Rotation Representations in Neural Networks.
525 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
526 | Retrieved from http://arxiv.org/abs/1812.07035
527 | """
528 |
529 | a1, a2 = d6[..., [0,2,4]], d6[..., [1,3,5]]
530 | b1 = F.normalize(a1, dim=-1)
531 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
532 | b2 = F.normalize(b2, dim=-1)
533 | b3 = torch.cross(b1, b2, dim=-1)
534 | return torch.cat((b1[...,None], b2[...,None], b3[...,None]), dim=-1)
535 |
536 |
537 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
538 | """
539 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
540 | by dropping the last row. Note that 6D representation is not unique.
541 | Args:
542 | matrix: batch of rotation matrices of size (*, 3, 3)
543 |
544 | Returns:
545 | 6D rotation representation, of size (*, 6)
546 |
547 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
548 | On the Continuity of Rotation Representations in Neural Networks.
549 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
550 | Retrieved from http://arxiv.org/abs/1812.07035
551 | """
552 | return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
553 |
--------------------------------------------------------------------------------