├── LICENSE
├── LICENSE-MF
├── README.md
├── assets
├── editing.gif
└── simcurve.gif
├── compute_metrics.py
├── configs
├── callback
│ ├── base.yaml
│ ├── best_checkpoint.yaml
│ ├── last_checkpoint.yaml
│ ├── latest_checkpoint.yaml
│ ├── lr_logging.yaml
│ ├── progress.yaml
│ └── render.yaml
├── compute_metrics.yaml
├── data
│ └── motionfix.yaml
├── evaluator
│ └── edit_evaluator.yaml
├── hydra
│ ├── hydra_logging
│ │ ├── console.yaml
│ │ ├── custom.yaml
│ │ └── rich.yaml
│ └── job_logging
│ │ ├── console.yaml
│ │ ├── custom.yaml
│ │ └── rich.yaml
├── logger
│ ├── none.yaml
│ ├── tensorboard.yaml
│ └── wandb.yaml
├── machine
│ └── server.yaml
├── model
│ ├── basic_clip.yaml
│ ├── basic_clip_cls.yaml
│ ├── basic_clip_cls_arch.yaml
│ ├── denoiser
│ │ ├── denoiser.yaml
│ │ ├── ditdenoiser_cls.yaml
│ │ └── ditdenoiser_cls_arch.yaml
│ ├── infer_scheduler
│ │ ├── ddim.yaml
│ │ └── ddpm.yaml
│ ├── losses
│ │ ├── basic.yaml
│ │ └── function
│ │ │ ├── kl.yaml
│ │ │ ├── klmulti.yaml
│ │ │ ├── mse.yaml
│ │ │ ├── recons.yaml
│ │ │ ├── recons_bp.yaml
│ │ │ └── smoothL1.yaml
│ ├── motion_condition_encoder
│ │ └── actor.yaml
│ ├── optim
│ │ └── adamw.yaml
│ ├── text_encoder
│ │ ├── clipenc.yaml
│ │ ├── clipenc_sim.yaml
│ │ ├── distilbert_enc.yaml
│ │ └── t5_enc.yaml
│ └── train_scheduler
│ │ ├── ddim.yaml
│ │ └── ddpm.yaml
├── motionfix_eval.yaml
├── path.yaml
├── sampler
│ ├── all_conseq.yaml
│ ├── fix_conseq.yaml
│ ├── upper_bound_variable_conseq.yaml
│ └── variable_conseq.yaml
├── stats.yaml
├── train.yaml
├── train_cls.yaml
├── train_cls_arch.yaml
└── trainer
│ ├── base-longer.yaml
│ └── base.yaml
├── deps
├── gpt
│ ├── edit
│ │ ├── gpt-labels_full.json
│ │ └── gpt-labels_full_v2.json
│ └── gpt3-labels-list.json
├── inference
│ ├── labels.json
│ ├── labels_train_spatial.json
│ ├── labels_val_spatial.json
│ ├── qual.txt
│ ├── smpl_part_seg.json
│ ├── text-annot_val_spatial-pairs.json
│ ├── texts_pairs_train.json
│ └── texts_pairs_val.json
├── smplh
│ └── smpl.faces
└── stats
│ └── statistics_motionfix.npy
├── gpt_parts
├── joint_utils.py
└── prompts.py
├── motionfix_evaluate.py
├── requirements.txt
├── scripts
├── download_data.sh
└── install.sh
├── src
├── __init__.py
├── callback
│ ├── __init__.py
│ ├── progress.py
│ └── render.py
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── comp.py
│ ├── features.py
│ ├── motionfix.py
│ ├── sampling
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── custom_batch_sampler.py
│ │ ├── framerate.py
│ │ └── frames.py
│ └── tools
│ │ ├── __init__.py
│ │ ├── amass_utils.py
│ │ ├── collate.py
│ │ ├── contacts.py
│ │ ├── extract_pairs.py
│ │ ├── rotation_transformation.py
│ │ ├── smpl.py
│ │ ├── spatiotempo.py
│ │ ├── tensors.py
│ │ └── utils.py
├── diffusion
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── gaussian_diffusion.py
│ ├── respace.py
│ └── timestep_sampler.py
├── evaluator
│ └── evaluate_edits.py
├── info
│ └── joints.py
├── launch
│ ├── blender.py
│ └── prepare.py
├── logger
│ ├── __init__.py
│ ├── tools.py
│ └── wandb_log.py
├── model
│ ├── DiT_denoiser.py
│ ├── DiT_denoiser_cls.py
│ ├── DiT_denoiser_cls_arch.py
│ ├── DiT_models.py
│ ├── __init__.py
│ ├── base.py
│ ├── base_diffusion.py
│ ├── dummy.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── compute.py
│ │ ├── compute_mld.py
│ │ ├── kl.py
│ │ ├── recons.py
│ │ ├── recons_bp.py
│ │ └── utils.py
│ ├── metrics
│ │ ├── __init__.py
│ │ └── compute.py
│ ├── motiondecoder
│ │ ├── __init__.py
│ │ └── actor.py
│ ├── motionencoder
│ │ ├── __init__.py
│ │ └── actor.py
│ ├── readme.txt
│ ├── textencoder
│ │ ├── __init__.py
│ │ ├── clip_encoder.py
│ │ ├── distilbert.py
│ │ ├── distilbert_encoder.py
│ │ ├── t5_encoder.py
│ │ └── text_space.py
│ ├── tmed_denoiser.py
│ ├── tmr_utils
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── losses.py
│ │ ├── metrics.py
│ │ ├── temos.py
│ │ ├── text_encoder.py
│ │ ├── tmr.py
│ │ └── utils.py
│ ├── trans_enc.py
│ ├── transformer_encoder
│ │ ├── encoder.py
│ │ ├── encoder_layer.py
│ │ ├── feed_forward.py
│ │ └── note.md
│ └── utils
│ │ ├── __init__.py
│ │ ├── all_positional_encodings.py
│ │ ├── body_parts.py
│ │ ├── lr_scheduler.py
│ │ ├── positional_encoding.py
│ │ ├── smpl_fast.py
│ │ ├── timestep_embed.py
│ │ ├── tools.py
│ │ ├── transf_utils.py
│ │ └── vae.py
├── render
│ ├── __init__.py
│ ├── anim.py
│ ├── mesh_viz.py
│ └── video.py
├── repr_utils
│ └── tmr_utils.py
├── tmr
│ ├── __init__.py
│ ├── actor.py
│ ├── data
│ │ ├── __init__.py
│ │ └── motionfix_loader.py
│ ├── load_model.py
│ ├── losses.py
│ ├── metrics.py
│ ├── temos.py
│ ├── text_encoder.py
│ └── tmr.py
├── tools
│ ├── __init__.py
│ ├── easyconvert.py
│ ├── frank.py
│ ├── geometry.py
│ ├── interpolation.py
│ ├── logging.py
│ ├── runid.py
│ └── transforms3d.py
└── utils
│ ├── __init__.py
│ ├── art_utils.py
│ ├── cherrypick.py
│ ├── eval_utils.py
│ ├── file_io.py
│ ├── genutils.py
│ ├── inference.py
│ ├── mesh_utils.py
│ ├── motionfix_utils.py
│ ├── nlp_consts.py
│ ├── smpl_body_utils.py
│ └── text_constants.py
├── tmr_evaluator
├── __init__.py
├── fid.py
├── motion2motion_retr.py
├── retr_ori.py
└── text2motion_retr.py
├── train.py
└── utils
├── __init__.py
├── masking.py
├── misc.py
└── transformations.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Zhengyuan Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE-MF:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Nikos Athanasiou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | SimMotionEdit: Text-Based Human Motion Editing with Motion Similarity Prediction
3 |
4 | Zhengyuan Li
5 | Kai Cheng
6 | Anindita Ghosh
7 | Uttaran Bhattacharya
8 | Liang-Yan Gui
9 | Aniket Bera
10 |
11 | Purdue University, DFKI, MPI-INF,
12 |
13 | Adobe Inc., University of Illinois Urbana-Champaign
14 |
15 | CVPR 2025
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | | Motion Similarity | Text-based Motion Editing |
31 | |-------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------|
32 | |

|

|
33 |
34 |
35 |
36 | ## Environment Setup
37 | Please follow [motionfix](https://github.com/atnikos/motionfix) to download the dataset and set up the environment.
38 |
39 | We provide pretrained model and pre-processed similarity in [this link](https://drive.google.com/drive/folders/1LjiKVjDHqOEnykZMP3ZiTJEz_EY-TqC3?usp=sharing).
40 |
41 | Constant setup: set `preproc.sim_file` to the path of preprocessed similarity file in `./configs/data/motionfix.yaml`.
42 |
43 | ## Evaluation
44 |
45 | #### Step 1: Extract the samples
46 | ```bash
47 | python motionfix_evaluate.py folder=/path/to/exp/ guidance_scale_text_n_motion=2.0 guidance_scale_motion=2.0 data=motionfix
48 | ```
49 |
50 | #### Step 2: Compute the metrics
51 | ```bash
52 | python compute_metrics.py folder=/path/to/exp/samples/npys
53 | ```
54 |
55 | ## Training
56 | ```bash
57 | python -u train.py --config-name="train_cls_arch" experiment=cls_arch run_id=no_text
58 | ```
59 |
60 | ## Acknowledgements
61 | Our code is based on: [motionfix](https://github.com/atnikos/motionfix).
62 |
63 | ## License
64 | This code is distributed under an MIT LICENSE. We also include the LICENSE of motionfix in our repo. Other third-party datasets and software are subject to their respective licenses.
65 |
66 | ## Citation
67 | You can cite this paper using:
68 |
69 | ```bibtex
70 | @article{li2025simmotionedittextbasedhumanmotion,
71 | title={SimMotionEdit: Text-Based Human Motion Editing with Motion Similarity Prediction},
72 | author={Zhengyuan Li and Kai Cheng and Anindita Ghosh and Uttaran Bhattacharya and Liangyan Gui and Aniket Bera},
73 | year={2025},
74 | eprint={2503.18211},
75 | archivePrefix={arXiv},
76 | primaryClass={cs.CV}
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/assets/editing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/assets/editing.gif
--------------------------------------------------------------------------------
/assets/simcurve.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/assets/simcurve.gif
--------------------------------------------------------------------------------
/compute_metrics.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 | import logging
3 | import hydra
4 | import torch
5 | from tqdm import tqdm
6 | from pathlib import Path
7 | import numpy as np
8 |
9 |
10 | def collect_gen_samples(motion_gen_path, normalizer, device):
11 | cur_samples = {}
12 | cur_samples_raw = {}
13 | # it becomes from
14 | # translation | root_orient | rots --> trans | rots | root_orient
15 | print("Collecting Generated Samples")
16 | from src.data.features import _get_body_transl_delta_pelv_infer
17 | import glob
18 |
19 | sample_files = glob.glob(f'{motion_gen_path}/*.npy')
20 | for fname in tqdm(sample_files):
21 | keyid = str(Path(fname).name).replace('.npy', '')
22 | gen_motion_b = np.load(fname,
23 | allow_pickle=True).item()['pose']
24 | gen_motion_b = torch.from_numpy(gen_motion_b)
25 | trans = gen_motion_b[..., :3]
26 | global_orient_6d = gen_motion_b[..., 3:9]
27 | body_pose_6d = gen_motion_b[..., 9:]
28 | trans_delta = _get_body_transl_delta_pelv_infer(global_orient_6d,
29 | trans)
30 | gen_motion_b_fixed = torch.cat([trans_delta, body_pose_6d,
31 | global_orient_6d], dim=-1)
32 | gen_motion_b_fixed = normalizer(gen_motion_b_fixed)
33 | cur_samples[keyid] = gen_motion_b_fixed.to(device)
34 | cur_samples_raw[keyid] = torch.cat([trans, global_orient_6d,
35 | body_pose_6d], dim=-1).to(device)
36 | return cur_samples, cur_samples_raw
37 |
38 |
39 | @hydra.main(config_path="configs", version_base="1.2", config_name="compute_metrics")
40 | def _compute_metrics(cfg: DictConfig):
41 | return compute_metrics(cfg)
42 |
43 | def compute_metrics(newcfg: DictConfig) -> None:
44 | from tmr_evaluator.motion2motion_retr import retrieval
45 | from pathlib import Path
46 | samples_folder = newcfg.folder
47 | metrs_batches, metrs_full = retrieval(samples_folder)
48 | print("\n===== Metrics for Retrieval on Batches of 32 =====")
49 | print(metrs_batches)
50 | print("\n===== Metrics for Retrieval on the full test set =====")
51 | print( metrs_full)
52 |
53 | if __name__ == '__main__':
54 | _compute_metrics()
55 |
56 |
--------------------------------------------------------------------------------
/configs/callback/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - /callback/last_checkpoint@last_ckpt
3 | - /callback/latest_checkpoint@latest_ckpt
4 | - /callback/progress@progress
5 | - /callback/render@render
6 | - /callback/lr_logging@lr_logging
--------------------------------------------------------------------------------
/configs/callback/best_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
2 | dirpath: ${path.working_dir}/checkpoints
3 | filename: best
4 | monitor: "Metrics/APE_root"
5 | mode: min
6 | every_n_epochs: 1
7 | save_top_k: 1
8 |
--------------------------------------------------------------------------------
/configs/callback/last_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
2 | dirpath: ${path.working_dir}/checkpoints
3 |
4 | filename: latest-{epoch}
5 | every_n_epochs: 1
6 | save_top_k: 1
7 | save_last: true
--------------------------------------------------------------------------------
/configs/callback/latest_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
2 | dirpath: ${path.working_dir}/checkpoints
3 | filename: latest-{epoch}
4 | monitor: step
5 | mode: max
6 | every_n_epochs: 100
7 | save_top_k: -1
8 | save_last: false
9 |
--------------------------------------------------------------------------------
/configs/callback/lr_logging.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
2 | # proper way to log LR but does not work with the custom logger thing ...
3 | logging_interval: epoch
4 |
--------------------------------------------------------------------------------
/configs/callback/progress.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.callback.ProgressLogger
2 |
--------------------------------------------------------------------------------
/configs/callback/render.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.callback.RenderCallback
2 |
3 | every_n_epochs: 50
4 | num_workers: ${machine.num_workers}
5 | save_last: true
6 | nvids_to_save: 3
7 | bm_path: ${path.data}
8 | modelname: ${model.modelname}
--------------------------------------------------------------------------------
/configs/compute_metrics.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: .
4 | job:
5 | chdir: true
6 | output_subdir: null
7 |
8 | folder: ???
9 |
--------------------------------------------------------------------------------
/configs/data/motionfix.yaml:
--------------------------------------------------------------------------------
1 | dataname: motionfix
2 | _target_: src.data.motionfix.MotionFixDataModule
3 | debug: ${debug}
4 |
5 | datapath: ${path.data}/motionfix-dataset/motionfix.pth.tar # amass_bodilex.pth.tar
6 |
7 | # Amass
8 | smplh_path: ${path.data}/body_models
9 | smplh_path_dbg: ${path.minidata}/body_models
10 |
11 | load_with_rot: true
12 |
13 | load_splits:
14 | - "train"
15 | - "val"
16 | - "test"
17 |
18 | proportion: 1.0
19 | text_augment: false
20 |
21 | # Machine
22 | batch_size: ${machine.batch_size} # it's tiny
23 | num_workers: ${machine.num_workers}
24 | rot_repr: '6d'
25 | preproc:
26 | stats_file: ${path.deps}/stats/statistics_${data.dataname}.npy # full path for statistics
27 | split_seed: 0
28 | calculate_minmax: True
29 | generate_joint_files: True
30 | use_cuda: True
31 | n_body_joints: 22
32 | norm_type: std # norm or std
33 | sim_file: PATH/TO/SIM
34 |
35 | # Motion
36 | framerate: 30
37 | sampler: ${sampler}
38 |
39 | load_feats:
40 | - "body_transl_delta_pelv"
41 | - "body_orient_xy"
42 | - "z_orient_delta"
43 | - "body_pose"
44 | - "body_joints_local_wo_z_rot"
45 | - "body_transl"
46 | - "body_orient"
47 | # Other
48 | progress_bar: true
49 |
--------------------------------------------------------------------------------
/configs/evaluator/edit_evaluator.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.evaluator.evaluate_edits.MotionEditEvaluator
2 | metrics_to_eval:
3 | # - 'foot_skating'
4 | - 'loc_pres'
5 | - 'glo_pres'
6 | - 'lc_pre_gt'
7 | - 'gb_pre_gt'
8 | - 'loc_edit'
9 | - 'glob_edit'
10 | smplh_path: ${path.data}/body_models
11 |
--------------------------------------------------------------------------------
/configs/hydra/hydra_logging/console.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): sinc.tools.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): sinc.tools.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | root:
42 | level: ${logger_level}
43 | handlers:
44 | - console
45 |
46 | disable_existing_loggers: false
47 |
--------------------------------------------------------------------------------
/configs/hydra/hydra_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | colorlog:
5 | (): colorlog.ColoredFormatter
6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
7 | datefmt: '%d/%m/%y %H:%M:%S'
8 |
9 | handlers:
10 | console:
11 | class: logging.StreamHandler
12 | formatter: colorlog
13 | stream: ext://sys.stdout
14 |
15 | root:
16 | level: INFO
17 | handlers:
18 | - console
19 |
20 | disable_existing_loggers: false
21 |
--------------------------------------------------------------------------------
/configs/hydra/hydra_logging/rich.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | colorlog:
5 | (): colorlog.ColoredFormatter
6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
7 | datefmt: '%d/%m/%y %H:%M:%S'
8 |
9 | handlers:
10 | console:
11 | class: rich.logging.RichHandler # logging.StreamHandler
12 | formatter: colorlog
13 |
14 | root:
15 | level: INFO
16 | handlers:
17 | - console
18 |
19 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/configs/hydra/job_logging/console.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): sinc.tools.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): sinc.tools.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | root:
42 | level: ${logger_level}
43 | handlers:
44 | - console
45 |
46 | disable_existing_loggers: false
47 |
--------------------------------------------------------------------------------
/configs/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): sinc.tools.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): sinc.tools.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | file_out:
42 | class: logging.FileHandler
43 | formatter: simple
44 | filename: logs.out
45 | filters:
46 | - noimportant
47 |
48 | file_err:
49 | class: logging.FileHandler
50 | formatter: simple
51 | filename: logs.err
52 | filters:
53 | - onlyimportant
54 |
55 | root:
56 | level: ${logger_level}
57 | handlers:
58 | - console
59 | - file_out
60 | - file_err
61 |
62 | disable_existing_loggers: false
63 |
--------------------------------------------------------------------------------
/configs/hydra/job_logging/rich.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): src.tools.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): src.tools.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | verysimple:
19 | format: '%(message)s'
20 | simple:
21 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
22 | datefmt: '%d/%m/%y %H:%M:%S'
23 |
24 | colorlog:
25 | (): colorlog.ColoredFormatter
26 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
27 | - %(message)s'
28 | datefmt: '%d/%m/%y %H:%M:%S'
29 |
30 | log_colors:
31 | DEBUG: purple
32 | INFO: green
33 | WARNING: yellow
34 | ERROR: red
35 | CRITICAL: red
36 |
37 | handlers:
38 | console:
39 | class: rich.logging.RichHandler # logging.StreamHandler
40 | formatter: verysimple # colorlog
41 | rich_tracebacks: true
42 |
43 | file_out:
44 | class: logging.FileHandler
45 | formatter: simple
46 | filename: logs.out
47 | filters:
48 | - noimportant
49 |
50 | file_err:
51 | class: logging.FileHandler
52 | formatter: simple
53 | filename: logs.err
54 | filters:
55 | - onlyimportant
56 |
57 | root:
58 | level: ${logger_level}
59 | handlers:
60 | - console
61 | - file_out
62 | - file_err
63 |
64 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/configs/logger/none.yaml:
--------------------------------------------------------------------------------
1 | logger_name: null
2 | version: ${run_id}
3 |
4 | project: null
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | logger_name: tensorboard
2 | # version: ${run_id}
3 |
4 | save_dir: tensorboard
5 | version: ""
6 | log_graph: false
7 | default_hp_metric: false
8 | project: null
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | logger_name: wandb
2 | name: ${experiment}-${run_id}
3 |
4 | save_dir: wandb
5 | project: ${project}
6 | group: null
7 | tags: null
8 | notes: null
9 |
10 | offline: false
11 | resume: false
12 | save_code: false
13 | log_model: false
14 |
15 | # wandb.watch(model, log=cfg.log, log_freq=cfg.log_freq)
16 |
--------------------------------------------------------------------------------
/configs/machine/server.yaml:
--------------------------------------------------------------------------------
1 | name: server
2 |
3 | # specific attributes to this machine
4 | batch_size: 128
5 | smpl_batch_size: 64
6 | num_workers: 16
7 |
8 | # Specific attributes for training
9 |
--------------------------------------------------------------------------------
/configs/model/basic_clip.yaml:
--------------------------------------------------------------------------------
1 | modelname: basic_clip
2 | _target_: src.model.base_diffusion.MD
3 |
4 | latent_dim: 768
5 |
6 | ff_size: 1024
7 | num_layers: 9
8 | num_head: 8
9 | droupout: 0.1
10 | activation: "gelu"
11 | render_vids_every_n_epochs: 100
12 | num_vids_to_render: 2
13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr
14 |
15 | zero_len_source: false
16 | old_way: false
17 | # normalization
18 | statistics_path: ${statistics_path}
19 | norm_type: standardize # min_max standardize
20 |
21 | # diffusion related
22 | diff_params:
23 | num_inference_timesteps: 200
24 | num_train_timesteps: 300
25 | prob_uncondp: 0.05 # 0.1 0.25
26 | prob_drop_text: 0.05 # 0.1 0.25
27 | prob_drop_motion: 0.05
28 | guidance_scale_text: 2.5 #
29 | guidance_scale_motion: 2.0
30 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
31 | predict_type: 'sample' # noise
32 |
33 | motion_condition: source
34 | source_encoder: null # trans_enc
35 | condition: text
36 | smpl_path: ${data.smplh_path}
37 | copy_target: false
38 | nfeats: 135
39 | dim_per_feat: [135]
40 | # data related
41 |
42 | input_feats:
43 | # - "body_transl_delta"
44 | - "body_transl_delta_pelv"
45 | # - "body_transl_delta_pelv_xy_wo_z"
46 | # - "body_transl_z"
47 | # - "body_transl"
48 | - "body_orient_xy"
49 | - "z_orient_delta"
50 | - "body_pose"
51 | - "body_joints_local_wo_z_rot"
52 |
53 | pad_inputs: false
54 |
55 | loss_func_pos: mse # l1 mse
56 | loss_func_feats: mse # l1 mse
57 |
58 | defaults:
59 | # diffusion stuff
60 | - _self_
61 | - /path@path
62 | - train_scheduler: ddpm
63 | - infer_scheduler: ddpm
64 | - denoiser: denoiser_larger # ditdenoiser
65 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc
66 | - motion_condition_encoder: actor
67 | - losses: basic
68 | - optim: adamw
69 | - /model/losses/function/recons@func_recons
70 | - /model/losses/function/recons@func_latent
71 | - /model/losses/function/kl@func_kl
72 |
--------------------------------------------------------------------------------
/configs/model/basic_clip_cls.yaml:
--------------------------------------------------------------------------------
1 | modelname: basic_clip_repr
2 | _target_: src.model.base_diffusion.MD
3 |
4 | latent_dim: 512
5 |
6 | ff_size: 1024
7 | num_layers: 9
8 | num_head: 8
9 | droupout: 0.1
10 | activation: "gelu"
11 | render_vids_every_n_epochs: 100
12 | num_vids_to_render: 2
13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr
14 |
15 | zero_len_source: false
16 | old_way: false
17 | # normalization
18 | statistics_path: ${statistics_path}
19 | norm_type: standardize # min_max standardize
20 | use_cls: True
21 | use_repr: False
22 | use_v_weight: True
23 | source_align_coef: 0
24 | target_align_coef: 0
25 | cls_coef: 0.05
26 | n_cls: 3
27 | target_align_depth: 6
28 |
29 | # diffusion related
30 | diff_params:
31 | num_inference_timesteps: 200
32 | num_train_timesteps: 300
33 | prob_uncondp: 0.05 # 0.1 0.25
34 | prob_drop_text: 0.05 # 0.1 0.25
35 | prob_drop_motion: 0.05
36 | guidance_scale_text: 2.5 #
37 | guidance_scale_motion: 2.0
38 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
39 | predict_type: 'sample' # noise
40 |
41 | motion_condition: source
42 | source_encoder: null # trans_enc
43 | condition: text
44 | smpl_path: ${data.smplh_path}
45 | copy_target: false
46 | nfeats: 135
47 | dim_per_feat: [135]
48 | # data related
49 |
50 | input_feats:
51 | # - "body_transl_delta"
52 | - "body_transl_delta_pelv"
53 | # - "body_transl_delta_pelv_xy_wo_z"
54 | # - "body_transl_z"
55 | # - "body_transl"
56 | - "body_orient_xy"
57 | - "z_orient_delta"
58 | - "body_pose"
59 | - "body_joints_local_wo_z_rot"
60 |
61 | pad_inputs: false
62 |
63 | loss_func_pos: mse # l1 mse
64 | loss_func_feats: mse # l1 mse
65 |
66 | defaults:
67 | # diffusion stuff
68 | - _self_
69 | - /path@path
70 | - train_scheduler: ddpm
71 | - infer_scheduler: ddpm
72 | - denoiser: ditdenoiser_cls # ditdenoiser
73 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc
74 | - motion_condition_encoder: actor
75 | - losses: basic
76 | - optim: adamw
77 | - /model/losses/function/recons@func_recons
78 | - /model/losses/function/recons@func_latent
79 | - /model/losses/function/kl@func_kl
80 |
--------------------------------------------------------------------------------
/configs/model/basic_clip_cls_arch.yaml:
--------------------------------------------------------------------------------
1 | modelname: basic_clip_repr
2 | _target_: src.model.base_diffusion.MD
3 |
4 | latent_dim: 512
5 |
6 | ff_size: 1024
7 | num_layers: 9
8 | num_head: 8
9 | droupout: 0.1
10 | activation: "gelu"
11 | render_vids_every_n_epochs: 100
12 | num_vids_to_render: 2
13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr
14 |
15 | zero_len_source: false
16 | old_way: false
17 | # normalization
18 | statistics_path: ${statistics_path}
19 | norm_type: standardize # min_max standardize
20 | use_cls: True
21 | use_repr: False
22 | use_v_weight: True
23 | source_align_coef: 0
24 | target_align_coef: 0
25 | cls_coef: 0.05
26 | n_cls: 3
27 | target_align_depth: 6
28 |
29 | # diffusion related
30 | diff_params:
31 | num_inference_timesteps: 200
32 | num_train_timesteps: 300
33 | prob_uncondp: 0.05 # 0.1 0.25
34 | prob_drop_text: 0.05 # 0.1 0.25
35 | prob_drop_motion: 0.05
36 | guidance_scale_text: 2.5 #
37 | guidance_scale_motion: 2.0
38 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
39 | predict_type: 'sample' # noise
40 |
41 | motion_condition: source
42 | source_encoder: null # trans_enc
43 | condition: text
44 | smpl_path: ${data.smplh_path}
45 | copy_target: false
46 | nfeats: 135
47 | dim_per_feat: [135]
48 | # data related
49 |
50 | input_feats:
51 | # - "body_transl_delta"
52 | - "body_transl_delta_pelv"
53 | # - "body_transl_delta_pelv_xy_wo_z"
54 | # - "body_transl_z"
55 | # - "body_transl"
56 | - "body_orient_xy"
57 | - "z_orient_delta"
58 | - "body_pose"
59 | - "body_joints_local_wo_z_rot"
60 |
61 | pad_inputs: false
62 |
63 | loss_func_pos: mse # l1 mse
64 | loss_func_feats: mse # l1 mse
65 |
66 | defaults:
67 | # diffusion stuff
68 | - _self_
69 | - /path@path
70 | - train_scheduler: ddpm
71 | - infer_scheduler: ddpm
72 | - denoiser: ditdenoiser_cls_arch # ditdenoiser
73 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc
74 | - motion_condition_encoder: actor
75 | - losses: basic
76 | - optim: adamw
77 | - /model/losses/function/recons@func_recons
78 | - /model/losses/function/recons@func_latent
79 | - /model/losses/function/kl@func_kl
80 |
--------------------------------------------------------------------------------
/configs/model/denoiser/denoiser.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.tmed_denoiser.TMED_denoiser
2 | text_encoded_dim: 768 # or 512 patch-14-large or base
3 | ff_size: 1024
4 | num_layers: 8
5 | num_heads: 4
6 | dropout: 0.1
7 | activation: 'gelu'
8 | condition: ${model.condition}
9 | motion_condition: ${model.motion_condition}
10 | latent_dim: ${model.latent_dim}
11 | nfeats: ${model.nfeats} # TODO FIX THIS
12 | use_sep: true
13 | pred_delta_motion: false
14 |
--------------------------------------------------------------------------------
/configs/model/denoiser/ditdenoiser_cls.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.DiT_denoiser_cls.DiT_Denoiser_CLS
2 | text_encoded_dim: 768 # or 512 patch-14-large or base
3 | ff_size: 1024
4 | num_layers: 8
5 | num_heads: 8
6 | dropout: 0.1
7 | activation: 'gelu'
8 | condition: ${model.condition}
9 | motion_condition: ${model.motion_condition}
10 | latent_dim: ${model.latent_dim}
11 | nfeats: ${model.nfeats}
12 | use_sep: true
13 | pred_delta_motion: false
14 | repr_dim: 256
15 | n_cls: 3
16 | encoder_layers: 4 # 4?
--------------------------------------------------------------------------------
/configs/model/denoiser/ditdenoiser_cls_arch.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.DiT_denoiser_cls_arch.DiT_Denoiser_CLS_Arch
2 | text_encoded_dim: 768 # or 512 patch-14-large or base
3 | ff_size: 1024
4 | num_layers: 8
5 | num_heads: 8
6 | dropout: 0.1
7 | activation: 'gelu'
8 | condition: ${model.condition}
9 | motion_condition: ${model.motion_condition}
10 | latent_dim: ${model.latent_dim}
11 | nfeats: ${model.nfeats}
12 | use_sep: true
13 | pred_delta_motion: false
14 | repr_dim: 256
15 | n_cls: 3
16 | encoder_layers: 4
17 | ablation: no_text # 'raw_text', 'no_text', 'raw_source', 'raw_text_and_source'
--------------------------------------------------------------------------------
/configs/model/infer_scheduler/ddim.yaml:
--------------------------------------------------------------------------------
1 | _target_: diffusers.DDIMScheduler
2 | num_train_timesteps: 1000
3 | beta_start: 0.00085
4 | beta_end: 0.012
5 | beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
6 | # variance_type: 'fixed_small'
7 | clip_sample: false # clip sample to -1~1
8 | prediction_type: sample
9 | # below are for ddim
10 | set_alpha_to_one: false
11 | steps_offset: 1
--------------------------------------------------------------------------------
/configs/model/infer_scheduler/ddpm.yaml:
--------------------------------------------------------------------------------
1 | _target_: diffusers.DDPMScheduler
2 | num_train_timesteps: 1000
3 | beta_start: 0.00085
4 | beta_end: 0.012
5 | beta_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
6 | variance_type: 'fixed_small'
7 | clip_sample: false # clip sample to -1~1
8 | prediction_type: sample
9 |
--------------------------------------------------------------------------------
/configs/model/losses/basic.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.losses.MLDLosses
2 |
3 | # Loss terms
4 | ## Reconstruction losses
5 | lmd_rfeats_recons: 1.0
6 | lmd_jfeats_recons: 1.0
7 | predict_epsilon: false
8 |
9 | modelname: ${model.modelname}
10 |
11 | ## Latent sinc. losses
12 | lmd_latent: 1e-5
13 | lmd_kl: 1e-5
14 | lmd_prior: 0.0
15 | lmd_recons: 1.0
16 | lmd_gen: 1.0
17 | fuse: 'concat' # 'add' null
18 | # Ablations
19 | loss_on_both: true
20 | # loss on joint position features
21 | loss_on_jfeats: false
22 | ablation_no_kl_combine: false
23 | ablation_no_kl_gaussian: false
24 | ablation_no_motionencoder: false
25 |
26 | # # Text => rfeats (rotation features)
27 | # recons_text2rfeats: ${.lmd_rfeats_recons}
28 | # recons_text2rfeats_func: ${model.func_recons}
29 |
30 | # # Text => jfeats (xyz features)
31 | # recons_text2jfeats: ${.lmd_jfeats_recons}
32 | # recons_text2jfeats_func: ${model.func_recons}
33 |
34 | # # rfeats => rfeats
35 | # recons_rfeats2rfeats: ${.lmd_rfeats_recons}
36 | # recons_rfeats2rfeats_func: ${model.func_recons}
37 |
38 | # # vts => vts
39 | # recons_vertex2vertex: ${.lmd_rfeats_recons}
40 | # recons_vertex2vertex_func: ${model.func_recons}
41 |
42 | # # jfeats => jfeats
43 | # recons_jfeats2jfeats: ${.lmd_jfeats_recons}
44 | # recons_jfeats2jfeats_func: ${model.func_recons}
45 |
46 | # # Latent sinc.losses
47 | # latent_manifold: ${.lmd_latent}
48 | # latent_manifold_func: ${model.func_latent}
49 |
50 | # # VAE losses
51 | # kl_text: ${.lmd_kl}
52 | # kl_text_func: ${model.func_kl}
53 |
54 | # kl_motion: ${.lmd_kl}
55 | # kl_motion_func: ${model.func_kl}
56 |
57 | # kl_text2motion: ${.lmd_kl}
58 | # kl_text2motion_func: ${model.func_kl}
59 |
60 | # kl_motion2text: ${.lmd_kl}
61 | # kl_motion2text_func: ${model.func_kl}
62 |
63 |
--------------------------------------------------------------------------------
/configs/model/losses/function/kl.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.losses.KLLoss
--------------------------------------------------------------------------------
/configs/model/losses/function/klmulti.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.losses.KLLossMulti
--------------------------------------------------------------------------------
/configs/model/losses/function/mse.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.nn.MSELoss
2 | reduction: mean
3 |
--------------------------------------------------------------------------------
/configs/model/losses/function/recons.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.losses.recons.Recons
--------------------------------------------------------------------------------
/configs/model/losses/function/recons_bp.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.losses.recons_bp.ReconsBP
2 |
--------------------------------------------------------------------------------
/configs/model/losses/function/smoothL1.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.nn.SmoothL1Loss
2 | reduction: mean
--------------------------------------------------------------------------------
/configs/model/motion_condition_encoder/actor.yaml:
--------------------------------------------------------------------------------
1 | name: actor_encoder
2 | _target_: src.model.motionencoder.ActorAgnosticEncoder
3 |
4 | latent_dim: ${model.latent_dim}
5 |
6 | ff_size: ${model.ff_size}
7 | num_layers: ${model.num_layers}
8 | num_head: ${model.num_head}
9 | droupout: ${model.droupout}
10 | activation: ${model.activation}
11 | nfeats: ${model.nfeats} # TODO FIX THIS
--------------------------------------------------------------------------------
/configs/model/optim/adamw.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.AdamW
2 | lr: 1e-4
3 | lr_final: 1e-6
4 | t_warmup: 150
5 |
--------------------------------------------------------------------------------
/configs/model/text_encoder/clipenc.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.textencoder.clip_encoder.ClipTextEncoder
2 | finetune: false # if false, model weights are frozen
3 | # true --> 77x768 | false --> 1x768
4 | last_hidden_state: true # if true, the last hidden state is used as the text embedding
5 | # clip-vit-base-patch32 | clip-vit-large-patch14
6 | modelpath: ${path.deps}/clip-vit-large-patch14
7 |
--------------------------------------------------------------------------------
/configs/model/text_encoder/clipenc_sim.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.textencoder.clip_encoder.ClipTextEncoder
2 | finetune: false # if false, model weights are frozen
3 | # true --> 77x768 | false --> 1x768
4 | last_hidden_state: false # if true, the last hidden state is used as the text embedding
5 | # clip-vit-base-patch32 | clip-vit-large-patch14
6 | modelpath: ${path.deps}/clip-vit-large-patch14
7 |
--------------------------------------------------------------------------------
/configs/model/text_encoder/distilbert_enc.yaml:
--------------------------------------------------------------------------------
1 | # name: distilbert_linear_encoder
2 | _target_: src.model.textencoder.distilbert_encoder.DistilbertEncoderTransformer
3 |
4 | latent_dim: ${model.latent_dim}
5 |
6 | ff_size: ${model.ff_size}
7 | num_layers: ${model.num_layers}
8 | num_head: ${model.num_head}
9 | droupout: ${model.droupout}
10 | activation: ${model.activation}
11 |
12 | finetune: false
13 | modelpath: ${path.deps}/distilbert-base-uncased
14 |
--------------------------------------------------------------------------------
/configs/model/text_encoder/t5_enc.yaml:
--------------------------------------------------------------------------------
1 | name: t5_text_encoder
2 | _target_: src.model.textencoder.t5_encoder.T5TextEncoder
3 | finetune: false # if false, model weights are frozen
4 | modelpath: ${path.deps}/flan-t5-base
--------------------------------------------------------------------------------
/configs/model/train_scheduler/ddim.yaml:
--------------------------------------------------------------------------------
1 | _target_: diffusers.DDIMScheduler
2 | num_train_timesteps: 1000
3 | beta_start: 0.00085
4 | beta_end: 0.012
5 | beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
6 | # variance_type: 'fixed_small'
7 | clip_sample: false # clip sample to -1~1
8 | prediction_type: sample
9 | # below are for ddim
10 | set_alpha_to_one: false
11 | steps_offset: 1
12 |
--------------------------------------------------------------------------------
/configs/model/train_scheduler/ddpm.yaml:
--------------------------------------------------------------------------------
1 | _target_: diffusers.DDPMScheduler
2 | num_train_timesteps: 1000
3 | beta_start: 0.00085
4 | beta_end: 0.012
5 | beta_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2']
6 | variance_type: 'fixed_small'
7 | clip_sample: false # clip sample to -1~1
8 | prediction_type: sample
9 |
--------------------------------------------------------------------------------
/configs/motionfix_eval.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: .
4 | job:
5 | chdir: true
6 | output_subdir: null
7 |
8 | folder: ???
9 | mode: sample # denoise / sample / sthing else
10 |
11 | prob_way: '3way' # 2way
12 |
13 | savedir: null
14 | mean: false
15 | fact: 1
16 | number_of_samples: 1
17 | ckpt_name: 'last'
18 | last_ckpt_path: ${get_last_checkpoint:${folder},${ckpt_name}}
19 | logger_level: INFO
20 | save_pkl: false
21 | render_vids: true
22 | subset: null
23 |
24 | num_sampling_steps: 1000
25 |
26 | guidance_scale_text_n_motion: null
27 | guidance_scale_motion: null
28 |
29 | init_from: 'noise' # noise
30 | condition_mode: 'full_cond' # 'mot_cond' 'text_cond'
31 | inpaint: false
32 | linear_gd: false
33 |
34 | defaults:
35 | - _self_
36 | - data: motionfix
37 | - /path@path
38 | - override hydra/job_logging: rich # custom
39 | - override hydra/hydra_logging: rich # custom
40 |
41 | split_to_load:
42 | - "test"
43 | - "val"
44 |
--------------------------------------------------------------------------------
/configs/path.yaml:
--------------------------------------------------------------------------------
1 | # path to additional modules
2 | deps: ${code_path:./deps}
3 | data: ${code_path:./data}
4 | minidata: ${code_path:./minidata}
5 | code_dir: ${code_path:}
6 | working_dir: ${working_path:""}
7 | minilog: ${code_path:./miniexperiments}
--------------------------------------------------------------------------------
/configs/sampler/all_conseq.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.sampling.FrameSampler
2 |
3 | # For fix size data
4 | request_frames: null
5 | sampling: conseq
6 | sampling_step: 1
7 |
8 | # For data of any size
9 | max_len: 800
10 | min_len: 15
11 |
--------------------------------------------------------------------------------
/configs/sampler/fix_conseq.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.sampling.FrameSampler
2 |
3 | # For fix size data
4 | request_frames: 60
5 | threshold_reject: 0.75
6 | sampling: conseq
7 | sampling_step: 1
8 |
9 | # For data of any size
10 | max_len: 10000
11 | min_len: 10
--------------------------------------------------------------------------------
/configs/sampler/upper_bound_variable_conseq.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.sampling.FrameSampler
2 |
3 | # For fix size data
4 | request_frames:
5 | sampling: conseq
6 | sampling_step: 1
7 | threshold_reject: 1.0
8 |
9 | # For data of any size
10 | max_len: 500
11 | min_len: 15
--------------------------------------------------------------------------------
/configs/sampler/variable_conseq.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.sampling.FrameSampler
2 |
3 | # For fix size data
4 | request_frames: 300
5 | sampling: conseq
6 | sampling_step: 1
7 | threshold_reject: 0.75
8 |
9 | # For data of any size
10 | max_len: 600
11 | min_len: 15
12 |
--------------------------------------------------------------------------------
/configs/stats.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model: motion_prior_mlp_cvae
4 | - scheduler: reduce_on_plateau
5 | - dataloader: amass
6 | model_family: "mlp"
7 | debug: False
8 | variant: "base"
9 | dataset: "amass"
10 | render_train_outputs: True
11 |
12 | # all used directories
13 | project_dir: "." # /data/deps/stats
14 | chkpnt_dir: "data/checkpoints" # experiments
15 | amass_path: "data/amass/amass.pth.tar"
16 | smplx_models_path: "data/body_models/"
17 | # grab_path: "data/amass/GRAB/GRAB.pth.tar"
18 | grab_path: "/ps/project/datasets/GRAB"
19 | joints_path: "/is/cluster/fast/mdiomataris/grab_joints_new"
20 |
21 | # experiment details
22 | wandb_logger:
23 | # _target_: pytorch_lightning.loggers.WandbLogger
24 | resume: 'allow'
25 | id:
26 | group:
27 | tags:
28 | mode: "online"
29 | # project: 'HOI-common-sense'
30 | project: 'Motion_Prior-HPT'
31 | save_dir: "data/wandb_logs"
32 |
33 | rendering:
34 | dir: "data/rendered_outputs"
35 | in_training: True
36 | in_testing: True
37 | every_epochs: 100
38 | train_ids:
39 | - 0
40 | # - 10
41 | - 42
42 | # - 100
43 | - 200
44 | # - 300
45 | # - 345
46 | # - 128
47 | - 333
48 | # - 444
49 | test_ids:
50 | - 1
51 | - 2
52 | - 3
53 | - 5
54 | - 8
55 | - 13
56 | - 17
57 | - 16
58 |
59 | exp_id: None
60 |
61 | # lightning checkpoint callback
62 | trainer:
63 | _target_: pytorch_lightning.Trainer
64 | default_root_dir: ${chkpnt_dir}
65 | max_epochs: 300
66 | accelerator: 'gpu'
67 | devices: 1
68 | fast_dev_run: False
69 | overfit_batches: 0.0
70 | enable_progress_bar: True
71 | auto_scale_batch_size:
72 | accumulate_grad_batches:
73 | gradient_clip_val: 2.0
74 | callbacks:
75 | logger:
76 | resume_from_checkpoint:
77 |
78 | # precision: "bf16"
79 |
80 | batch_size: 32
81 | num_workers: 4
82 | tune: False
83 |
84 | # lightning checkpoint callback
85 | model_checkpoint:
86 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
87 | dirpath:
88 | filename:
89 | monitor: "val/loss_monitor"
90 | save_top_k: 1
91 | mode: "min"
92 | save_last: True
93 | every_n_epochs: 5
94 |
95 | optimizer:
96 | _target_: torch.optim.Adam
97 | lr: 1e-4
98 |
99 | # training parameters
100 | monitor_loss: "loss"
101 |
102 | # preprocessing parameters
103 | preproc:
104 | joints_dump_path: "/is/cluster/fast/mdiomataris/grab_joints_new"
105 | split_seed: 0
106 | calculate_minmax: True
107 | generate_joint_files: True
108 | use_cuda: True
109 |
110 | # dataloading parameters
111 | dl:
112 | framerate_ratio: 4
113 | chunk_duration: 8.0
114 | trim_nointeractions: False
115 | force_recalculate_stats: False
116 |
117 | # augmentation parameters
118 | aug:
119 | undo_interaction_prob: 0.1
120 | out_of_reach_prob: 0.1
121 | min_oor_distance: 0.05
122 | max_oor_distance: 0.3
123 | random_rotate: False
124 | random_rot_type: "3d"
125 | framerate_dev: 0
126 |
127 | loss_type: "mse"
128 | joint_loss: False
129 | n_body_joints: 22
130 |
131 | rot_repr: "6d"
132 | norm_type: "std"
133 | load_feats:
134 | - "body_transl"
135 | - "body_transl_delta"
136 | - "body_transl_delta_pelv"
137 | - "body_transl_delta_pelv_xy"
138 | - "body_transl_z"
139 | - "body_orient"
140 | - "body_pose"
141 | - "body_orient_delta"
142 | - "body_pose_delta"
143 | - "body_orient_xy"
144 | - "body_joints"
145 | # - "body_joints_rel"
146 | # - "body_joints_vel"
147 | # - "object_transl"
148 | # - "object_transl_rel"
149 | # - "object_transl_vel"
150 | # - "object_orient"
151 | # - "obj_contact_bin"
152 | # - "hands_contact_bin"
153 | # - "obj_wrists_dist"
154 | # - "wrist_joints_transl"
155 | # - "wrist_joints_transl_rel"
156 | # - "wrist_joints_vel"
157 | # - "joint_global_oris"
158 | # - "joint_ang_vel"
159 | # - "wrists_ang_vel"
160 | # - "wrists_ang_vel_euler"
161 | # - "active_grasp_frames"
162 | # - "index_tips_vel"
163 |
164 |
165 | feats_dims:
166 | body_transl: 3
167 | body_transl_delta: 3
168 | body_transl_delta_pelv: 3
169 | body_transl_delta_pelv_xy: 3
170 | body_transl_z: 1
171 | body_orient: 6
172 | body_orient_delta: 6
173 | body_orient_xy: 6
174 | body_pose: 126
175 | body_pose_delta: 126
176 | body_joints: 66
177 | body_joints_rel: 66
178 | body_joints_vel: 66
179 | object_transl: 3
180 | object_transl_rel: 3
181 | object_transl_vel: 3
182 | object_orient: 6
183 | obj_contact_bin: 1
184 | obj_wrists_dist: 6
185 | wrist_joints_transl: 6
186 | wrist_joints_transl_rel: 6
187 | wrist_joints_vel: 6
188 | index_tips_vel: 6
189 | joint_ang_vel: 6
190 | wrists_ang_vel: 6
191 | hands_contact_bin: 2
192 |
193 | # rendering:
194 | # fps: 30
195 | # choose_random:
196 | # indices:
197 | # - 36 #-> drinking from bottle
198 | # - 160 #-> headphones over head
199 | # - 269 #-> airplane
200 | # - 272 #-> failing case, scisors
201 | # - 51 #-> good example (easy)
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${expdir}/${project}/${experiment}/${run_id}/
4 | job:
5 | chdir: true
6 | env_set:
7 | # if you want to use wandb you should assign this key
8 | WANDB_API_KEY: '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'
9 | PYOPENGL_PLATFORM: 'egl'
10 | HYDRA_FULL_ERROR: 1
11 | WANDB__SERVICE_WAIT: 300
12 |
13 | debug: false
14 |
15 | # Global configurations shared between different modules
16 | expdir: ${get_expdir:${debug}}
17 | experiment: ${data.dataname}
18 | # must be the same when you are resuming experiment
19 | project: new_code
20 | seed: 42
21 | logger_level: INFO
22 | run_id: ${generate_id:}
23 | # For finetuning
24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/}
25 | resume_ckpt_name: 'last'
26 | renderer: null # ait # null
27 |
28 | # log gradients/weights
29 | watch_model: false
30 | log_freq: 1000
31 | log: 'all'
32 |
33 | devices: 1
34 |
35 | # For finetuning
36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline
37 | ftune_ckpt_name: 'last' # 'all'
38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}}
39 |
40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy
41 | # statistics_file: "statistics_amass_circle.py"
42 | statistics_path: ${path.deps}/stats/${statistics_file}
43 |
44 |
45 | # Composing nested config with default
46 | defaults:
47 | - data: motionfix
48 | - model: basic_clip
49 | - machine: server
50 | - trainer: base
51 | - sampler: variable_conseq # cut it
52 | - logger: tensorboard # wandb
53 | - callback: base
54 | - /path@path
55 | - override hydra/job_logging: rich
56 | - override hydra/hydra_logging: rich
57 | - _self_
58 |
--------------------------------------------------------------------------------
/configs/train_cls.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${expdir}/${project}/${experiment}/${run_id}/
4 | job:
5 | chdir: true
6 | env_set:
7 | # if you want to use wandb you should assign this key
8 | WANDB_API_KEY: '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'
9 | PYOPENGL_PLATFORM: 'egl'
10 | HYDRA_FULL_ERROR: 1
11 | WANDB__SERVICE_WAIT: 300
12 |
13 | debug: false
14 |
15 | # Global configurations shared between different modules
16 | expdir: ${get_expdir:${debug}}
17 | experiment: ${data.dataname}
18 | # must be the same when you are resuming experiment
19 | project: new_code
20 | seed: 42
21 | logger_level: INFO
22 | run_id: ${generate_id:}
23 | # For finetuning
24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/}
25 | resume_ckpt_name: 'last'
26 | renderer: null # ait # null
27 |
28 | # log gradients/weights
29 | watch_model: false
30 | log_freq: 1000
31 | log: 'all'
32 |
33 | devices: 1
34 |
35 | # For finetuning
36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline
37 | ftune_ckpt_name: 'last' # 'all'
38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}}
39 |
40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy
41 | # statistics_file: "statistics_amass_circle.py"
42 | statistics_path: ${path.deps}/stats/${statistics_file}
43 |
44 |
45 | # Composing nested config with default
46 | defaults:
47 | - data: motionfix
48 | - model: basic_clip_cls
49 | - machine: server
50 | - trainer: base-longer
51 | - sampler: variable_conseq # cut it
52 | - logger: tensorboard # wandb
53 | - callback: base
54 | - /path@path
55 | - override hydra/job_logging: rich
56 | - override hydra/hydra_logging: rich
57 | - _self_
58 |
--------------------------------------------------------------------------------
/configs/train_cls_arch.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${expdir}/${project}/${experiment}/${run_id}/
4 | job:
5 | chdir: true
6 | env_set:
7 | # if you want to use wandb you should assign this key
8 | WANDB_API_KEY: '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'
9 | PYOPENGL_PLATFORM: 'egl'
10 | HYDRA_FULL_ERROR: 1
11 | WANDB__SERVICE_WAIT: 300
12 |
13 | debug: false
14 |
15 | # Global configurations shared between different modules
16 | expdir: ${get_expdir:${debug}}
17 | experiment: ${data.dataname}
18 | # must be the same when you are resuming experiment
19 | project: new_code
20 | seed: 42
21 | logger_level: INFO
22 | run_id: ${generate_id:}
23 | # For finetuning
24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/}
25 | resume_ckpt_name: 'last'
26 | renderer: null # ait # null
27 |
28 | # log gradients/weights
29 | watch_model: false
30 | log_freq: 1000
31 | log: 'all'
32 |
33 | devices: 1
34 |
35 | # For finetuning
36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline
37 | ftune_ckpt_name: 'last' # 'all'
38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}}
39 |
40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy
41 | # statistics_file: "statistics_amass_circle.py"
42 | statistics_path: ${path.deps}/stats/${statistics_file}
43 |
44 |
45 | # Composing nested config with default
46 | defaults:
47 | - data: motionfix
48 | - model: basic_clip_cls_arch
49 | - machine: server
50 | - trainer: base-longer
51 | - sampler: variable_conseq # cut it
52 | - logger: tensorboard # wandb
53 | - callback: base
54 | - /path@path
55 | - override hydra/job_logging: rich
56 | - override hydra/hydra_logging: rich
57 | - _self_
58 |
--------------------------------------------------------------------------------
/configs/trainer/base-longer.yaml:
--------------------------------------------------------------------------------
1 | strategy: null # 'ddp' for multi gpu
2 | benchmark: True
3 | max_epochs: 1501
4 | accelerator: gpu
5 | log_every_n_steps: 40 # 100
6 | deterministic: False
7 | detect_anomaly: False
8 | enable_progress_bar: True
9 | check_val_every_n_epoch: 80 # if k --> happens every (k-1), 2*(k-1), ... 100
10 | limit_train_batches: 1.0
11 | limit_val_batches: 1.0
12 | num_sanity_val_steps: 0 # 2
13 | precision: 32 # => 'bf16' | 16 | 32
14 |
--------------------------------------------------------------------------------
/configs/trainer/base.yaml:
--------------------------------------------------------------------------------
1 | strategy: null # 'ddp' for multi gpu
2 | benchmark: True
3 | max_epochs: 1001
4 | accelerator: gpu
5 | log_every_n_steps: 40 # 100
6 | deterministic: False
7 | detect_anomaly: False
8 | enable_progress_bar: True
9 | check_val_every_n_epoch: 100 # if k --> happens every (k-1), 2*(k-1), ... 100
10 | limit_train_batches: 1.0
11 | limit_val_batches: 1.0
12 | num_sanity_val_steps: 0 # 2
13 | precision: 32 # => 'bf16' | 16 | 32
14 |
--------------------------------------------------------------------------------
/deps/inference/qual.txt:
--------------------------------------------------------------------------------
1 | indep-seg+seq/
2 | climb down stairs, hold rail with left hand
3 | spatial_pairs-1008-0.mp4
4 |
5 | wide side step to the right
6 | seg-10094-0.mp4
7 |
8 | walk up stairs
9 | seg-10090-2.mp4
10 |
11 | lift right arm
12 | seg-10072-19.mp4
13 |
14 |
15 | left step left, hold arms as in waltz
16 | spatial_pairs-9918-1.mp4
17 |
18 |
19 | high jump, right hand full high
20 | spatial_pairs-9854-6.mp4
21 | spatial_pairs-9854-5.mp4
22 | spatial_pairs-9854-3.mp4
23 |
24 |
25 | look down, walk
26 | spatial_pairs-10090-0.mp4
27 |
28 |
29 | spatial_pairs-5740-2
30 | spatial_pairs-12301-0
31 |
--------------------------------------------------------------------------------
/deps/smplh/smpl.faces:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/deps/smplh/smpl.faces
--------------------------------------------------------------------------------
/deps/stats/statistics_motionfix.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/deps/stats/statistics_motionfix.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | accelerate==0.23.0
3 | aiohttp==3.8.1
4 | aiosignal==1.2.0
5 | aitviewer==1.12.0
6 | antlr4-python3-runtime==4.9.3
7 | appdirs==1.4.4
8 | asttokens==2.0.5
9 | async-timeout==4.0.2
10 | attrs==21.4.0
11 | backcall==0.2.0
12 | blis==0.7.9
13 | cachetools==5.0.0
14 | catalogue==2.0.8
15 | certifi==2021.10.8
16 | charset-normalizer==2.0.12
17 | click==8.0.4
18 | cmake==3.26.0
19 | colorlog==6.6.0
20 | confection==0.0.4
21 | contourpy==1.1.1
22 | cycler==0.12.1
23 | cymem==2.0.7
24 | decorator==4.4.2
25 | diffusers==0.21.2
26 | docker-pycreds==0.4.0
27 | einops==0.6.1
28 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl#sha256=0964370218b7e1672a30ac50d72cdc6b16f7c867496f1d60925691188f4d2510
29 | executing==0.8.3
30 | filelock==3.6.0
31 | fonttools==4.43.1
32 | freetype-py==2.2.0
33 | frozenlist==1.3.0
34 | fsspec==2024.3.1
35 | future==0.18.2
36 | gitdb==4.0.9
37 | GitPython==3.1.27
38 | glcontext==2.4.0
39 | google-auth==2.6.2
40 | google-auth-oauthlib==0.4.6
41 | grpcio==1.44.0
42 | huggingface-hub==0.17.3
43 | hydra-colorlog==1.2.0
44 | hydra-core==1.3.2
45 | idna==3.3
46 | imageio==2.16.1
47 | imageio-ffmpeg==0.4.5
48 | imgui==2.0.0
49 | importlib-metadata==4.11.3
50 | ipdb==0.13.9
51 | ipython==8.1.1
52 | jedi==0.18.1
53 | Jinja2==3.1.2
54 | joblib==1.3.2
55 | kiwisolver==1.4.5
56 | langcodes==3.3.0
57 | lightning-utilities==0.8.0
58 | lit==16.0.0
59 | loguru==0.6.0
60 | Markdown==3.3.6
61 | markdown-it-py==3.0.0
62 | MarkupSafe==2.1.2
63 | matplotlib==3.8.0
64 | matplotlib-inline==0.1.3
65 | mdurl==0.1.2
66 | moderngl==5.8.2
67 | moderngl-window==2.4.4
68 | more-itertools==9.1.0
69 | moviepy==1.0.3
70 | mpmath==1.3.0
71 | multidict==6.0.2
72 | multipledispatch==1.0.0
73 | murmurhash==1.0.9
74 | networkx==2.7.1
75 | numpy==1.22.3
76 | oauthlib==3.2.0
77 | omegaconf==2.3.0
78 | opencv-contrib-python-headless==4.8.0.76
79 | opencv-python==4.5.5.64
80 | orjson==3.9.15
81 | packaging==21.3
82 | pandas==1.4.1
83 | parso==0.8.3
84 | pathtools==0.1.2
85 | pathy==0.10.1
86 | pexpect==4.8.0
87 | pickleshare==0.7.5
88 | Pillow==9.0.1
89 | preshed==3.0.8
90 | proglog==0.1.9
91 | promise==2.3
92 | prompt-toolkit==3.0.28
93 | protobuf==3.19.4
94 | psutil==5.9.0
95 | ptyprocess==0.7.0
96 | pudb==2022.1.1
97 | pure-eval==0.2.2
98 | pyasn1==0.4.8
99 | pyasn1-modules==0.2.8
100 | pydantic==1.10.2
101 | pyDeprecate==0.3.1
102 | pyglet==2.0.9
103 | Pygments==2.16.1
104 | PyOpenGL==3.1.0
105 | pyparsing==3.0.7
106 | PyQt5==5.15.9
107 | PyQt5-Qt5==5.15.2
108 | PyQt5-sip==12.12.2
109 | pyrender==0.1.45
110 | pyrr==0.10.3
111 | python-dateutil==2.8.2
112 | pytz==2021.3
113 | PyWavelets==1.4.1
114 | PyYAML==6.0
115 | regex==2022.3.15
116 | requests==2.27.1
117 | requests-oauthlib==1.3.1
118 | rich==13.5.3
119 | roma==1.4.0
120 | rsa==4.8
121 | sacremoses==0.0.49
122 | safetensors==0.3.3
123 | scikit-image==0.19.3
124 | scikit-video==1.1.11
125 | scipy==1.7.2
126 | sentry-sdk==1.5.8
127 | setproctitle==1.3.2
128 | shortuuid==1.0.8
129 | six==1.16.0
130 | smart-open==6.3.0
131 | smmap==5.0.0
132 | smplx==0.1.28
133 | spacy==3.5.1
134 | spacy-legacy==3.0.12
135 | spacy-loggers==1.0.4
136 | srsly==2.4.6
137 | stack-data==0.2.0
138 | sympy==1.11.1
139 | tensorboard==2.10.0
140 | tensorboard-data-server==0.6.1
141 | tensorboard-plugin-wit==1.8.1
142 | termcolor==1.1.0
143 | thinc==8.1.9
144 | tifffile==2023.9.18
145 | tokenizers==0.11.6
146 | toml==0.10.2
147 | tqdm==4.63.0
148 | traitlets==5.1.1
149 | transformers==4.17.0
150 | trimesh==3.10.5
151 | triton==2.1.0
152 | typer==0.7.0
153 | typing_extensions==4.11.0
154 | urllib3==1.26.9
155 | urwid==2.1.2
156 | urwid-readline==0.13
157 | usd-core==23.8
158 | uuid==1.30
159 | wandb==0.16.6
160 | wasabi==1.1.1
161 | wcwidth==0.2.5
162 | websockets==11.0.3
163 | Werkzeug==2.0.3
164 | yarl==1.7.2
165 | yaspin==2.1.0
166 | zipp==3.7.0
167 |
--------------------------------------------------------------------------------
/scripts/download_data.sh:
--------------------------------------------------------------------------------
1 | pip install gdown
2 | mkdir data
3 | mkdir data/body_models
4 | # dataset
5 | gdown --folder "https://drive.google.com/drive/folders/1DM7oIJwxwoljVxAfhfktocTptwVX5sqR?usp=sharing"
6 | mv motionfix-dataset data/
7 | # tmr folder
8 | gdown --folder "https://drive.google.com/drive/folders/15LHeriOCjmh4Cp5H9M94xoFGBN0amxI8?usp=sharing"
9 | mv tmr-evaluator eval-deps
10 | # smpl models
11 | gdown --folder "https://drive.google.com/drive/folders/1s3re2I1OzBimQIpudUEFB1hClFWOPJjC?usp=drive_link"
12 | mv smplh data/body_models
13 | # tmed checkpoints
14 | gdown --folder "https://drive.google.com/drive/folders/1M_i_zUSlktdEKf-xBF9g6y7N-lfDtuPD?usp=sharing"
15 | mkdir experiments
16 | mv tmed experiments/
17 | mkdir experiments/tmed/checkpoints
18 | mv experiments/tmed/last.ckpt experiments/tmed/checkpoints/
19 |
--------------------------------------------------------------------------------
/scripts/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | echo "Creating virtual environment"
4 | python3.10 -m venv mfix-env
5 | echo "Activating virtual environment"
6 |
7 | source $PWD/mfix-env/bin/activate
8 |
9 | $PWD/mfix-env/bin/pip install --upgrade pip setuptools
10 |
11 | $PWD/mfix-env/bin/pip install "torch==2.1.2" "torchvision==0.16.2" --index-url https://download.pytorch.org/whl/cu118
12 |
13 | $PWD/mfix-env/bin/pip install "pytorch-lightning==2.2.4"
14 | $PWD/mfix-env/bin/pip install -r requirements.txt
15 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/__init__.py
--------------------------------------------------------------------------------
/src/callback/__init__.py:
--------------------------------------------------------------------------------
1 | from .progress import ProgressLogger
2 | from .render import RenderCallback
--------------------------------------------------------------------------------
/src/callback/progress.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pytorch_lightning import LightningModule, Trainer
4 | from pytorch_lightning.callbacks import Callback
5 | import psutil
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class ProgressLogger(Callback):
11 | def __init__(self,
12 | metric_monitor: dict,
13 | precision: int = 3):
14 | # Metric to monitor
15 | self.metric_monitor = metric_monitor
16 | self.precision = precision
17 |
18 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None:
19 | logger.info("Training started")
20 |
21 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None:
22 | logger.info("Training done")
23 |
24 | def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None:
25 | if trainer.sanity_checking:
26 | logger.info("Sanity checking ok.")
27 |
28 | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None:
29 | metric_format = f"{{:.{self.precision}e}}"
30 | line = f"Epoch {trainer.current_epoch}"
31 | line = f"{line:>{len('Epoch xxxx')}}" # Right padding
32 | metrics_str = []
33 |
34 | losses_dict = trainer.callback_metrics
35 | for metric_name, dico_name in self.metric_monitor.items():
36 | if dico_name not in losses_dict:
37 | dico_name = f"losses/{dico_name}"
38 |
39 | if dico_name in losses_dict:
40 | metric = losses_dict[dico_name].item()
41 | metric = metric_format.format(metric)
42 | metric = f"{metric_name} {metric}"
43 | metrics_str.append(metric)
44 |
45 | if len(metrics_str) == 0:
46 | return
47 |
48 | memory = f"Memory {psutil.virtual_memory().percent}%"
49 | line = line + ": " + " ".join(metrics_str) + " " + memory
50 | logger.info(line)
51 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/base.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from torch.utils.data import DataLoader
3 |
4 | from src.data.tools.collate import collate_batch_last_padding, collate_datastruct_and_text
5 | import torch
6 | from typing import List
7 |
8 | class BASEDataModule(pl.LightningDataModule):
9 | def __init__(self,
10 | batch_size: int,
11 | num_workers: int,
12 | load_feats: List[str],
13 | batch_sampler: str | None = None,
14 | dataset_percentages: dict[str, float] | None = None):
15 | super().__init__()
16 |
17 | collate_fn = lambda b: collate_batch_last_padding(b, load_feats)
18 |
19 | def set_worker_sharing_strategy(worker_id: int) -> None:
20 | sharing_strategy = "file_system"
21 | torch.multiprocessing.set_sharing_strategy(sharing_strategy)
22 | self.dataloader_options = {
23 | 'batch_size': batch_size,
24 | 'num_workers': num_workers,
25 | 'collate_fn': collate_fn,
26 | 'drop_last': False,
27 | 'worker_init_fn': set_worker_sharing_strategy
28 | # 'pin_memory': True,
29 | }
30 | self.batch_sampler = batch_sampler
31 | self.ds_perc = dataset_percentages
32 | self.batch_size = batch_size
33 | # need to be overloaded:
34 | # - self.Dataset
35 | # - self._sample_set => load only a small subset
36 | # There is an helper below (get_sample_set)
37 | # - self.nfeats
38 | # - self.transforms
39 | self._train_dataset = None
40 | self._val_dataset = None
41 | self._test_dataset = None
42 |
43 | # Optional
44 | self._subset_dataset = None
45 |
46 | def get_sample_set(self, overrides={}):
47 | sample_params = self.hparams.copy()
48 | sample_params.update(overrides)
49 | return self.Dataset(**sample_params)
50 |
51 | def train_dataloader(self):
52 | if self.batch_sampler is not None:
53 | from src.data.sampling.custom_batch_sampler import PercBatchSampler, CustomBatchSampler, CustomBatchSamplerV2, CustomBatchSamplerV4
54 | from src.data.sampling.custom_batch_sampler import mix_datasets_anysize
55 | # ratio_batch_sampler = CustomBatchSamplerV2(concat_dataset=self.dataset['train'],
56 | # batch_size=self.batch_size)
57 | ratio_batch_sampler = CustomBatchSamplerV4(concat_dataset=self.dataset['train'],
58 | batch_size=self.batch_size,
59 | mix_percentages=self.ds_perc)
60 |
61 | # ratio_batch_sampler = PercBatchSampler(data_source=self.dataset['train'],
62 | # baxtch_size=self.batch_size)
63 | # dataset_percentages=self.ds_perc)
64 | del self.dataloader_options['batch_size']
65 | return DataLoader(self.dataset['train'],
66 | batch_sampler=ratio_batch_sampler,
67 | **self.dataloader_options)
68 | else:
69 | return DataLoader(self.dataset['train'],
70 | shuffle=True,
71 | **self.dataloader_options)
72 |
73 | def val_dataloader(self):
74 | if self.batch_sampler is not None:
75 | return DataLoader(self.dataset['test'],
76 | #batch_sampler=ratio_batch_sampler,
77 | shuffle=False,
78 | **self.dataloader_options)
79 | else:
80 | return DataLoader(self.dataset['test'],
81 | shuffle=False,
82 | **self.dataloader_options)
83 |
84 | def test_dataloader(self):
85 | return DataLoader(self.dataset['test'],
86 | shuffle=False,
87 | **self.dataloader_options)
88 |
89 | # def train_dataloader(self):
90 | # return DataLoader(self.train_dataset,
91 | # shuffle=True,
92 | # **self.dataloader_options)
93 |
94 | # def predict_dataloader(self):
95 | # return DataLoader(self.train_dataset,
96 | # shuffle=False,
97 | # **self.dataloader_options)
98 |
99 | # def val_dataloader(self):
100 | # return DataLoader(self.val_dataset,
101 | # shuffle=False,
102 | # **self.dataloader_options)
103 |
104 | # def test_dataloader(self):
105 | # return DataLoader(self.test_dataset,
106 | # shuffle=False,
107 | # **self.dataloader_options)
108 |
109 | # def subset_dataloader(self):
110 | # return DataLoader(self.subset_dataset,
111 | # shuffle=False,
112 | # **self.dataloader_options)
113 |
--------------------------------------------------------------------------------
/src/data/features.py:
--------------------------------------------------------------------------------
1 | from src.tools.transforms3d import change_for, transform_body_pose
2 | from src.utils.genutils import to_tensor
3 | import torch
4 |
5 | def _get_body_pose(data):
6 | """get body pose"""
7 | # default is axis-angle representation: Frames x (Jx3) (J=21)
8 | if not torch.is_tensor(data):
9 | pose = to_tensor(data['rots'][..., 3:3 + 21*3]) # drop pelvis orientation
10 | else:
11 | pose = to_tensor(data[..., 3:3 + 21*3])
12 | pose = transform_body_pose(pose, f"aa->6d")
13 | return pose
14 |
15 | def _get_body_transl(data):
16 | """get body pelvis translation"""
17 | if not torch.is_tensor(data):
18 | tran = data['trans']
19 | else:
20 | tran = data
21 | return to_tensor(tran)
22 |
23 | def _get_body_orient(data):
24 | """get body global orientation"""
25 | # default is axis-angle representation
26 | if not torch.is_tensor(data):
27 | pelvis_orient = to_tensor(data['rots'][..., :3])
28 | else:
29 | pelvis_orient = data
30 | # axis-angle to rotation matrix & drop last row
31 | pelvis_orient = transform_body_pose(pelvis_orient, "aa->6d")
32 | return pelvis_orient
33 |
34 | def _get_body_transl_delta_pelv(data):
35 | """
36 | get body pelvis tranlation delta relative to pelvis coord.frame
37 | v_i = t_i - t_{i-1} relative to R_{i-1}
38 | """
39 | trans = to_tensor(data['trans'])
40 | trans_vel = trans - trans.roll(1, 0) # shift one right and subtract
41 | pelvis_orient = transform_body_pose(to_tensor(data['rots'][..., :3]), "aa->rot")
42 | trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0))
43 | trans_vel_pelv[0] = 0 # zero out velocity of first frame
44 | return trans_vel_pelv
45 |
46 | def _get_body_transl_delta_pelv_infer(pelvis_orient, trans):
47 | """
48 | get body pelvis tranlation delta relative to pelvis coord.frame
49 | v_i = t_i - t_{i-1} relative to R_{i-1}
50 | """
51 | trans = to_tensor(trans)
52 | trans_vel = trans - trans.roll(1, 0) # shift one right and subtract
53 | pelvis_orient = transform_body_pose(to_tensor(pelvis_orient), "6d->rot")
54 | trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0))
55 | trans_vel_pelv[0] = 0 # zero out velocity of first frame
56 | return trans_vel_pelv
57 |
--------------------------------------------------------------------------------
/src/data/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import FrameSampler
2 | from .framerate import subsample, upsample
3 |
--------------------------------------------------------------------------------
/src/data/sampling/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from dataclasses import dataclass
3 |
4 |
5 | @dataclass
6 | class FrameSampler:
7 | sampling: str = "conseq"
8 | sampling_step: int = 1
9 | request_frames: Optional[int] = None
10 | threshold_reject: int = 0.75
11 | max_len: int = 1000
12 | min_len: int = 10
13 |
14 | def __call__(self, num_frames):
15 | from .frames import get_frameix_from_data_index
16 | return get_frameix_from_data_index(num_frames,
17 | self.max_len,
18 | self.request_frames,
19 | self.sampling,
20 | self.sampling_step)
21 |
22 | def accept(self, duration):
23 | # Outputs have original lengths
24 | # Check if it is too long
25 | if self.request_frames is None:
26 | if duration > self.max_len:
27 | return False
28 | if duration < self.min_len:
29 | return False
30 | else:
31 | # Reject sample if the length is
32 | # too little relative to
33 | # the request frames
34 |
35 | # min_number = self.threshold_reject * self.request_frames
36 | if duration < self.min_len: # min_number:
37 | return False
38 | return True
39 |
40 | def get(self, key, default=None):
41 | return getattr(self, key, default)
42 |
43 | def __getitem__(self, key):
44 | return getattr(self, key)
45 |
--------------------------------------------------------------------------------
/src/data/sampling/framerate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # TODO: use a real subsampler..
5 | def subsample(num_frames, last_framerate, new_framerate):
6 | step = int(last_framerate / new_framerate)
7 | assert step >= 1
8 | frames = np.arange(0, num_frames, step)
9 | return frames
10 |
11 |
12 | # TODO: use a real upsampler..
13 | def upsample(motion, last_framerate, new_framerate):
14 | step = int(new_framerate / last_framerate)
15 | assert step >= 1
16 |
17 | # Alpha blending => interpolation
18 | alpha = np.linspace(0, 1, step+1)
19 | last = np.einsum("l,...->l...", 1-alpha, motion[:-1])
20 | new = np.einsum("l,...->l...", alpha, motion[1:])
21 |
22 | chuncks = (last + new)[:-1]
23 | output = np.concatenate(chuncks.swapaxes(1, 0))
24 | # Don't forget the last one
25 | output = np.concatenate((output, motion[[-1]]))
26 | return output
27 |
28 |
29 | if __name__ == "__main__":
30 | motion = np.arange(105)
31 | submotion = motion[subsample(len(motion), 100.0, 12.5)]
32 | newmotion = upsample(submotion, 12.5, 100)
33 |
34 | print(newmotion)
35 |
--------------------------------------------------------------------------------
/src/data/sampling/frames.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | from numpy import ndarray as Array
5 | import random
6 |
7 |
8 | def get_frameix_from_data_index(num_frames: int,
9 | max_len: Optional[int],
10 | request_frames: Optional[int],
11 | sampling: str = "conseq",
12 | sampling_step: int = 1) -> Array:
13 | nframes = num_frames
14 | # do not pad small sequences sample from long ones
15 | if request_frames is None or request_frames > nframes:
16 | frame_ix = np.arange(nframes)
17 | else:
18 | # sampling goal: input: ----------- 11 nframes
19 | # o--o--o--o- 4 ninputs
20 | #
21 | # step number is computed like that: [(11-1)/(4-1)] = 3
22 | # [---][---][---][-
23 | # So step = 3, and we take 0 to step*ninputs+1 with steps
24 | # [o--][o--][o--][o-]
25 | # then we can randomly shift the vector
26 | # -[o--][o--][o--]o
27 | # If there are too much frames required
28 | # Nikos: It never gets here now. Should add a pad flag instead of this.
29 | if request_frames > nframes:
30 | fair = False # True
31 | if fair:
32 | # distills redundancy everywhere
33 | choices = np.random.choice(range(nframes),
34 | request_frames,
35 | replace=True)
36 | frame_ix = sorted(choices)
37 | else:
38 | # adding the last frame until done
39 | # Nikos: do not pad
40 | ntoadd = max(0, request_frames - nframes)
41 | lastframe = nframes - 1
42 | padding = lastframe * np.ones(ntoadd, dtype=int)
43 | frame_ix = np.concatenate((np.arange(0, nframes),
44 | padding))
45 |
46 | elif sampling in ["conseq", "random_conseq"]:
47 | step_max = (nframes - 1) // (request_frames - 1)
48 | if sampling == "conseq":
49 | if sampling_step == -1 or sampling_step * (request_frames - 1) >= nframes:
50 | step = step_max
51 | else:
52 | step = sampling_step
53 | elif sampling == "random_conseq":
54 | step = random.randint(1, step_max)
55 |
56 | lastone = step * (request_frames - 1)
57 | shift_max = nframes - lastone - 1
58 | shift = random.randint(0, max(0, shift_max - 1))
59 | frame_ix = shift + np.arange(0, lastone + 1, step)
60 |
61 | elif sampling == "random":
62 | choices = np.random.choice(range(nframes),
63 | request_frames,
64 | replace=False)
65 | frame_ix = sorted(choices)
66 |
67 | else:
68 | raise ValueError("Sampling not recognized.")
69 |
70 | return frame_ix
71 |
--------------------------------------------------------------------------------
/src/data/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .tensors import lengths_to_mask, lengths_to_mask_njoints
2 | from .collate import collate_text_and_length, collate_pairs_and_text, collate_datastruct_and_text, collate_tensor_with_padding
3 |
--------------------------------------------------------------------------------
/src/data/tools/amass_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import torch
3 | name_mapping = {'MPI_mosh':'MoSh',
4 | 'ACCAD':'ACCAD',
5 | 'WEIZMANN': 'WEIZMANN',
6 | 'CNRS': 'CNRS',
7 | 'DFaust': 'DFaust',
8 | 'TCDHands': 'TCDHands',
9 | 'SOMA': 'SOMA',
10 | 'CMU': 'CMU',
11 | 'SFU': 'SFU',
12 | 'TotalCapture': 'TotalCapture',
13 | 'HDM05': 'HDM05',
14 | 'Transitions': 'Transitions',
15 | 'MoSh': 'MoSh',
16 | 'KIT': 'KIT',
17 | 'DanceDB': 'DanceDB',
18 | 'Transitions_mocap': 'Transitions',
19 | 'PosePrior': 'PosePrior',
20 | 'MPI_Limits': 'PosePrior',
21 | 'BMLhandball': 'BMLhandball',
22 | 'SSM': 'SSM',
23 | 'TCD_handMocap': 'TCDHands',
24 | 'BMLrub': 'BMLrub',
25 | 'BioMotionLab_NTroje': 'BMLrub',
26 | 'SSM_synced': 'SSM',
27 | 'Eyes_Japan_Dataset': 'Eyes_Japan_Dataset',
28 | 'DFaust_67': 'DFaust',
29 | 'EKUT': 'EKUT',
30 | 'MPI_HDM05': 'HDM05',
31 | 'GRAB': 'GRAB',
32 | 'HumanEva': 'HumanEva',
33 | 'HUMAN4D': 'HUMAN4D',
34 | 'BMLmovi': 'BMLmovi'
35 | }
36 |
37 |
38 | def path_normalizer(paths_list_or_str):
39 | if isinstance(paths_list_or_str, str):
40 | paths_list = [paths_list_or_str]
41 | else:
42 | paths_list = list(paths_list_or_str)
43 | # works only for dir1/dir2/fname.npz to normalize dir1
44 | plist = ['/'.join(p.split('/')[-3:]) for p in paths_list]
45 | norm_path = ['/'.join([name_mapping[p.split('/')[0]], p.split('/')[1], p.split('/')[2]]) for p in plist if p.split('/')[0] in name_mapping.keys()]
46 | if isinstance(paths_list_or_str, str):
47 | return norm_path[0]
48 | else:
49 | return norm_path
50 |
51 | def fname_normalizer(fname):
52 | dataset_name, subject, sequence_name = fname.split('/')
53 | sequence_name = sequence_name.replace('_poses.npz', '')
54 | sequence_name = sequence_name.replace('_poses', '')
55 | sequence_name = sequence_name.replace('poses', '')
56 | sequence_name = sequence_name.replace('_stageii.npz', '')
57 | sequence_name = sequence_name.replace('_stageii', '')
58 | sequence_name = sequence_name.rstrip()
59 | getVals = list([val for val in sequence_name
60 | if val.isalpha() or val.isnumeric()])
61 | sequence_name = ''.join(getVals)
62 |
63 | return '/'.join([dataset_name, subject, sequence_name])
64 |
65 | def flip_motion(pose, trans):
66 | # expects T, Jx3
67 | # Permutation of SMPL pose parameters when flipping the shape
68 | SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10,
69 | 12, 14, 13, 15, 17, 16, 19,
70 | 18, 21, 20] #, 23, 22]
71 | SMPL_POSE_FLIP_PERM = []
72 | for i in SMPL_JOINTS_FLIP_PERM:
73 | SMPL_POSE_FLIP_PERM.append(3*i)
74 | SMPL_POSE_FLIP_PERM.append(3*i+1)
75 | SMPL_POSE_FLIP_PERM.append(3*i+2)
76 |
77 | flipped_parts = SMPL_POSE_FLIP_PERM
78 | pose = pose[:, flipped_parts]
79 | # we also negate the second and the third dimension of the axis-angle
80 | pose[:, 1::3] = -pose[:, 1::3]
81 | pose[:, 2::3] = -pose[:, 2::3]
82 | x, y, z = torch.unbind(trans, dim=-1)
83 | mirrored_trans = torch.stack((-x, y, z), axis=-1)
84 |
85 | return pose, mirrored_trans
--------------------------------------------------------------------------------
/src/data/tools/contacts.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from src.info.joints import smplh_joints
4 | left_foot_joints = []
5 | right_foot_joints = []
6 | jointnames = ['foot', 'small_toe', 'heel', 'big_toe', 'ankle']
7 |
8 | def foot_detect(positions, thres):
9 | """ Get Foot Contacts """
10 | velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
11 |
12 | feet_l_x = (positions[1:, left_foot_joints, 0] - positions[:-1, left_foot_joints, 0]) ** 2
13 | feet_l_y = (positions[1:, left_foot_joints, 1] - positions[:-1, left_foot_joints, 1]) ** 2
14 | feet_l_z = (positions[1:, left_foot_joints, 2] - positions[:-1, left_foot_joints, 2]) ** 2
15 | # feet_l_h = positions[:-1,fid_l,1]
16 | # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
17 | feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32)
18 |
19 | feet_r_x = (positions[1:, right_foot_joints, 0] - positions[:-1, right_foot_joints, 0]) ** 2
20 | feet_r_y = (positions[1:, right_foot_joints, 1] - positions[:-1, right_foot_joints, 1]) ** 2
21 | feet_r_z = (positions[1:, right_foot_joints, 2] - positions[:-1, right_foot_joints, 2]) ** 2
22 | # feet_r_h = positions[:-1,fid_r,1]
23 | # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
24 | feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32)
25 | return feet_l, feet_r
26 |
--------------------------------------------------------------------------------
/src/data/tools/extract_pairs.py:
--------------------------------------------------------------------------------
1 | from src.utils.nlp_consts import fix_spell
2 | from src.data.tools.spatiotempo import temporal_compositions, spatial_compositions
3 | from src.data.tools.spatiotempo import EXCLUDED_ACTIONS, EXCLUDED_ACTIONS_WO_TR
4 |
5 |
6 | def extract_frame_labels_onlytext(babel_labels):
7 | seg_acts = []
8 | # if is_valid:
9 | if babel_labels['frame_ann'] is None:
10 | # 'transl' 'pose''betas'
11 | action_label = babel_labels['seq_ann']['labels'][0]['proc_label']
12 | action_label = fix_spell(action_label)
13 | seg_acts.append(action_label)
14 | else:
15 | for seg_an in babel_labels['frame_ann']['labels']:
16 | action_label = fix_spell(seg_an['proc_label'])
17 | if action_label not in EXCLUDED_ACTIONS:
18 | seg_acts.append(action_label)
19 |
20 | return seg_acts
21 |
22 |
23 | def extract_frame_labels(babel_labels, fps, seqlen, max_simultaneous):
24 |
25 | seg_ids = []
26 | seg_acts = []
27 | # is_valid = True
28 | # possible_frame_dtypes = ['seg', 'pairs', 'separate_pairs', 'spatial_pairs']
29 | # # if 'seq' in datatype and babel_labels['frame_ann'] is not None:
30 | # # is_valid = False
31 | # if bool(set(datatype.split('+')) & set(possible_frame_dtypes)) \
32 | # and babel_labels['frame_ann'] is None:
33 | # is_valid = False
34 |
35 | possible_motions = {}
36 | # if is_valid:
37 | if babel_labels['frame_ann'] is None:
38 | # 'transl' 'pose''betas'
39 | action_label = babel_labels['seq_ann']['labels'][0]['proc_label']
40 | possible_motions['seq'] = [(0, seqlen, fix_spell(action_label))]
41 | else:
42 | # Get segments
43 | # segments_dict = {k: {} for k in range(babel_labels['frame_ann']['labels'])}
44 | seg_list = []
45 |
46 | for seg_an in babel_labels['frame_ann']['labels']:
47 | action_label = fix_spell(seg_an['proc_label'])
48 |
49 | st_f = int(seg_an['start_t'] * fps)
50 | end_f = int(seg_an['end_t'] * fps)
51 |
52 | if end_f > seqlen:
53 | end_f = seqlen
54 | seg_ids.append((st_f, end_f))
55 | seg_acts.append(action_label)
56 |
57 | if action_label not in EXCLUDED_ACTIONS and end_f > st_f:
58 | seg_list.append((st_f, end_f, action_label))
59 |
60 | possible_motions['seg'] = seg_list
61 | spatial = spatial_compositions(seg_list,
62 | actions_up_to=max_simultaneous)
63 | possible_motions['spatial_pairs'] = spatial
64 | possible_motions['separate_pairs'] = temporal_compositions(
65 | seg_ids, seg_acts)
66 |
67 | return possible_motions
68 |
--------------------------------------------------------------------------------
/src/data/tools/rotation_transformation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import src.tools.geometry as geometry
3 | from einops import rearrange
4 |
5 |
6 | def rotate_trajectory(traj, rotZ, inverse=False):
7 | if inverse:
8 | # transpose
9 | rotZ = rearrange(rotZ, "... i j -> ... j i")
10 |
11 | vel = torch.diff(traj, dim=-2)
12 | # 0 for the first one => keep the dimentionality
13 | vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2)
14 | vel_local = torch.einsum("...kj,...k->...j", rotZ[..., :2, :2], vel[..., :2])
15 | # Integrate the trajectory
16 | traj_local = torch.cumsum(vel_local, dim=-2)
17 | # First frame should be the same as before
18 | traj_local = traj_local - traj_local[..., [0], :] + traj[..., [0], :]
19 | return traj_local
20 |
21 |
22 | def rotate_trans(trans, rotZ, inverse=False):
23 | traj = trans[..., :2]
24 | transZ = trans[..., 2]
25 | traj_local = rotate_trajectory(traj, rotZ, inverse=inverse)
26 | trans_local = torch.cat((traj_local, transZ[..., None]), axis=-1)
27 | return trans_local
28 |
29 |
30 | def rotate_joints2D(joints, rotZ, inverse=False):
31 | if inverse:
32 | # transpose
33 | rotZ = rearrange(rotZ, "... i j -> ... j i")
34 |
35 | assert joints.shape[-1] == 2
36 | vel = torch.diff(joints, dim=-2)
37 | # 0 for the first one => keep the dimentionality
38 | vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2)
39 | vel_local = torch.einsum("...kj,...lk->...lj", rotZ[..., :2, :2], vel[..., :2])
40 | # Integrate the trajectory
41 | joints_local = torch.cumsum(vel_local, dim=-2)
42 | # First frame should be the same as before
43 | joints_local = joints_local - joints_local[..., [0], :] + joints[..., [0], :]
44 | return joints_local
45 |
46 |
47 | def rotate_joints(joints, rotZ, inverse=False):
48 | joints2D = joints[..., :2]
49 | jointsZ = joints[..., 2]
50 | joints2D_local = rotate_joints2D(joints2D, rotZ, inverse=inverse)
51 | joints_local = torch.cat((joints2D_local, jointsZ[..., None]), axis=-1)
52 | return joints_local
53 |
54 |
55 | def canonicalize_rotations(global_orient, trans, angle=0.0):
56 | global_euler = geometry.matrix_to_euler_angles(global_orient, "ZYX")
57 | anglesZ, anglesY, anglesX = torch.unbind(global_euler, -1)
58 |
59 | rotZ = geometry._axis_angle_rotation("Z", anglesZ)
60 |
61 | # remove the current rotation
62 | # make it local
63 | local_trans = rotate_trans(trans, rotZ)
64 |
65 | # For information:
66 | # rotate_joints(joints, rotZ) == joints_local
67 |
68 | diff_mat_rotZ = rotZ[..., 1:, :, :] @ rotZ.transpose(-1, -2)[..., :-1, :, :]
69 |
70 | vel_anglesZ = geometry.matrix_to_axis_angle(diff_mat_rotZ)[..., 2]
71 | # padding "same"
72 | vel_anglesZ = torch.cat((vel_anglesZ[..., [0]], vel_anglesZ), dim=-1)
73 |
74 | # Compute new rotation:
75 | # canonicalized
76 | anglesZ = torch.cumsum(vel_anglesZ, -1)
77 | anglesZ += angle
78 | rotZ = geometry._axis_angle_rotation("Z", anglesZ)
79 |
80 | new_trans = rotate_trans(local_trans, rotZ, inverse=True)
81 |
82 | new_global_euler = torch.stack((anglesZ, anglesY, anglesX), -1)
83 | new_global_orient = geometry.euler_angles_to_matrix(new_global_euler, "ZYX")
84 |
85 | return new_global_orient, new_trans
--------------------------------------------------------------------------------
/src/data/tools/smpl.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import src.tools.geometry as geometry
3 | import torch
4 | from torch import Tensor
5 | from src.tools.easyconvert import matrix_to, axis_angle_to
6 | from src.transforms.smpl import RotTransDatastruct
7 |
8 |
9 | def canonicalize_smplh(poses: Tensor, trans: Optional[Tensor] = None):
10 | bs, nframes, njoints = poses.shape[:3]
11 |
12 | global_orient = poses[:, :, 0]
13 |
14 | # first global rotations
15 | rot2d = geometry.matrix_to_axis_angle(global_orient[:, 0])
16 | rot2d[:, :2] = 0 # Remove the rotation along the vertical axis
17 | rot2d = geometry.axis_angle_to_matrix(rot2d)
18 |
19 | # Rotate the global rotation to eliminate Z rotations
20 | global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient)
21 |
22 | # Construct canonicalized version of x
23 | xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2)
24 |
25 | if trans is not None:
26 | vel = trans[:, 1:] - trans[:, :-1]
27 | # Turn the translation as well
28 | vel = torch.einsum("ikj,ilk->ilj", rot2d, vel)
29 | trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device),
30 | torch.cumsum(vel, 1)), 1)
31 | return xc, trans
32 | else:
33 | return xc
34 |
35 |
36 | def smpl_data_to_matrix_and_trans(data, nohands=True):
37 | trans = data['trans']
38 | nframes = len(trans)
39 | try:
40 | axis_angle_poses = data['poses']
41 | axis_angle_poses = data['poses'].reshape(nframes, -1, 3)
42 | except:
43 | breakpoint()
44 |
45 | if nohands:
46 | axis_angle_poses = axis_angle_poses[:, :22]
47 |
48 | matrix_poses = axis_angle_to("matrix", axis_angle_poses)
49 |
50 | return RotTransDatastruct(rots=matrix_poses, trans=trans)
51 |
--------------------------------------------------------------------------------
/src/data/tools/tensors.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | import torch
3 | from torch import Tensor
4 |
5 | def lengths_to_mask_njoints(lengths: List[int], njoints: int, device: torch.device) -> Tensor:
6 | # joints*lenghts
7 | joints_lengths = [njoints*l for l in lengths]
8 | joints_mask = lengths_to_mask(joints_lengths, device)
9 | return joints_mask
10 |
11 |
12 | def lengths_to_mask(lengths: List[int], device: torch.device) -> Tensor:
13 | lengths = torch.tensor(lengths, device=device)
14 | max_len = max(lengths)
15 | mask = torch.arange(max_len,
16 | device=device).expand(len(lengths),
17 | max_len) < lengths.unsqueeze(1)
18 | return mask
19 |
20 | from copy import copy
21 | import numpy as np
22 | import torch
23 | from torch import Tensor
24 | import logging
25 | import os
26 | import random
27 | from einops import rearrange
28 |
29 | def read_json(file_path):
30 | import json
31 | with open(file_path, 'r') as json_file:
32 | data = json.load(json_file)
33 | return data
34 |
35 | def freeze(model) -> None:
36 | r"""
37 | Freeze all params for inference.
38 | """
39 | for param in model.parameters():
40 | param.requires_grad = False
41 |
42 | model.eval()
43 |
44 | # A logger for this file
45 | log = logging.getLogger(__name__)
46 | def to_tensor(array):
47 | if torch.is_tensor(array):
48 | return array
49 | else:
50 | return torch.tensor(array)
51 |
52 | def DotDict(in_dict):
53 | if isinstance(in_dict, dotdict):
54 | return in_dict
55 | out_dict = copy(in_dict)
56 | for k,v in out_dict.items():
57 | if isinstance(v,dict):
58 | out_dict[k] = DotDict(v)
59 | return dotdict(out_dict)
60 |
61 |
62 | def dict_to_device(tensor_dict, device):
63 | return {k: v.to(device) for k, v in tensor_dict.items()}
64 |
65 | class dotdict(dict):
66 | """dot.notation access to dictionary attributes"""
67 | __getattr__ = dict.get
68 | __setattr__ = dict.__setitem__
69 | __delattr__ = dict.__delitem__
70 |
71 | def cast_dict_to_tensors(d, device="cpu"):
72 | if isinstance(d, dict):
73 | return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
74 | elif isinstance(d, np.ndarray):
75 | return torch.from_numpy(d).float().to(device)
76 | elif isinstance(d, torch.Tensor):
77 | return d.to(device)
78 | else:
79 | return d
--------------------------------------------------------------------------------
/src/data/tools/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # from matplotlib import collections as mc
3 | # from matplotlib.pyplot import cm
4 | # import matplotlib.pyplot as plt
5 | from typing import Dict, List, Optional, Tuple
6 |
7 | def separate_actions(pair: Tuple[Tuple]):
8 |
9 | if len(pair) == 3:
10 | if pair[0][1] < pair[2][0]:
11 | # a1 10, 15 t 14, 18 a2 17, 25
12 | # a1 10, 15 t 16, 16 a2 17, 25
13 | # transition only --> transition does not matter
14 | final_pair = [(pair[0][0], pair[1][0]),
15 | (pair[1][0] + 1, pair[2][0] - 1),
16 | (pair[2][0], pair[2][1])]
17 | else:
18 | # overlap + transition --> transition does not matter
19 | over = pair[2][0] - pair[0][1]
20 | final_pair = [(pair[0][0], int(pair[0][1] + over/2)),
21 | (int(pair[0][1] + over/2 + 1), pair[2][1])]
22 | else:
23 | # give based on small or long
24 | # p1_prop = (pair[0][1] - pair[0][0]) / (eoq - soq)
25 | # p2_prop = (pair[1][1] - pair[1][0]) / (eoq - soq)
26 | # over = pair[1][0] - pair[0][1]
27 | # final_pair = [(pair[0][0], int(p1_prop*over) + pair[0][1]),
28 | # (int(p1_prop*over) + pair[0][1] + 1, pair[1][1])]
29 |
30 | # no transition at all
31 | over = pair[0][1] - pair[1][0]
32 | final_pair = [(pair[0][0], int(pair[0][1] + over/2)),
33 | (int(pair[0][1] + over/2 + 1), pair[1][1])]
34 |
35 | return final_pair
36 |
37 | def timeline_overlaps(interval: Tuple, interval_list: List[Tuple]) -> Tuple[List[Tuple],
38 | List[Tuple],
39 | List[Tuple],
40 | List[Tuple]]:
41 | '''
42 | Returns the intervals for which:
43 | (1) arr1 has overlap with
44 | (2) arr1 is a subset of
45 | (3) arr1 is a superset of
46 | '''
47 | l = interval[0]
48 | r = interval[1]
49 | inter_sub = []
50 | inter_super = []
51 | inter_before = []
52 | inter_after = []
53 | for s in interval_list:
54 |
55 | if (s[0] > l and s[0] > r) or (s[1] < l and s[1] < r):
56 | continue
57 | if s[0] <= l and s[1] >= r:
58 | inter_sub.append(s)
59 | if s[0] >= l and s[1] <= r:
60 | inter_super.append(s)
61 | if s[0] < l and s[1] < r and s[1] >= l:
62 | inter_before.append(s)
63 | if s[0] > l and s[0] <= r and s[1] > r:
64 | inter_after.append(s)
65 |
66 | return inter_before, inter_after
67 |
68 | def segments_sorted(segs_fr: List[List], acts: List) -> Tuple[List[List], List]:
69 |
70 | assert len(segs_fr) == len(acts)
71 | if len(segs_fr) == 1: return segs_fr, acts
72 | L = [ (segs_fr[i],i) for i in range(len(segs_fr)) ]
73 | L.sort()
74 | sorted_segs_fr, permutation = zip(*L)
75 | sort_acts = [acts[i] for i in permutation]
76 |
77 | return list(sorted_segs_fr), sort_acts
78 |
79 |
80 | # def plot_timeline(segments, babel_id, outdir=get_original_cwd(), accel=None):
81 |
82 | # seg_ids = [(s_s, s_e) for s_s, s_e, _ in segments]
83 | # seg_acts = [f'{a}\n{s_s}|---|{s_e}' for s_s, s_e, a in segments]
84 | # seg_lns = [ [(x[0], i*0.01), (x[1], i*0.01)] for i, x in enumerate(seg_ids) ]
85 | # colorline = cm.rainbow(np.linsinc.0, 1, len(seg_acts)))
86 | # lc = mc.LineCollection(seg_lns, colors=colorline, linewidths=3,
87 | # label=seg_acts)
88 | # fig, ax = plt.subplots()
89 |
90 | # ax.add_collection(lc)
91 | # fig.tight_layout()
92 | # ax.autoscale()
93 | # ax.margins(0.1)
94 | # # alternative for putting text there
95 | # # from matplotlib.lines import Line2D
96 | # # proxies = [ Line2D([0, 1], [0, 1], color=x) for x in colorline]
97 | # # ax.legend(proxies, seg_acts, fontsize='x-small', loc='upper left')
98 | # for i, a in enumerate(seg_acts):
99 | # plt.text((seg_ids[i][0]+seg_ids[i][1])/2, i*0.01 - 0.002, a,
100 | # fontsize='x-small', ha='center')
101 | # if accel is not None:
102 | # plt.plot(accel)
103 | # plt.title(f'Babel Sequence ID\n{babel_id}')
104 | # plt.savefig(f'{outdir}/plot_{babel_id}.png')
105 | # plt.close()
106 |
--------------------------------------------------------------------------------
/src/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | # rescale_timesteps=rescale_timesteps,
46 | )
47 |
--------------------------------------------------------------------------------
/src/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/src/launch/blender.py:
--------------------------------------------------------------------------------
1 | # Fix blender path
2 | import sys
3 | # sys.path.append("/linkhome/rech/genlgm01/uwm78rj/.local/lib/python3.9/site-packages")
4 |
5 | import bpy
6 | import os
7 | DIR = os.path.dirname(bpy.data.filepath)
8 | if DIR not in sys.path:
9 | sys.path.append(DIR)
10 |
11 | from argparse import ArgumentParser
12 |
13 | # Workaround temorary for cluster vs local
14 | # TODO fix it
15 | import socket
16 | if socket.gethostname() == 'ps018':
17 | packages_path = '/home/nathanasiou/.local/lib/python3.10/site-packages'
18 | sys.path.insert(0, packages_path)
19 | sys.path.append("/home/nathanasiou/.venvs/teach/lib/python3.10/site-packages")
20 | sys.path.append('/usr/lib/python3/dist-packages')
21 |
22 | # Monkey patch argparse such that
23 | # blender / python / hydra parsing works
24 | def parse_args(self, args=None, namesinc.None):
25 | if args is not None:
26 | return self.parse_args_bak(args=args, namesinc.namesinc.
27 | try:
28 | idx = sys.argv.index("--")
29 | args = sys.argv[idx+1:] # the list after '--'
30 | except ValueError as e: # '--' not in the list:
31 | args = []
32 | return self.parse_args_bak(args=args, namesinc.namesinc.
33 |
34 | setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
35 | setattr(ArgumentParser, 'parse_args', parse_args)
36 |
--------------------------------------------------------------------------------
/src/launch/prepare.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | # os.environ['HOME']='/home/nathanasiou'
4 | # sys.path.insert(0,'/usr/lib/python3.10/')
5 | # os.environ['PYTHONPATH']='/home/nathanasiou/.venvs/teach/lib/python3.10/site-packages'
6 | import warnings
7 | from pathlib import Path
8 | from omegaconf import OmegaConf
9 | from src.tools.runid import generate_id
10 | import hydra
11 |
12 | # Local paths
13 | def code_path(path=""):
14 | code_dir = hydra.utils.get_original_cwd()
15 | code_dir = Path(code_dir)
16 | return str(code_dir / path)
17 |
18 | def get_local_debug():
19 | import socket
20 | hostname = socket.gethostname()
21 | if hostname == 'ps018':
22 | local_debug = True
23 | else:
24 | local_debug = False
25 | return local_debug
26 |
27 | def working_path(path):
28 | return str(Path(os.getcwd()) / path)
29 |
30 |
31 | # fix the id for this run
32 | ID = generate_id()
33 | def generate_id():
34 | return ID
35 |
36 | def concat_string_list(l, d1, d2, d3):
37 | """
38 | Concatenate the strings of a list in a sorted order
39 | """
40 | # import ipdb; ipdb.set_trace()
41 | if d1 == 0:
42 | if 'hml3d' in l:
43 | l.remove('hml3d')
44 | if d2 == 0:
45 | if 'motionfix' in l:
46 | l.remove('motionfix')
47 | if d3 == 0:
48 | if 'sinc_synth' in l:
49 | l.remove('sinc_synth')
50 |
51 | return '_'.join(sorted(l))
52 |
53 | def get_last_checkpoint(path, ckpt_name="last"):
54 | if path is None:
55 | return None
56 | output_dir = Path(hydra.utils.to_absolute_path(path))
57 | if ckpt_name != 'last':
58 | last_ckpt_path = output_dir / "checkpoints" / f'latest-epoch={ckpt_name}.ckpt'
59 | else:
60 | last_ckpt_path = output_dir / "checkpoints/last.ckpt"
61 | # check exsitence
62 | if last_ckpt_path.exists():
63 | return str(last_ckpt_path)
64 | else:
65 | return None
66 |
67 | def get_samples_folder(path):
68 | output_dir = Path(hydra.utils.to_absolute_path(path))
69 | samples_path = output_dir / "samples"
70 | return str(samples_path)
71 |
72 | def get_expdir(debug):
73 | if debug:
74 | return 'experiments'
75 | else:
76 | return 'experiments'
77 |
78 | def get_debug(debug):
79 | if debug:
80 | return '_debug'
81 | else:
82 | return ''
83 |
84 | # this has to run -- pytorch memory leak in the dataloader associated with #973 pytorch issues
85 | #import resource
86 | #rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
87 | #resource.setrlimit(resource.RLIMIT_NOFILE, (12000, rlimit[1]))
88 | # Solutions summarized in --> https://github.com/Project-MONAI/MONAI/issues/701
89 | OmegaConf.register_new_resolver("get_debug", get_debug)
90 | OmegaConf.register_new_resolver("get_expdir", get_expdir)
91 | OmegaConf.register_new_resolver("code_path", code_path)
92 | OmegaConf.register_new_resolver("working_path", working_path)
93 | OmegaConf.register_new_resolver("generate_id", generate_id)
94 | OmegaConf.register_new_resolver("concat_string_list", concat_string_list)
95 | OmegaConf.register_new_resolver("absolute_path", hydra.utils.to_absolute_path)
96 | OmegaConf.register_new_resolver("get_last_checkpoint", get_last_checkpoint)
97 | OmegaConf.register_new_resolver("get_samples_folder", get_samples_folder)
98 |
99 |
100 | # Remove warnings
101 | warnings.filterwarnings(
102 | "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
103 | )
104 |
105 | warnings.filterwarnings(
106 | "ignore", ".*does not have many workers which may be a bottleneck*"
107 | )
108 |
109 | warnings.filterwarnings(
110 | "ignore", ".*Our suggested max number of worker in current system is*"
111 | )
112 |
--------------------------------------------------------------------------------
/src/logger/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from .wandb_log import WandbLogger
4 | # from pytorch_lightning.loggers import WandbLogger
5 | from pytorch_lightning.loggers import TensorBoardLogger
6 | from hydra.utils import to_absolute_path
7 | from omegaconf import DictConfig, OmegaConf
8 | from .tools import cfg_to_flatten_config
9 | import types
10 | import wandb
11 |
12 | def instantiate_logger(cfg: DictConfig):
13 | conf = OmegaConf.to_container(cfg.logger, resolve=True)
14 | name = conf.pop('logger_name')
15 | if name is None:
16 | return False
17 | if name == 'wandb':
18 | project_save_dir = to_absolute_path(Path(cfg.path.working_dir) / conf['save_dir'])
19 | Path(cfg.path.working_dir)
20 | Path(project_save_dir).mkdir(exist_ok=True)
21 | conf['dir'] = project_save_dir
22 | conf['config'] = cfg_to_flatten_config(cfg)
23 | # maybe do this for connection error in cluster, could be redundant
24 |
25 | # conf['settings'] = wandb.Settings(start_method="fork")
26 |
27 | # conf['mode']= 'online' if not cfg.logger.offline else 'offline'
28 | conf['notes']= cfg.logger.notes if cfg.logger.notes is not None else None
29 | conf['tags'] = cfg.logger.tags.strip().split(',')\
30 | if cfg.logger.tags is not None else None
31 | print('init WandbLogger')
32 | logger = WandbLogger(**conf)
33 | print('after init WandbLogger')
34 |
35 | # begin / end already defined
36 |
37 | else:
38 | def begin(self, *args, **kwargs):
39 | return
40 |
41 | def end(self, *args, **kwargs):
42 | return
43 |
44 | if name == 'tensorboard':
45 | # TODO: need to update
46 | # logger = TensorBoardLogger(**conf)
47 | logger = TensorBoardLogger(save_dir = \
48 | to_absolute_path(Path(cfg.path.working_dir) / conf['save_dir']), name='lightning_logs')
49 |
50 | logger.begin = begin
51 | logger.end = end
52 | else:
53 | raise NotImplementedError("This logger is not recognized.")
54 |
55 | logger.lname = name
56 | return logger
57 |
--------------------------------------------------------------------------------
/src/logger/tools.py:
--------------------------------------------------------------------------------
1 | # Taken from pytorch lighting / loggers / base
2 | # Mimic the log hyperparams of wandb
3 | from argparse import Namespace
4 | from typing import Dict, Any, MutableMapping, Union
5 | from omegaconf import DictConfig
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
12 | # in case converting from namespace
13 | if isinstance(params, Namespace):
14 | params = vars(params)
15 |
16 | if params is None:
17 | params = {}
18 |
19 | return params
20 |
21 |
22 | def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]:
23 | """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
24 |
25 | Args:
26 | params: Dictionary containing the hyperparameters
27 | delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.
28 |
29 | Returns:
30 | Flattened dict.
31 |
32 | Examples:
33 | >>> _flatten_dict({'a': {'b': 'c'}})
34 | {'a/b': 'c'}
35 | >>> _flatten_dict({'a': {'b': 123}})
36 | {'a/b': 123}
37 | >>> _flatten_dict({5: {'a': 123}})
38 | {'5/a': 123}
39 | """
40 |
41 | def _dict_generator(input_dict, prefixes=None):
42 | prefixes = prefixes[:] if prefixes else []
43 | if isinstance(input_dict, MutableMapping):
44 | for key, value in input_dict.items():
45 | key = str(key)
46 | if isinstance(value, (MutableMapping, Namespace)):
47 | value = vars(value) if isinstance(value, Namespace) else value
48 | yield from _dict_generator(value, prefixes + [key])
49 | else:
50 | yield prefixes + [key, value if value is not None else str(None)]
51 | else:
52 | yield prefixes + [input_dict if input_dict is None else str(input_dict)]
53 |
54 | return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}
55 |
56 |
57 | def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
58 | """Returns params with non-primitvies converted to strings for logging.
59 |
60 | >>> params = {"float": 0.3,
61 | ... "int": 1,
62 | ... "string": "abc",
63 | ... "bool": True,
64 | ... "list": [1, 2, 3],
65 | ... "namesinc.: Namesinc.foo=3),
66 | ... "layer": torch.nn.BatchNorm1d}
67 | >>> import pprint
68 | >>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE
69 | {'bool': True,
70 | 'float': 0.3,
71 | 'int': 1,
72 | 'layer': "",
73 | 'list': '[1, 2, 3]',
74 | 'namesinc.: 'Namesinc.foo=3)',
75 | 'string': 'abc'}
76 | """
77 | for k in params.keys():
78 | # convert relevant np scalars to python types first (instead of str)
79 | if isinstance(params[k], (np.bool_, np.integer, np.floating)):
80 | params[k] = params[k].item()
81 | elif type(params[k]) not in [bool, int, float, str, torch.Tensor]:
82 | params[k] = str(params[k])
83 | return params
84 |
85 |
86 | def cfg_to_flatten_config(cfg: DictConfig):
87 | params = _convert_params(cfg)
88 | params = _flatten_dict(params)
89 | params = _sanitize_params(params)
90 | return params
91 |
--------------------------------------------------------------------------------
/src/logger/wandb_log.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from pytorch_lightning.utilities import rank_zero_only
3 | from pytorch_lightning.loggers import WandbLogger as _pl_WandbLogger
4 | import os
5 | from pathlib import Path
6 | import time
7 |
8 | # Fix the step logging
9 | class WandbLogger(_pl_WandbLogger):
10 | # @rank_zero_only
11 | # def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
12 |
13 | # assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
14 | # metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
15 |
16 | # # if 'epoch' not in metrics:
17 | # # self.experiment.log({**metrics, "trainer/global_step": step},
18 | # # step=step)
19 | # # else:
20 |
21 | # if 'epoch' not in metrics:
22 | # wandb_step = self.experiment.step
23 | # else:
24 | # wandb_step = int(metrics["epoch"])
25 |
26 | # if step is not None:
27 | # self.experiment.log({**metrics, "trainer/global_step": step},
28 | # step=wandb_step)
29 | # else:
30 | # self.experiment.log(metrics, step=wandb_step)
31 |
32 | # @property
33 | def name(self) -> Optional[str]:
34 | """ Override the method because model checkpointing define the path before
35 | the initialization, and in offline mode you can't get the good path
36 | """
37 | # don't create an experiment if we don't have one
38 | # return self._experiment.project_name() if self._experiment else self._name
39 | return self._wandb_init["project"]
40 |
41 | def symlink_checkpoint(self, code_dir, project, run_id):
42 | # this is the hydra run dir!! see train.yaml
43 | local_project_dir = Path("wandb") / project
44 | local_project_dir.mkdir(parents=True, exist_ok=True)
45 |
46 | # ... but code_dir is the current dir see path.yaml
47 | Path(code_dir) / project / run_id
48 | os.symlink(Path(code_dir) / "wandb" / project / run_id,
49 | local_project_dir / f'{run_id}_{time.strftime("%Y%m%d%H%M%S")}')
50 | # # Creating a another symlink for easy access
51 | os.symlink(Path(code_dir) / "wandb" / project / run_id / "checkpoints",
52 | Path("checkpoints"))
53 | # if it exists an error is spawned which makes sense, but ...
54 |
55 | def symlink_run(self, checkpoint_folder: str):
56 |
57 | code_dir = checkpoint_folder.split("wandb/")[0]
58 | # # local run
59 | local_wandb = Path("wandb/wandb")
60 | local_wandb.mkdir(parents=True, exist_ok=True)
61 | offline_run = self.experiment.dir.split("wandb/wandb/")[1].split("/files")[0]
62 | # # Create the symlink
63 | os.symlink(Path(code_dir) / "wandb/wandb" / offline_run, local_wandb / offline_run)
64 |
65 | def begin(self, code_dir, project, run_id):
66 | self.symlink_checkpoint(code_dir, project, run_id)
67 |
68 | def end(self, checkpoint_folder):
69 | self.symlink_run(checkpoint_folder)
70 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/dummy.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | from torch import nn
4 |
5 | from src.data.tools import PoseData
6 |
7 |
8 | class Dummy(pl.LightningModule):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__()
11 | self.linear = nn.Linear(10, 10)
12 | self.relu = nn.ReLU()
13 | self.store_examples = {"train": None,
14 | "val": None,
15 | "test": None}
16 |
17 | def forward(self, batch: dict) -> PoseData:
18 | return batch["pose_data"]
19 |
20 | def allsplit_step(self, split: str, batch, batch_idx):
21 | joints = batch["pose_data"].joints
22 | if batch_idx == 0:
23 | self.store_examples[split] = {
24 | "text": batch["text"],
25 | "length": batch["length"],
26 | "ref": joints,
27 | "from_text": joints,
28 | "from_motion": joints
29 | }
30 |
31 | x = batch["pose_data"].poses
32 | x = torch.rand((5, 10), device=x.device)
33 | xhat = self.linear(x)
34 |
35 | loss = torch.nn.functional.mse_loss(x, xhat)
36 | return loss
37 |
38 | def training_step(self, batch, batch_idx):
39 | return self.allsplit_step("train", batch, batch_idx)
40 |
41 | def validation_step(self, batch, batch_idx):
42 | self.allsplit_step("val", batch, batch_idx)
43 |
44 | def test_step(self, batch, batch_idx):
45 | self.allsplit_step("test", batch, batch_idx)
46 |
47 | def configure_optimizers(self):
48 | optimizer = torch.optim.AdamW(lr=1e-4, params=self.parameters())
49 | return {"optimizer": optimizer}
--------------------------------------------------------------------------------
/src/model/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .compute_mld import MLDLosses
--------------------------------------------------------------------------------
/src/model/losses/compute.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 |
4 | from torchmetrics import Metric
5 |
6 |
7 | class TemosComputeLosses(Metric):
8 | def __init__(self, vae: bool,
9 | mode: str,
10 | loss_on_both: bool = False,
11 | motion_branch: bool = False,
12 | loss_on_jfeats: bool = True,
13 | ablation_no_kl_combine: bool = False,
14 | ablation_no_motionencoder: bool = False,
15 | ablation_no_kl_gaussian: bool = False,
16 | dist_sync_on_step=True, **kwargs):
17 | super().__init__(dist_sync_on_step=dist_sync_on_step)
18 |
19 | # Save parameters
20 | self.vae = vae
21 | self.mode = mode
22 | self.motion_branch = motion_branch
23 |
24 | self.loss_on_both = loss_on_both
25 | self.ablation_no_kl_combine = ablation_no_kl_combine
26 | self.ablation_no_kl_gaussian = ablation_no_kl_gaussian
27 | self.ablation_no_motionencoder = ablation_no_motionencoder
28 |
29 | self.loss_on_jfeats = loss_on_jfeats
30 | losses = []
31 | if mode == "xyz" or loss_on_jfeats:
32 | if motion_branch:
33 | losses.append("recons_jfeats2jfeats")
34 | losses.append("recons_text2jfeats")
35 | if mode == "smpl":
36 | if motion_branch:
37 | losses.append("recons_rfeats2rfeats")
38 | losses.append("recons_text2rfeats")
39 | else:
40 | ValueError("This mode is not recognized.")
41 |
42 | if vae or loss_on_both:
43 | kl_losses = []
44 | if not ablation_no_kl_combine and not ablation_no_motionencoder:
45 | kl_losses.extend(["kl_text2motion", "kl_motion2text"])
46 | if not ablation_no_kl_gaussian:
47 | if not motion_branch:
48 | kl_losses.extend(["kl_text"])
49 | else:
50 | kl_losses.extend(["kl_text", "kl_motion"])
51 | losses.extend(kl_losses)
52 | if not self.vae or loss_on_both:
53 | if motion_branch:
54 | losses.append("latent_manifold")
55 | losses.append("total")
56 |
57 | for loss in losses:
58 | self.add_state(loss, default=torch.tensor(0.0), dist_reduce_fx="sum")
59 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
60 | self.losses = losses
61 |
62 | # Instantiate loss functions
63 | self._losses_func = {loss: hydra.utils.instantiate(kwargs[loss + "_func"])
64 | for loss in losses if loss != "total"}
65 | # Save the lambda parameters
66 | self._params = {loss: kwargs[loss] for loss in losses if loss != "total"}
67 |
68 | def update(self, ds_text=None, ds_motion=None, ds_ref=None,
69 | lat_text=None, lat_motion=None, dis_text=None,
70 | dis_motion=None, dis_ref=None):
71 | total: float = 0.0
72 |
73 | if self.mode == "xyz" or self.loss_on_jfeats:
74 | if self.motion_branch:
75 | total += self._update_loss("recons_jfeats2jfeats", ds_motion.jfeats, ds_ref.jfeats)
76 | total += self._update_loss("recons_text2jfeats", ds_text.jfeats, ds_ref.jfeats)
77 | if self.mode == "smpl":
78 | if self.motion_branch:
79 | total += self._update_loss("recons_rfeats2rfeats", ds_motion.rfeats, ds_ref.rfeats)
80 | total += self._update_loss("recons_text2rfeats", ds_text.rfeats, ds_ref.rfeats)
81 |
82 | if self.vae or self.loss_on_both:
83 | if not self.ablation_no_kl_combine and self.motion_branch:
84 | total += self._update_loss("kl_text2motion", dis_text, dis_motion)
85 | total += self._update_loss("kl_motion2text", dis_motion, dis_text)
86 | if not self.ablation_no_kl_gaussian:
87 | total += self._update_loss("kl_text", dis_text, dis_ref)
88 | if self.motion_branch:
89 | total += self._update_loss("kl_motion", dis_motion, dis_ref)
90 | if not self.vae or self.loss_on_both:
91 | if self.motion_branch:
92 | total += self._update_loss("latent_manifold", lat_text, lat_motion)
93 |
94 | self.total += total.detach()
95 | self.count += 1
96 |
97 | return total
98 |
99 | def compute(self, split):
100 | count = getattr(self, "count")
101 | return {loss: getattr(self, loss)/count for loss in self.losses}
102 |
103 | def _update_loss(self, loss: str, outputs, inputs):
104 | # Update the loss
105 | val = self._losses_func[loss](outputs, inputs)
106 | getattr(self, loss).__iadd__(val.detach())
107 | # Return a weighted sum
108 | weighted_loss = self._params[loss] * val
109 | return weighted_loss
110 |
111 | def loss2logname(self, loss: str, split: str):
112 | if loss == "total":
113 | log_name = f"losses/{loss}/{split}"
114 | else:
115 | loss_type, name = loss.split("_")
116 | log_name = f"losses/{loss_type}/{name}/{split}"
117 | return log_name
118 |
--------------------------------------------------------------------------------
/src/model/losses/compute_mld.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torchmetrics import Metric
5 |
6 |
7 | class MLDLosses(Metric):
8 | """
9 | MLD Loss
10 | """
11 |
12 | def __init__(self, predict_epsilon, lmd_prior,
13 | lmd_kl, lmd_gen, lmd_recons,
14 | **kwargs):
15 | super().__init__(dist_sync_on_step=True)
16 |
17 | # Save parameters
18 | # self.vae = vae
19 | self.predict_epsilon = predict_epsilon
20 | self.lmd_prior = lmd_prior
21 | self.lmd_kl = lmd_kl
22 | self.lmd_gen = lmd_gen
23 | self.lmd_recons = lmd_recons
24 | losses = []
25 |
26 | # diffusion loss
27 | # instance noise loss
28 | losses.append("inst_loss")
29 | losses.append("x_loss")
30 | if lmd_prior != 0.0:
31 | # prior noise loss
32 | losses.append("prior_loss")
33 |
34 | # if self.stage in ['vae', 'vae_diffusion']:
35 | # # reconstruction loss
36 | # losses.append("recons_feature")
37 | # losses.append("recons_verts")
38 | # losses.append("recons_joints")
39 | # losses.append("recons_limb")
40 |
41 | # losses.append("gen_feature")
42 | # losses.append("gen_joints")
43 |
44 | # # KL loss
45 | # losses.append("kl_motion")
46 |
47 |
48 | losses.append("loss")
49 |
50 | for loss in losses:
51 | self.add_state(loss,
52 | default=torch.tensor(0.0),
53 | dist_reduce_fx="sum")
54 | # self.register_buffer(loss, torch.tensor(0.0))
55 | self.add_state("count", torch.tensor(0), dist_reduce_fx="sum")
56 | self.losses = losses
57 |
58 | self._losses_func = {}
59 | self._params = {}
60 | for loss in losses:
61 | if loss.split('_')[0] == 'inst':
62 | self._losses_func[loss] = nn.MSELoss(reduction='mean')
63 | self._params[loss] = 1
64 | elif loss.split('_')[0] == 'x':
65 | self._losses_func[loss] = nn.MSELoss(reduction='mean')
66 | self._params[loss] = 1
67 | elif loss.split('_')[0] == 'prior':
68 | self._losses_func[loss] = nn.MSELoss(reduction='mean')
69 | self._params[loss] = self.lmd_prior
70 | if loss.split('_')[0] == 'kl':
71 | if self.lmd_kl != 0.0:
72 | self._losses_func[loss] = KLLoss()
73 | self._params[loss] = self.lmd_kl
74 | elif loss.split('_')[0] == 'recons':
75 | self._losses_func[loss] = torch.nn.SmoothL1Loss(
76 | reduction='mean')
77 | self._params[loss] = self.lmd_recons
78 | elif loss.split('_')[0] == 'gen':
79 | self._losses_func[loss] = torch.nn.SmoothL1Loss(
80 | reduction='mean')
81 | self._params[loss] = self.lmd_gen
82 | else:
83 | ValueError("This loss is not recognized.")
84 |
85 | def update(self, rs_set):
86 | total_loss: float = 0.0
87 | # Compute the losses
88 | # Compute instance loss
89 |
90 | # predict noise
91 | if self.predict_epsilon:
92 | total_loss += self._update_loss("inst_loss", rs_set['noise_pred'],
93 | rs_set['noise'])
94 | # predict x
95 | else:
96 | total_loss += self._update_loss("x_loss", rs_set['pred'],
97 | rs_set['diff_in'])
98 |
99 | if self.lmd_prior != 0.0:
100 | # loss - prior loss
101 | total_loss += self._update_loss("prior_loss", rs_set['noise_prior'],
102 | rs_set['dist_m1'])
103 |
104 | self.loss += total_loss
105 | self.count += 1
106 |
107 | return total_loss
108 |
109 | def compute(self):
110 | count = getattr(self, "count")
111 | return {loss: getattr(self, loss) / count for loss in self.losses}
112 |
113 | def _update_loss(self, loss: str, outputs, inputs):
114 | # Update the loss
115 | val = self._losses_func[loss](outputs, inputs)
116 | getattr(self, loss).__iadd__(val)
117 | # Return a weighted sum
118 | weighted_loss = self._params[loss] * val
119 | return weighted_loss
120 |
121 | def loss2logname(self, loss: str, split: str):
122 | if loss == "loss":
123 | log_name = f"total_{loss}/{split}"
124 | else:
125 | loss_type, name = loss.split("_")
126 | log_name = f"{loss_type}/{name}/{split}"
127 | return log_name
128 |
129 | class KLLoss:
130 |
131 | def __init__(self):
132 | pass
133 |
134 | def __call__(self, q, p):
135 | div = torch.distributions.kl_divergence(q, p)
136 | return div.mean()
137 |
138 | def __repr__(self):
139 | return "KLLoss()"
140 |
141 |
142 | class KLLossMulti:
143 |
144 | def __init__(self):
145 | self.klloss = KLLoss()
146 |
147 | def __call__(self, qlist, plist):
148 | return sum([self.klloss(q, p) for q, p in zip(qlist, plist)])
149 |
150 | def __repr__(self):
151 | return "KLLossMulti()"
152 |
--------------------------------------------------------------------------------
/src/model/losses/kl.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class KLLoss:
5 | def __init__(self):
6 | pass
7 |
8 | def __call__(self, q, p, reduce_fx='mean'):
9 | div = torch.distributions.kl_divergence(q, p)
10 | if reduce_fx == 'mean':
11 | return div.mean()
12 | else:
13 | return div
14 |
15 | def __repr__(self):
16 | return "KLLoss()"
17 |
18 |
19 | class KLLossMulti:
20 | def __init__(self):
21 | self.klloss = KLLoss()
22 |
23 | def __call__(self, qlist, plist):
24 | return sum([self.klloss(q, p)
25 | for q, p in zip(qlist, plist)])
26 |
27 | def __repr__(self):
28 | return "KLLossMulti()"
29 |
--------------------------------------------------------------------------------
/src/model/losses/recons.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.functional import smooth_l1_loss
3 |
4 |
5 | class Recons:
6 | def __call__(self, input_motion_feats_lst, output_features_lst):
7 | # for x,y in zip(input_motion_feats_lst, output_features_lst):
8 | # print('----------------------')
9 | # print(x.shape, y.shape)
10 | # print(smooth_l1_loss(x, y, reduction="mean").shape)
11 | # print('----------------------')
12 | recons = torch.stack([smooth_l1_loss(x.squeeze(), y.squeeze(), reduction="mean") for x,y in zip(input_motion_feats_lst,
13 | output_features_lst)]).mean()
14 | return recons
15 |
16 | def __repr__(self):
17 | return "Recons()"
18 |
--------------------------------------------------------------------------------
/src/model/losses/recons_bp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.functional import smooth_l1_loss
3 |
4 |
5 | class ReconsBP:
6 | def __call__(self, input_motion_feats_lst, output_features_lst):
7 | recons = [smooth_l1_loss(x.squeeze(),
8 | y.squeeze(),
9 | reduction='none') for x,y in zip(input_motion_feats_lst,
10 | output_features_lst)]
11 |
12 | recons = torch.stack([bpl.mean((1,2)) for bpl in recons], dim=1)
13 | return recons
14 |
15 | def __repr__(self):
16 | return "ReconsBP()"
17 |
18 |
19 |
--------------------------------------------------------------------------------
/src/model/losses/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.nn import Module
4 |
5 | class LossTracker(Module):
6 | def __init__(self, losses):
7 | super().__init__()
8 | self.losses = losses
9 | self.training = False
10 | self.count = 0
11 | for loss in self.losses:
12 | self.register_buffer(loss, torch.tensor(0.0, device='cpu'), persistent=False)
13 |
14 | def reset(self):
15 | self.count = 0
16 | for loss in self.losses:
17 | getattr(self, loss).__imul__(0)
18 |
19 | def update(self, losses_dict):
20 | self.count += 1
21 | for loss_name, loss_val in losses_dict.items():
22 | getattr(self, loss_name).__iadd__(loss_val)
23 |
24 | def compute(self):
25 | if self.count == 0:
26 | raise ValueError("compute should be called after update")
27 | # compute the mean
28 | return {loss: getattr(self, loss)/self.count for loss in self.losses}
29 |
30 | def loss2logname(self, loss: str, split: str):
31 | if loss == "total":
32 | log_name = f"{loss}/{split}"
33 | else:
34 |
35 | if '_multi' in loss:
36 | if 'bodypart' in loss:
37 | loss_type, name, multi, _ = loss.split("_")
38 | name = f'{name}_multiple_bp'
39 | else:
40 | loss_type, name, multi = loss.split("_")
41 | name = f'{name}_multiple'
42 | else:
43 | loss_type, name = loss.split("_")
44 | log_name = f"{loss_type}/{name}/{split}"
45 | return log_name
46 |
--------------------------------------------------------------------------------
/src/model/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import imp
2 | from .compute import ComputeMetrics
--------------------------------------------------------------------------------
/src/model/motiondecoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import ActorAgnosticDecoder
2 |
--------------------------------------------------------------------------------
/src/model/motiondecoder/actor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import pytorch_lightning as pl
5 |
6 | from typing import List, Optional
7 | from torch import nn, Tensor
8 |
9 | from src.model.utils import PositionalEncoding
10 | from src.data.tools import lengths_to_mask
11 | from einops import rearrange
12 |
13 |
14 | class ActorAgnosticDecoder(pl.LightningModule):
15 | def __init__(self, nfeats: int,
16 | latent_dim: int = 256, ff_size: int = 1024,
17 | num_layers: int = 4, num_heads: int = 4,
18 | dropout: float = 0.1,
19 | activation: str = "gelu", **kwargs) -> None:
20 |
21 | super().__init__()
22 | self.save_hyperparameters(logger=False)
23 |
24 | output_feats = nfeats
25 |
26 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) # multi GPU
27 |
28 | seq_trans_decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim,
29 | nhead=num_heads,
30 | dim_feedforward=ff_size,
31 | dropout=dropout,
32 | activation=activation) # for multi GPU
33 |
34 | self.seqTransDecoder = nn.TransformerDecoder(seq_trans_decoder_layer,
35 | num_layers=num_layers)
36 |
37 | self.final_layer = nn.Linear(latent_dim, output_feats)
38 |
39 | def forward(self, z: Tensor, mask: Tensor, mem_masks=None):
40 | latent_dim = z.shape[-1]
41 | bs, nframes = mask.shape
42 | nfeats = self.hparams.nfeats
43 |
44 | # z = z[:, None] # sequence of 1 element for the memory
45 | # separate latents
46 | # torch.cat((z0[:, None], z1[:, None]), 1)
47 | if len(z.shape) > 3:
48 | z = rearrange(z, "bs nz z_len latent_dim -> (nz z_len) bs latent_dim")
49 | else:
50 | z = rearrange(z, "bs z_len latent_dim -> z_len bs latent_dim")
51 |
52 | # Construct time queries
53 | time_queries = torch.zeros(nframes, bs, latent_dim, device=z.device)
54 | time_queries = self.sequence_pos_encoding(time_queries)
55 |
56 | # Pass through the transformer decoder
57 | # with the latent vector for memory
58 | if mem_masks is not None:
59 | mem_masks = ~mem_masks
60 | output = self.seqTransDecoder(tgt=time_queries, memory=z,
61 | tgt_key_padding_mask=~mask,
62 | memory_key_padding_mask=mem_masks)
63 | output = self.final_layer(output)
64 | # zero for padded area
65 | output[~mask.T] = 0
66 | # Pytorch Transformer: [Sequence, Batch size, ...]
67 | feats = output.permute(1, 0, 2)
68 | return feats
69 |
--------------------------------------------------------------------------------
/src/model/motionencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import ActorAgnosticEncoder
--------------------------------------------------------------------------------
/src/model/motionencoder/actor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import pytorch_lightning as pl
5 |
6 | from typing import List, Optional, Union
7 | from torch import nn, Tensor
8 | from torch.distributions.distribution import Distribution
9 |
10 | from src.model.utils import PositionalEncoding
11 | from src.data.tools import lengths_to_mask
12 |
13 |
14 | class ActorAgnosticEncoder(nn.Module):
15 | def __init__(self, nfeats: int,
16 | latent_dim: int = 256, ff_size: int = 1024,
17 | num_layers: int = 4, num_heads: int = 4,
18 | dropout: float = 0.1,
19 | activation: str = "gelu", **kwargs) -> None:
20 | super().__init__()
21 |
22 | input_feats = nfeats
23 | self.skel_embedding = nn.Linear(input_feats, latent_dim)
24 | # self.layer_norm = nn.LayerNorm(nfeats)
25 | # Action agnostic: only one set of params
26 | self.emb_token = nn.Parameter(torch.randn(latent_dim))
27 |
28 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) # multi-GPU
29 |
30 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
31 | nhead=num_heads,
32 | dim_feedforward=ff_size,
33 | dropout=dropout,
34 | activation=activation) # multi-gpu
35 |
36 | self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer,
37 | num_layers=num_layers)
38 |
39 | def forward(self, features: Tensor, mask: Tensor) -> Union[Tensor, Distribution]:
40 | in_mask = mask
41 | device = features.device
42 |
43 | nframes, bs, nfeats = features.shape
44 |
45 | x = features
46 | # Embed each human poses into latent vectors
47 | # x = self.layer_norm(x)
48 | x = self.skel_embedding(x)
49 | # Switch sequence and batch_size because the input of
50 | # Pytorch Transformer is [Sequence, Batch size, ...]
51 | # x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
52 | # Each batch has its own set of tokens
53 | emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1)
54 |
55 | # adding the embedding token for all sequences
56 | xseq = torch.cat((emb_token[None], x), 0)
57 |
58 | # create a bigger mask, to allow attend to emb
59 | token_mask = torch.ones((bs, 1), dtype=bool, device=x.device)
60 | aug_mask = torch.cat((token_mask, in_mask), 1)
61 | # add positional encoding
62 | xseq = self.sequence_pos_encoding(xseq)
63 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
64 |
65 | return final[0]
66 |
--------------------------------------------------------------------------------
/src/model/readme.txt:
--------------------------------------------------------------------------------
1 | denoiser models
2 | phase encoding
3 | length input
--------------------------------------------------------------------------------
/src/model/textencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .distilbert_encoder import DistilbertEncoderTransformer
2 | from .clip_encoder import ClipTextEncoder
3 | from .t5_encoder import T5TextEncoder
--------------------------------------------------------------------------------
/src/model/textencoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.distributions.distribution import Distribution
7 | from src.utils.file_io import hack_path
8 | import pytorch_lightning as pl
9 |
10 | class ClipTextEncoder(pl.LightningModule):
11 |
12 | def __init__(
13 | self,
14 | modelpath: str,
15 | finetune: bool = False,
16 | last_hidden_state: bool = True,
17 | **kwargs
18 | ) -> None:
19 |
20 | super().__init__()
21 | self.save_hyperparameters(logger=False)
22 | from transformers import logging
23 | from transformers import AutoModel, AutoTokenizer
24 | logging.set_verbosity_error()
25 | # Tokenizer
26 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
27 |
28 | self.tokenizer = AutoTokenizer.from_pretrained(hack_path(modelpath))
29 | self.text_model = AutoModel.from_pretrained(hack_path(modelpath))
30 |
31 | # Don't train the model
32 | if not finetune:
33 | self.text_model.training = False
34 | for p in self.text_model.parameters():
35 | p.requires_grad = False
36 |
37 | # Then configure the model
38 | self.max_length = self.tokenizer.model_max_length
39 | if "clip" in modelpath:
40 | self.text_encoded_dim = self.text_model.config.text_config.hidden_size
41 | if last_hidden_state:
42 | self.variant = "clip_hidden"
43 | else:
44 | self.variant = "clip"
45 | elif "bert" in modelpath:
46 | self.variant = "bert"
47 | self.text_encoded_dim = self.text_model.config.hidden_size
48 | else:
49 | raise ValueError(f"Model {modelpath} not supported")
50 |
51 | def forward(self, texts: List[str]):
52 | # get prompt text embeddings
53 | if self.variant in ["clip", "clip_hidden"]:
54 | text_inputs = self.tokenizer(
55 | texts,
56 | padding="max_length",
57 | truncation=True,
58 | max_length=self.max_length,
59 | return_tensors="pt",
60 | )
61 | text_input_ids = text_inputs.input_ids.to(self.text_model.device)
62 | txt_att_mask = text_inputs.attention_mask.to(self.text_model.device)
63 | # split into max length Clip can handle
64 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
65 | text_input_ids = text_input_ids[:, :self.tokenizer.
66 | model_max_length]
67 | elif self.variant == "bert":
68 | text_inputs = self.tokenizer(texts,
69 | return_tensors="pt",
70 | padding=True)
71 |
72 | # use pooled ouuput if latent dim is two-dimensional
73 | # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim)
74 | # text encoder forward, clip must use get_text_features
75 | # TODO check the CLIP network
76 | if self.variant == "clip":
77 | # (batch_Size, text_encoded_dim)
78 | text_embeddings = self.text_model.get_text_features(
79 | text_input_ids.to(self.text_model.device))
80 | # (batch_Size, 1, text_encoded_dim)
81 | text_embeddings = text_embeddings.unsqueeze(1)
82 | elif self.variant == "clip_hidden":
83 | # (batch_Size, seq_length , text_encoded_dim)
84 | text_embeddings = self.text_model.text_model(text_input_ids,
85 | # attention_mask=txt_att_mask
86 | ).last_hidden_state
87 | elif self.variant == "bert":
88 | # (batch_Size, seq_length , text_encoded_dim)
89 | text_embeddings = self.text_model(
90 | **text_inputs.to(self.text_model.device)).last_hidden_state
91 | else:
92 | raise NotImplementedError(f"Model {self.name} not implemented")
93 |
94 | return text_embeddings, txt_att_mask.bool()
95 |
--------------------------------------------------------------------------------
/src/model/textencoder/distilbert.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 | import pytorch_lightning as pl
3 |
4 | import torch.nn as nn
5 | import os
6 |
7 | import torch
8 | from torch import Tensor
9 | from torch.distributions.distribution import Distribution
10 |
11 |
12 | class DistilbertEncoderBase(pl.LightningModule):
13 | def __init__(self, modelpath: str,
14 | finetune: bool = False) -> None:
15 | super().__init__()
16 |
17 | from transformers import AutoTokenizer, AutoModel
18 | from transformers import logging
19 | logging.set_verbosity_error()
20 | # Tokenizer
21 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
22 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
23 |
24 | # Text model
25 | self.text_model = AutoModel.from_pretrained(modelpath)
26 | # Don't train the model
27 | if not finetune:
28 | self.text_model.training = False
29 | for p in self.text_model.parameters():
30 | p.requires_grad = False
31 |
32 | # Then configure the model
33 | self.text_encoded_dim = self.text_model.config.hidden_size
34 |
35 | def train(self, mode: bool = True):
36 | self.training = mode
37 | for module in self.children():
38 | # Don't put the model in
39 | if module == self.text_model and not self.hparams.finetune:
40 | continue
41 | module.train(mode)
42 | return self
43 |
44 | def get_last_hidden_state(self, texts: List[str],
45 | return_mask: bool = False
46 | ) -> Union[Tensor, tuple[Tensor, Tensor]]:
47 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
48 | output = self.text_model(**encoded_inputs.to(self.text_model.device))
49 | if not return_mask:
50 | return output.last_hidden_state
51 | return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
--------------------------------------------------------------------------------
/src/model/textencoder/distilbert_encoder.py:
--------------------------------------------------------------------------------
1 | from .distilbert import DistilbertEncoderBase
2 | import torch
3 |
4 | from typing import List, Union
5 | from torch import nn, Tensor
6 | from torch.distributions.distribution import Distribution
7 |
8 | from src.model.utils import PositionalEncoding
9 | from src.data.tools import lengths_to_mask
10 | from src.utils.file_io import hack_path
11 |
12 | class DistilbertEncoderTransformer(DistilbertEncoderBase):
13 | def __init__(self, modelpath: str,
14 | finetune: bool = False,
15 | vae: bool = False,
16 | latent_dim: int = 256,
17 | ff_size: int = 1024,
18 | num_layers: int = 4, num_heads: int = 4,
19 | dropout: float = 0.1,
20 | activation: str = "gelu", **kwargs) -> None:
21 | super().__init__(modelpath=hack_path(modelpath), finetune=finetune)
22 | self.save_hyperparameters(logger=False)
23 | encoded_dim = self.text_encoded_dim
24 | latent_dim = latent_dim
25 | # Projection of the text-outputs into the latent space
26 | self.projection = nn.Sequential(nn.ReLU(),
27 | nn.Linear(encoded_dim, latent_dim))
28 |
29 | # TransformerVAE adapted from ACTOR
30 | # Action agnostic: only one set of params
31 | if vae:
32 | self.mu_token = nn.Parameter(torch.randn(latent_dim))
33 | self.logvar_token = nn.Parameter(torch.randn(latent_dim))
34 | else:
35 | self.emb_token = nn.Parameter(torch.randn(latent_dim))
36 |
37 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout)
38 |
39 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
40 | nhead=num_heads,
41 | dim_feedforward=ff_size,
42 | dropout=dropout,
43 | activation=activation)
44 |
45 | self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer,
46 | num_layers=num_layers)
47 |
48 | def forward(self, texts: List[str]) -> Union[Tensor, Distribution]:
49 | text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
50 |
51 | x = self.projection(text_encoded)
52 | bs, nframes, _ = x.shape
53 | # bs, nframes, totjoints, nfeats = x.shape
54 | # Switch sequence and batch_size because the input of
55 | # Pytorch Transformer is [Sequence, Batch size, ...]
56 | x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
57 |
58 | if self.hparams.vae:
59 | mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
60 | logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
61 |
62 | # adding the distribution tokens for all sequences
63 | xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
64 |
65 | # create a bigger mask, to allow attend to mu and logvar
66 | token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
67 | aug_mask = torch.cat((token_mask, mask), 1)
68 | else:
69 | emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1)
70 |
71 | # adding the embedding token for all sequences
72 | xseq = torch.cat((emb_token[None], x), 0)
73 |
74 | # create a bigger mask, to allow attend to emb
75 | token_mask = torch.ones((bs, 1), dtype=bool, device=x.device)
76 | aug_mask = torch.cat((token_mask, mask), 1)
77 |
78 | # add positional encoding
79 | xseq = self.sequence_pos_encoding(xseq)
80 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
81 |
82 | if self.hparams.vae:
83 | mu, logvar = final[0], final[1]
84 | std = logvar.exp().pow(0.5)
85 | # https://github.com/kampta/pytorch-distributions/blob/master/gaussian_vae.py
86 | try:
87 | dist = torch.distributions.normal.Normal(mu, std)
88 | except ValueError:
89 | import ipdb; ipdb.set_trace() # noqa
90 | pass
91 | return dist
92 | else:
93 | return final[0]
94 |
--------------------------------------------------------------------------------
/src/model/textencoder/t5_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.distributions.distribution import Distribution
7 | from src.utils.file_io import hack_path
8 | import pytorch_lightning as pl
9 |
10 | class T5TextEncoder(pl.LightningModule):
11 |
12 | def __init__(
13 | self,
14 | modelpath: str,
15 | finetune: bool = False,
16 | **kwargs
17 | ) -> None:
18 |
19 | super().__init__()
20 | self.save_hyperparameters(logger=False)
21 | from transformers import logging
22 | from transformers import AutoModel
23 | from transformers import AutoTokenizer, T5EncoderModel
24 |
25 | logging.set_verbosity_error()
26 | # Tokenizer
27 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
28 |
29 | self.tokenizer = AutoTokenizer.from_pretrained(hack_path(modelpath),
30 | legacy=True)
31 | self.language_model = T5EncoderModel.from_pretrained(
32 | hack_path(modelpath))
33 | self.language_model.resize_token_embeddings(len(self.tokenizer))
34 | self.max_length = 128 # self.tokenizer.model_max_length
35 | # Don't train the model
36 | if not finetune:
37 | self.language_model.training = False
38 | for p in self.language_model.parameters():
39 | p.requires_grad = False
40 |
41 | def forward(self, texts: List[str]):
42 |
43 | # # Tokenize
44 | text_inputs = self.tokenizer(texts,
45 | padding='max_length',
46 | max_length=self.max_length,
47 | truncation=True,
48 | return_attention_mask=True,
49 | add_special_tokens=True,
50 | return_tensors="pt")
51 |
52 |
53 | # input_ids = self.tokenizer(texts, return_tensors="pt").input_ids # Batch size 1
54 | text_input_ids = text_inputs.input_ids.to(self.language_model.device)
55 | txt_att_mask = text_inputs.attention_mask.to(self.language_model.device)
56 |
57 | outputs = self.language_model(input_ids=text_input_ids, attention_mask=txt_att_mask)
58 | last_hidden_states = outputs.last_hidden_state
59 |
60 | return last_hidden_states, txt_att_mask.bool()
--------------------------------------------------------------------------------
/src/model/tmr_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa
2 | from .temos import TEMOS # noqa
3 | from .tmr import TMR # noqa
4 | from .text_encoder import TextToEmb # noqa
5 |
--------------------------------------------------------------------------------
/src/model/tmr_utils/actor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | import numpy as np
7 |
8 | from einops import repeat
9 |
10 |
11 | class PositionalEncoding(nn.Module):
12 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
13 | super().__init__()
14 | self.batch_first = batch_first
15 |
16 | self.dropout = nn.Dropout(p=dropout)
17 |
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
20 | div_term = torch.exp(
21 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
22 | )
23 | pe[:, 0::2] = torch.sin(position * div_term)
24 | pe[:, 1::2] = torch.cos(position * div_term)
25 | pe = pe.unsqueeze(0).transpose(0, 1)
26 | self.register_buffer("pe", pe, persistent=False)
27 |
28 | def forward(self, x: Tensor) -> Tensor:
29 | if self.batch_first:
30 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
31 | else:
32 | x = x + self.pe[: x.shape[0], :]
33 | return self.dropout(x)
34 |
35 |
36 | class ACTORStyleEncoder(nn.Module):
37 | # Similar to ACTOR but "action agnostic" and more general
38 | def __init__(
39 | self,
40 | nfeats: int,
41 | vae: bool,
42 | latent_dim: int = 256,
43 | ff_size: int = 1024,
44 | num_layers: int = 4,
45 | num_heads: int = 4,
46 | dropout: float = 0.1,
47 | activation: str = "gelu",
48 | ) -> None:
49 | super().__init__()
50 |
51 | self.nfeats = nfeats
52 | self.projection = nn.Linear(nfeats, latent_dim)
53 |
54 | self.vae = vae
55 | self.nbtokens = 2 if vae else 1
56 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
57 |
58 | self.sequence_pos_encoding = PositionalEncoding(
59 | latent_dim, dropout=dropout, batch_first=True
60 | )
61 |
62 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(
63 | d_model=latent_dim,
64 | nhead=num_heads,
65 | dim_feedforward=ff_size,
66 | dropout=dropout,
67 | activation=activation,
68 | batch_first=True,
69 | )
70 |
71 | self.seqTransEncoder = nn.TransformerEncoder(
72 | seq_trans_encoder_layer, num_layers=num_layers
73 | )
74 |
75 | def forward(self, x_dict: Dict) -> Tensor:
76 | x = x_dict["x"]
77 | mask = x_dict["mask"]
78 |
79 | x = self.projection(x)
80 |
81 | device = x.device
82 | bs = len(x)
83 |
84 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
85 | xseq = torch.cat((tokens, x), 1)
86 |
87 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)
88 | aug_mask = torch.cat((token_mask, mask), 1)
89 |
90 | # add positional encoding
91 | xseq = self.sequence_pos_encoding(xseq)
92 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
93 | return final[:, : self.nbtokens]
94 |
95 |
96 | class ACTORStyleDecoder(nn.Module):
97 | # Similar to ACTOR Decoder
98 |
99 | def __init__(
100 | self,
101 | nfeats: int,
102 | latent_dim: int = 256,
103 | ff_size: int = 1024,
104 | num_layers: int = 4,
105 | num_heads: int = 4,
106 | dropout: float = 0.1,
107 | activation: str = "gelu",
108 | ) -> None:
109 | super().__init__()
110 | output_feats = nfeats
111 | self.nfeats = nfeats
112 |
113 | self.sequence_pos_encoding = PositionalEncoding(
114 | latent_dim, dropout, batch_first=True
115 | )
116 |
117 | seq_trans_decoder_layer = nn.TransformerDecoderLayer(
118 | d_model=latent_dim,
119 | nhead=num_heads,
120 | dim_feedforward=ff_size,
121 | dropout=dropout,
122 | activation=activation,
123 | batch_first=True,
124 | )
125 |
126 | self.seqTransDecoder = nn.TransformerDecoder(
127 | seq_trans_decoder_layer, num_layers=num_layers
128 | )
129 |
130 | self.final_layer = nn.Linear(latent_dim, output_feats)
131 |
132 | def forward(self, z_dict: Dict) -> Tensor:
133 | z = z_dict["z"]
134 | mask = z_dict["mask"]
135 |
136 | latent_dim = z.shape[1]
137 | bs, nframes = mask.shape
138 |
139 | z = z[:, None] # sequence of 1 element for the memory
140 |
141 | # Construct time queries
142 | time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device)
143 | time_queries = self.sequence_pos_encoding(time_queries)
144 |
145 | # Pass through the transformer decoder
146 | # with the latent vector for memory
147 | output = self.seqTransDecoder(
148 | tgt=time_queries, memory=z, tgt_key_padding_mask=~mask
149 | )
150 |
151 | output = self.final_layer(output)
152 | # zero for padded area
153 | output[~mask] = 0
154 | return output
155 |
--------------------------------------------------------------------------------
/src/model/tmr_utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # For reference
6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
8 | class KLLoss:
9 | def __call__(self, q, p):
10 | mu_q, logvar_q = q
11 | mu_p, logvar_p = p
12 |
13 | log_var_ratio = logvar_q - logvar_p
14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
16 | return div.mean()
17 |
18 | def __repr__(self):
19 | return "KLLoss()"
20 |
21 |
22 | class InfoNCE_with_filtering:
23 | def __init__(self, temperature=0.7, threshold_selfsim=0.8):
24 | self.temperature = temperature
25 | self.threshold_selfsim = threshold_selfsim
26 |
27 | def get_sim_matrix(self, x, y):
28 | x_logits = torch.nn.functional.normalize(x, dim=-1)
29 | y_logits = torch.nn.functional.normalize(y, dim=-1)
30 | sim_matrix = x_logits @ y_logits.T
31 | return sim_matrix
32 |
33 | def __call__(self, x, y, sent_emb=None):
34 | bs, device = len(x), x.device
35 | sim_matrix = self.get_sim_matrix(x, y) / self.temperature
36 |
37 | if sent_emb is not None and self.threshold_selfsim:
38 | # put the threshold value between -1 and 1
39 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1
40 | # Filtering too close values
41 | # mask them by putting -inf in the sim_matrix
42 | selfsim = sent_emb @ sent_emb.T
43 | selfsim_nodiag = selfsim - selfsim.diag().diag()
44 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
45 | sim_matrix[idx] = -torch.inf
46 |
47 | labels = torch.arange(bs, device=device)
48 |
49 | total_loss = (
50 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
51 | ) / 2
52 |
53 | return total_loss
54 |
55 | def __repr__(self):
56 | return f"Constrastive(temp={self.temp})"
57 |
--------------------------------------------------------------------------------
/src/model/tmr_utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def print_latex_metrics(metrics):
5 | vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]]
6 | t2m_keys = [f"t2m/R{i}" for i in vals] + ["t2m/MedR"]
7 | m2t_keys = [f"m2t/R{i}" for i in vals] + ["m2t/MedR"]
8 |
9 | keys = t2m_keys + m2t_keys
10 |
11 | def ff(val_):
12 | val = str(val_).ljust(5, "0")
13 | # make decimal fine when only one digit
14 | if val[1] == ".":
15 | val = str(val_).ljust(4, "0")
16 | return val
17 |
18 | str_ = "& " + " & ".join([ff(metrics[key]) for key in keys]) + r" \\"
19 | dico = {key: ff(metrics[key]) for key in keys}
20 | print(dico)
21 | print("Number of samples: {}".format(int(metrics["t2m/len"])))
22 | print(str_)
23 |
24 |
25 | def all_contrastive_metrics(
26 | sims, emb=None, threshold=None, rounding=2, return_cols=False
27 | ):
28 | text_selfsim = None
29 | if emb is not None:
30 | text_selfsim = emb @ emb.T
31 |
32 | t2m_m, t2m_cols = contrastive_metrics(
33 | sims, text_selfsim, threshold, return_cols=True, rounding=rounding
34 | )
35 | m2t_m, m2t_cols = contrastive_metrics(
36 | sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding
37 | )
38 |
39 | all_m = {}
40 | for key in t2m_m:
41 | all_m[f"t2m/{key}"] = t2m_m[key]
42 | all_m[f"m2t/{key}"] = m2t_m[key]
43 |
44 | all_m["t2m/len"] = float(len(sims))
45 | all_m["m2t/len"] = float(len(sims[0]))
46 | if return_cols:
47 | return all_m, t2m_cols, m2t_cols
48 | return all_m
49 |
50 |
51 | def contrastive_metrics(
52 | sims,
53 | text_selfsim=None,
54 | threshold=None,
55 | return_cols=False,
56 | rounding=2,
57 | break_ties="averaging",
58 | ):
59 | n, m = sims.shape
60 | assert n == m
61 | num_queries = n
62 |
63 | dists = -sims
64 | sorted_dists = np.sort(dists, axis=1)
65 | # GT is in the diagonal
66 | gt_dists = np.diag(dists)[:, None]
67 |
68 | if text_selfsim is not None and threshold is not None:
69 | real_threshold = 2 * threshold - 1
70 | idx = np.argwhere(text_selfsim > real_threshold)
71 | partition = np.unique(idx[:, 0], return_index=True)[1]
72 | # take as GT the minimum score of similar values
73 | gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition)
74 | gt_dists = gt_dists[:, None]
75 |
76 | rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT
77 |
78 | # if there are ties
79 | if rows.size > num_queries:
80 | assert np.unique(rows).size == num_queries, "issue in metric evaluation"
81 | if break_ties == "optimistically":
82 | opti_cols = break_ties_optimistically(sorted_dists, gt_dists)
83 | cols = opti_cols
84 | elif break_ties == "averaging":
85 | avg_cols = break_ties_average(sorted_dists, gt_dists)
86 | cols = avg_cols
87 |
88 | msg = "expected ranks to match queries ({} vs {}) "
89 | assert cols.size == num_queries, msg
90 |
91 | if return_cols:
92 | return cols2metrics(cols, num_queries, rounding=rounding), cols
93 | return cols2metrics(cols, num_queries, rounding=rounding)
94 |
95 |
96 | def break_ties_average(sorted_dists, gt_dists):
97 | # fast implementation, based on this code:
98 | # https://stackoverflow.com/a/49239335
99 | locs = np.argwhere((sorted_dists - gt_dists) == 0)
100 |
101 | # Find the split indices
102 | steps = np.diff(locs[:, 0])
103 | splits = np.nonzero(steps)[0] + 1
104 | splits = np.insert(splits, 0, 0)
105 |
106 | # Compute the result columns
107 | summed_cols = np.add.reduceat(locs[:, 1], splits)
108 | counts = np.diff(np.append(splits, locs.shape[0]))
109 | avg_cols = summed_cols / counts
110 | return avg_cols
111 |
112 |
113 | def break_ties_optimistically(sorted_dists, gt_dists):
114 | rows, cols = np.where((sorted_dists - gt_dists) == 0)
115 | _, idx = np.unique(rows, return_index=True)
116 | cols = cols[idx]
117 | return cols
118 |
119 |
120 | def cols2metrics(cols, num_queries, rounding=2):
121 | metrics = {}
122 | vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]]
123 | for val in vals:
124 | metrics[f"R{val}"] = 100 * float(np.sum(cols < int(val))) / num_queries
125 |
126 | metrics["MedR"] = float(np.median(cols) + 1)
127 |
128 | if rounding is not None:
129 | for key in metrics:
130 | metrics[key] = round(metrics[key], rounding)
131 | return metrics
132 |
--------------------------------------------------------------------------------
/src/model/tmr_utils/text_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch
4 | from torch import Tensor
5 | from typing import Dict, List
6 | import torch.nn.functional as F
7 |
8 |
9 | class TextToEmb(nn.Module):
10 | def __init__(
11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu"
12 | ) -> None:
13 | super().__init__()
14 |
15 | self.device = device
16 | from transformers import AutoTokenizer, AutoModel
17 | from transformers import logging
18 |
19 | logging.set_verbosity_error()
20 |
21 | # Tokenizer
22 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
24 |
25 | # Text model
26 | self.text_model = AutoModel.from_pretrained(modelpath)
27 | # Then configure the model
28 | self.text_encoded_dim = self.text_model.config.hidden_size
29 |
30 | if mean_pooling:
31 | self.forward = self.forward_pooling
32 |
33 | # put it in eval mode by default
34 | self.eval()
35 |
36 | # Freeze the weights just in case
37 | for param in self.parameters():
38 | param.requires_grad = False
39 |
40 | self.to(device)
41 |
42 | def train(self, mode: bool = True) -> nn.Module:
43 | # override it to be always false
44 | self.training = False
45 | for module in self.children():
46 | module.train(False)
47 | return self
48 |
49 | @torch.no_grad()
50 | def forward(self, texts: List[str], device=None) -> Dict:
51 | device = device if device is not None else self.device
52 |
53 | squeeze = False
54 | if isinstance(texts, str):
55 | texts = [texts]
56 | squeeze = True
57 |
58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
59 | output = self.text_model(**encoded_inputs.to(device))
60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1)
61 |
62 | if squeeze:
63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]}
64 | else:
65 | x_dict = {"x": output.last_hidden_state, "length": length}
66 | return x_dict
67 |
68 | @torch.no_grad()
69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor:
70 | device = device if device is not None else self.device
71 |
72 | squeeze = False
73 | if isinstance(texts, str):
74 | texts = [texts]
75 | squeeze = True
76 |
77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2
78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
79 | output = self.text_model(**encoded_inputs.to(device))
80 | attention_mask = encoded_inputs["attention_mask"]
81 |
82 | # Mean Pooling - Take attention mask into account for correct averaging
83 | token_embeddings = output["last_hidden_state"]
84 | input_mask_expanded = (
85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
86 | )
87 | sentence_embeddings = torch.sum(
88 | token_embeddings * input_mask_expanded, 1
89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
90 | # Normalize embeddings
91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
92 | if squeeze:
93 | sentence_embeddings = sentence_embeddings[0]
94 | return sentence_embeddings
95 |
--------------------------------------------------------------------------------
/src/model/tmr_utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | import os
10 | import json
11 | from omegaconf import DictConfig, OmegaConf
12 |
13 |
14 | def save_config(cfg: DictConfig) -> str:
15 | path = os.path.join(cfg.run_dir, "config.json")
16 | config = OmegaConf.to_container(cfg, resolve=True)
17 | with open(path, "w") as f:
18 | string = json.dumps(config, indent=4)
19 | f.write(string)
20 | return path
21 |
22 |
23 | def read_config(run_dir: str, return_json=False) -> DictConfig:
24 | path = os.path.join(run_dir, "config.json")
25 | with open(path, "r") as f:
26 | config = json.load(f)
27 | if return_json:
28 | return config
29 | cfg = OmegaConf.create(config)
30 | cfg.run_dir = run_dir
31 | return cfg
32 |
33 | # split the lightning checkpoint into
34 | # seperate state_dict modules for faster loading
35 | def extract_ckpt(run_dir, ckpt_name="last"):
36 | import torch
37 |
38 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
39 |
40 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
41 | os.makedirs(extracted_path, exist_ok=True)
42 |
43 | new_path_template = os.path.join(extracted_path, "{}.pt")
44 | ckpt_dict = torch.load(ckpt_path)
45 | state_dict = ckpt_dict["state_dict"]
46 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
47 |
48 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
49 | for module_name in module_names:
50 | path = new_path_template.format(module_name)
51 | sub_state_dict = {
52 | ".".join(x.split(".")[1:]): y.cpu()
53 | for x, y in state_dict.items()
54 | if x.split(".")[0] == module_name
55 | }
56 | torch.save(sub_state_dict, path)
57 |
58 |
59 | def load_model(run_dir, **params):
60 | # Load last config
61 | cfg = read_config(run_dir)
62 | cfg.run_dir = run_dir
63 | return load_model_from_cfg(cfg, **params)
64 |
65 |
66 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True):
67 | import torch
68 |
69 | run_dir = cfg.run_dir
70 | model = hydra.utils.instantiate(cfg.model)
71 |
72 | # Loading modules one by one
73 | # motion_encoder / text_encoder / text_decoder
74 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
75 |
76 | if not os.path.exists(pt_path):
77 | logger.info("The extracted model is not found. Split into submodules..")
78 | extract_ckpt(run_dir, ckpt_name)
79 |
80 | for fname in os.listdir(pt_path):
81 | module_name, ext = os.path.splitext(fname)
82 | if ext != ".pt":
83 | continue
84 |
85 | module = getattr(model, module_name, None)
86 | if module is None:
87 | continue
88 |
89 | module_path = os.path.join(pt_path, fname)
90 | state_dict = torch.load(module_path)
91 | module.load_state_dict(state_dict)
92 | logger.info(f" {module_name} loaded")
93 |
94 | logger.info("Loading previous checkpoint done")
95 | model = model.to(device)
96 | logger.info(f"Put the model on {device}")
97 | if eval_mode:
98 | model = model.eval()
99 | logger.info("Put the model in eval mode")
100 | return model
101 |
102 |
103 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model")
104 | def hydra_load_model(cfg: DictConfig) -> None:
105 | run_dir = cfg.run_dir
106 | ckpt_name = cfg.ckpt
107 | device = cfg.device
108 | eval_mode = cfg.eval_mode
109 | return load_model(run_dir, ckpt_name, device, eval_mode)
110 |
111 |
112 | if __name__ == "__main__":
113 | hydra_load_model()
114 |
--------------------------------------------------------------------------------
/src/model/trans_enc.py:
--------------------------------------------------------------------------------
1 | from src.model.DiT_models import *
2 |
3 | class EncoderBlock(nn.Module):
4 | """
5 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
6 | """
7 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
8 | super().__init__()
9 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
10 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
11 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
12 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
13 | approx_gelu = lambda: nn.GELU(approximate="tanh")
14 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
15 |
16 | def forward(self, x, c, mask=None):
17 | x = x + self.attn((self.norm1(x)), mask)
18 | x = x + self.mlp((self.norm2(x)))
19 | return x
20 |
21 | def forward_with_att(self, x, c, mask=None):
22 | attention_out, attention_mask = self.attn.forward_w_attention((self.norm1(x)), mask)
23 | x = x + attention_out
24 | x = x + self.mlp((self.norm2(x)))
25 | return x, attention_mask
26 |
27 | import copy
28 | def clones(module, N):
29 | """Produce N identical layers."""
30 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
31 |
32 | class LayerNorm(nn.Module):
33 | def __init__(self, features: int, eps: float = 1e-6):
34 | # features = d_model
35 | super(LayerNorm, self).__init__()
36 | self.a = nn.Parameter(torch.ones(features))
37 | self.b = nn.Parameter(torch.zeros(features))
38 | self.eps = eps
39 |
40 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
41 | mean = x.mean(-1, keepdim=True)
42 | std = x.std(-1, keepdim=True)
43 | return self.a * (x - mean) / (std + self.eps) + self.b
44 |
45 | class TransEncoder(nn.Module):
46 | """Core encoder is a stack of N layers"""
47 |
48 | def __init__(self, layer, N: int):
49 | super(TransEncoder, self).__init__()
50 | self.layers = clones(layer, N)
51 |
52 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor:
53 | """Pass the input (and mask) through each layer in turn."""
54 | for layer in self.layers:
55 | x = layer(x, mask)
56 | return x
57 | def forward_with_att(self, x, c, mask=None):
58 | attention_masks = []
59 | for layer in self.layers:
60 | x, attention_mask = layer.forward_with_att(x, c, mask)
61 | attention_masks.append(attention_mask)
62 | return x, attention_masks
--------------------------------------------------------------------------------
/src/model/transformer_encoder/encoder.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/encoder.py
--------------------------------------------------------------------------------
/src/model/transformer_encoder/encoder_layer.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/encoder_layer.py
--------------------------------------------------------------------------------
/src/model/transformer_encoder/feed_forward.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/feed_forward.py
--------------------------------------------------------------------------------
/src/model/transformer_encoder/note.md:
--------------------------------------------------------------------------------
1 | Based on https://github.com/guocheng18/Transformer-Encoder/tree/master
--------------------------------------------------------------------------------
/src/model/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .positional_encoding import PositionalEncoding
2 | from .vae import reparameterize
3 |
--------------------------------------------------------------------------------
/src/model/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Optimizer
2 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
3 | from functools import partial
4 | import math
5 | import torch
6 | import math
7 | import functools
8 |
9 | def _cosine_decay_warmup(iteration, warmup_iterations, total_iterations, final_learning_rate, initial_learning_rate):
10 | """
11 | Adjusted function to handle linear warmup and cosine decay that correctly transitions to the final learning rate.
12 | """
13 | if iteration <= warmup_iterations:
14 | # Linear warmup: multiplier increases from 0 to 1
15 | return iteration / warmup_iterations
16 | else:
17 | # Cosine decay phase: decay from 1 to final_learning_rate/initial_learning_rate
18 | decayed = (iteration - warmup_iterations) / (total_iterations - warmup_iterations)
19 | decayed = 0.5 * (1 + math.cos(math.pi * decayed))
20 | # Normalize the decay such that it ends at final_learning_rate/initial_learning_rate
21 | decay_multiplier = decayed * (1 - final_learning_rate / initial_learning_rate) + final_learning_rate / initial_learning_rate
22 | return decay_multiplier
23 |
24 | def CosineAnnealingLRWarmup(optimizer, T_max, T_warmup, lr_final, lr_initial):
25 | _decay_func = functools.partial(
26 | _cosine_decay_warmup,
27 | warmup_iterations=T_warmup, total_iterations=T_max,
28 | final_learning_rate=lr_final,
29 | initial_learning_rate=lr_initial
30 | )
31 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, _decay_func)
32 | return scheduler
33 |
--------------------------------------------------------------------------------
/src/model/utils/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class PositionalEncoding(nn.Module):
7 | def __init__(self, d_model, dropout=0.1,
8 | max_len=5000, batch_first=False, negative=False):
9 | super().__init__()
10 | self.batch_first = batch_first
11 |
12 | self.dropout = nn.Dropout(p=dropout)
13 | self.max_len = max_len
14 |
15 | self.negative = negative
16 |
17 | if negative:
18 | pe = torch.zeros(2*max_len, d_model)
19 | position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1)
20 | else:
21 | pe = torch.zeros(max_len, d_model)
22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
23 |
24 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
25 | pe[:, 0::2] = torch.sin(position * div_term)
26 | pe[:, 1::2] = torch.cos(position * div_term)
27 | pe = pe.unsqueeze(0).transpose(0, 1)
28 |
29 | self.register_buffer('pe', pe, persistent=False)
30 |
31 | def forward(self, x, hist_frames=0):
32 | if not self.negative:
33 | center = 0
34 | assert hist_frames == 0
35 | first = 0
36 | else:
37 | center = self.max_len
38 | first = center-hist_frames
39 | if self.batch_first:
40 | last = first + x.shape[1]
41 | x = x + self.pe.permute(1, 0, 2)[:, first:last, :]
42 | else:
43 | last = first + x.shape[0]
44 | x = x + self.pe[first:last, :]
45 | return self.dropout(x)
--------------------------------------------------------------------------------
/src/model/utils/timestep_embed.py:
--------------------------------------------------------------------------------
1 | import select
2 | from torch import nn
3 | import torch
4 | import math
5 |
6 | def get_timestep_embedding(
7 | timesteps: torch.Tensor,
8 | embedding_dim: int,
9 | flip_sin_to_cos: bool = False,
10 | downscale_freq_shift: float = 1,
11 | scale: float = 1,
12 | max_period: int = 10000,
13 | ):
14 | """
15 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
16 |
17 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
18 | These may be fractional.
19 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
20 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
21 | """
22 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
23 |
24 | half_dim = embedding_dim // 2
25 | exponent = -math.log(max_period) * torch.arange(
26 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
27 | )
28 | exponent = exponent / (half_dim - downscale_freq_shift)
29 |
30 | emb = torch.exp(exponent)
31 | emb = timesteps[:, None].float() * emb[None, :]
32 |
33 | # scale embeddings
34 | emb = scale * emb
35 |
36 | # concat sine and cosine embeddings
37 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
38 |
39 | # flip sine and cosine embeddings
40 | if flip_sin_to_cos:
41 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
42 |
43 | # zero pad
44 | if embedding_dim % 2 == 1:
45 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
46 | return emb
47 |
48 | class TimestepEmbedding(nn.Module):
49 | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
50 | super().__init__()
51 |
52 | self.linear_1 = nn.Linear(channel, time_embed_dim)
53 | self.act = None
54 | if act_fn == "silu":
55 | self.act = nn.SiLU()
56 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
57 |
58 | def forward(self, sample):
59 | sample = self.linear_1(sample)
60 |
61 | if self.act is not None:
62 | sample = self.act(sample)
63 |
64 | sample = self.linear_2(sample)
65 | return sample
66 |
67 |
68 | class Timesteps(nn.Module):
69 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
70 | super().__init__()
71 | self.num_channels = num_channels
72 | self.flip_sin_to_cos = flip_sin_to_cos
73 | self.downscale_freq_shift = downscale_freq_shift
74 |
75 | def forward(self, timesteps):
76 | t_emb = get_timestep_embedding(timesteps, self.num_channels,
77 | flip_sin_to_cos=self.flip_sin_to_cos,
78 | downscale_freq_shift=self.downscale_freq_shift,)
79 | return t_emb
80 |
81 |
82 | class TimestepEmbedderMDM(nn.Module):
83 | def __init__(self, latent_dim):
84 | super().__init__()
85 | self.latent_dim = latent_dim
86 | from src.model.utils.positional_encoding import PositionalEncoding
87 |
88 | time_embed_dim = self.latent_dim
89 | self.sequence_pos_encoder = PositionalEncoding(d_model=self.latent_dim)
90 | self.time_embed = nn.Sequential(
91 | nn.Linear(self.latent_dim, time_embed_dim),
92 | nn.SiLU(),
93 | nn.Linear(time_embed_dim, time_embed_dim),
94 | )
95 |
96 | def forward(self, timesteps):
97 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
98 |
--------------------------------------------------------------------------------
/src/model/utils/tools.py:
--------------------------------------------------------------------------------
1 | from src.tools.transforms3d import transform_body_pose
2 | import torch
3 |
4 | def detach_to_numpy(tensor):
5 | return tensor.detach().cpu().numpy()
6 |
7 | def remove_padding(tensors, lengths):
8 | return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)]
9 |
10 | def pack_to_render(rots, trans, pose_repr='6d'):
11 | # make axis-angle
12 | # global_orient = transform_body_pose(rots, f"{pose_repr}->aa")
13 | if pose_repr != 'aa':
14 | body_pose = transform_body_pose(rots, f"{pose_repr}->aa")
15 | else:
16 | body_pose = rots
17 | if trans is None:
18 | trans = torch.zeros((rots.shape[0], rots.shape[1], 3),
19 | device=rots.device)
20 | render_d = {'body_transl': trans,
21 | 'body_orient': body_pose[..., :3],
22 | 'body_pose': body_pose[..., 3:]}
23 | return render_d
--------------------------------------------------------------------------------
/src/model/utils/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def reparameterize(mu, logvar, seed=None):
5 | std = torch.exp(logvar / 2)
6 |
7 | if seed is None:
8 | eps = std.data.new(std.size()).normal_()
9 | else:
10 | generator = torch.Generator(device=mu.device)
11 | generator.manual_seed(seed)
12 | eps = std.data.new(std.size()).normal_(generator=generator)
13 |
14 | return eps.mul(std).add_(mu)
--------------------------------------------------------------------------------
/src/render/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | if 'blender' not in sys.executable:
3 | import sys
4 | from .anim import render_animation
5 | from .mesh_viz import render_motion
6 |
--------------------------------------------------------------------------------
/src/repr_utils/tmr_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import hydra
3 | from pathlib import Path
4 | from tmr_evaluator.motion2motion_retr import read_config
5 | import logging
6 | logger = logging.getLogger(__name__)
7 |
8 | import pytorch_lightning as pl
9 | import numpy as np
10 | from hydra.utils import instantiate
11 | from src.tmr.load_model import load_model_from_cfg
12 | from src.tmr.metrics import all_contrastive_metrics_mot2mot, print_latex_metrics_m2m
13 | from omegaconf import DictConfig
14 | from omegaconf import OmegaConf
15 | import src.launch.prepare
16 | from src.data.tools.collate import collate_batch_last_padding
17 | import torch
18 | from tmr_evaluator.motion2motion_retr import length_to_mask
19 | # Step 1: load TMR model
20 |
21 | def mean_flat(x):
22 | """
23 | Take the mean over all non-batch dimensions.
24 | """
25 | return torch.mean(x, dim=list(range(1, len(x.size()))))
26 |
27 | def masked_loss(loss, mask):
28 | # loss: B, T
29 | # mask: B, T
30 | masked_loss = loss * mask # 只保留 mask 位置的损失
31 |
32 | # 计算平均 masked loss
33 | # 1. 计算有效元素数量,避免除以零
34 | num_valid_elements = mask.sum() # 有效元素的总数
35 | masked_loss_mean = masked_loss.sum() / (num_valid_elements + 1)
36 | return masked_loss_mean
37 |
38 | def load_tmr_model():
39 | protocol = ['normal', 'batches']
40 | device = 'cpu'
41 | run_dir = 'eval-deps'
42 | ckpt_name = 'last'
43 | batch_size = 256
44 |
45 | protocols = protocol
46 | dataset = 'motionfix' # motionfix
47 | sets = 'test' # val all
48 | # save_dir = os.path.join(run_dir, "motionfix/contrastive_metrics")
49 | # os.makedirs(save_dir, exist_ok=True)
50 |
51 | # Load last config
52 | curdir = Path("/depot/bera89/data/li5280/project/motionfix")
53 | cfg = read_config(curdir / run_dir)
54 | logger.info("Loading the evaluation TMR model")
55 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
56 | return model
57 |
58 |
59 |
60 | # STEP 2: load dataset
61 |
62 | def load_testloader(cfg: DictConfig):
63 | # What do you want?
64 | exp_folder = Path("/depot/bera89/data/li5280/project/motionfix/experiments/tmed")
65 | prevcfg = OmegaConf.load(exp_folder / ".hydra/config.yaml")
66 | cfg = OmegaConf.merge(prevcfg, cfg)
67 | data_module = instantiate(cfg.data, amt_only=True,
68 | load_splits=['test', 'val'])
69 | # load the test set and collate it properly
70 | features_to_load = data_module.dataset['test'].load_feats
71 | SMPL_feats = ['body_transl', 'body_pose', 'body_orient']
72 | for feat in SMPL_feats:
73 | if feat not in features_to_load:
74 | data_module.dataset['test'].load_feats.append(feat)
75 | print(features_to_load)
76 |
77 | # TODO: change features of including SMPL features
78 | test_dataset = data_module.dataset['test'] + data_module.dataset['val']
79 | collate_fn = lambda b: collate_batch_last_padding(b, features_to_load)
80 |
81 | testloader = torch.utils.data.DataLoader(test_dataset,
82 | shuffle=False,
83 | num_workers=8,
84 | batch_size=128,
85 | collate_fn=collate_fn)
86 | return testloader
87 |
88 | from src.data.features import _get_body_transl_delta_pelv_infer
89 | def batch_to_smpl(batch, normalizer, mot_from='source'):
90 | # batch: dict
91 | # return: padded batch with lengths
92 | # simple concatenation of the body pose and body transl
93 | # trans_delta, body_pose_6d, global_orient_6d
94 | smpl_keys = ['body_transl', 'body_orient', 'body_pose']
95 | lengths = batch[f'length_{mot_from}']
96 | # T, zeros, trans, global, local
97 | tensor_list = []
98 | trans = batch[f'body_transl_{mot_from}']
99 | body_pose_6d = batch[f'body_pose_{mot_from}']
100 | global_orient_6d = batch[f'body_orient_{mot_from}']
101 | trans_delta = _get_body_transl_delta_pelv_infer(global_orient_6d,
102 | trans) # T, 3, [0] is [0,0,0]
103 | motion_smpl = torch.cat([trans_delta, body_pose_6d,
104 | global_orient_6d], dim=-1)
105 | motion_smpl = normalizer(motion_smpl)
106 | # In some motions, translation is always zeros.
107 | # first 3 dimensions of the first frame should be zeros
108 | # B, T, D
109 | return motion_smpl, lengths
110 | from src.tmr.data.motionfix_loader import Normalizer
111 |
112 | @hydra.main(config_path="/depot/bera89/data/li5280/project/motionfix/configs", config_name="motionfix_eval")
113 | def main(cfg: DictConfig):
114 | model = load_tmr_model()
115 | dataloader = load_testloader(cfg)
116 | normalizer = Normalizer("/depot/bera89/data/li5280/project/motionfix/eval-deps/stats/humanml3d/amass_feats")
117 |
118 | # STEP 3: dataset to SMPL
119 | for batch in dataloader:
120 | source_smpl, lengths = batch_to_smpl(batch, normalizer) # B, T, 135
121 | masks = length_to_mask(lengths, device=source_smpl.device)
122 | in_batch = {'x': source_smpl, 'mask': masks}
123 | res = model.encode_motion(in_batch)
124 | print(res.size())
125 | break
126 | # what we want: original SMPL data, and lengths
127 | # Then will converted model-specific representation
128 |
129 | if __name__ == '__main__':
130 | main()
131 |
--------------------------------------------------------------------------------
/src/tmr/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa
2 | from .temos import TEMOS # noqa
3 | from .tmr import TMR # noqa
4 | from .text_encoder import TextToEmb # noqa
5 |
--------------------------------------------------------------------------------
/src/tmr/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/tmr/data/__init__.py
--------------------------------------------------------------------------------
/src/tmr/load_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 |
6 | import os
7 | import json
8 | from omegaconf import DictConfig, OmegaConf
9 |
10 |
11 | def save_config(cfg: DictConfig) -> str:
12 | path = os.path.join(cfg.run_dir, "config.json")
13 | config = OmegaConf.to_container(cfg, resolve=True)
14 | with open(path, "w") as f:
15 | string = json.dumps(config, indent=4)
16 | f.write(string)
17 | return path
18 |
19 |
20 | def read_config(run_dir: str, return_json=False) -> DictConfig:
21 | path = os.path.join(run_dir, "config.json")
22 | with open(path, "r") as f:
23 | config = json.load(f)
24 | if return_json:
25 | return config
26 | cfg = OmegaConf.create(config)
27 | cfg.run_dir = run_dir
28 | return cfg
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | # split the lightning checkpoint into
34 | # seperate state_dict modules for faster loading
35 | def extract_ckpt(run_dir, ckpt_name="last"):
36 | import torch
37 |
38 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
39 |
40 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
41 | os.makedirs(extracted_path, exist_ok=True)
42 |
43 | new_path_template = os.path.join(extracted_path, "{}.pt")
44 | ckpt_dict = torch.load(ckpt_path)
45 | state_dict = ckpt_dict["state_dict"]
46 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
47 |
48 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
49 | for module_name in module_names:
50 | path = new_path_template.format(module_name)
51 | sub_state_dict = {
52 | ".".join(x.split(".")[1:]): y.cpu()
53 | for x, y in state_dict.items()
54 | if x.split(".")[0] == module_name
55 | }
56 | torch.save(sub_state_dict, path)
57 |
58 |
59 | def load_model(run_dir, **params):
60 | # Load last config
61 | cfg = read_config(run_dir)
62 | cfg.run_dir = run_dir
63 | return load_model_from_cfg(cfg, **params)
64 |
65 |
66 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True):
67 | import torch
68 |
69 | run_dir = cfg.run_dir
70 |
71 | from omegaconf import DictConfig, OmegaConf
72 |
73 | def replace_model_with_tmr(d):
74 | if isinstance(d, DictConfig):
75 | new_dict = {}
76 | for k, v in d.items():
77 | new_key = k.replace('.model.', '.tmr.')
78 | new_value = replace_model_with_tmr(v)
79 | new_dict[new_key] = new_value
80 | return OmegaConf.create(new_dict)
81 | elif isinstance(d, dict):
82 | new_dict = {}
83 | for k, v in d.items():
84 | new_key = k.replace('.model.', '.tmr.')
85 | new_value = replace_model_with_tmr(v)
86 | new_dict[new_key] = new_value
87 | return new_dict
88 | elif isinstance(d, list):
89 | return [replace_model_with_tmr(item) for item in d]
90 | elif isinstance(d, str):
91 | return d.replace('.model.', '.tmr.')
92 | else:
93 | return d
94 | model_conf = replace_model_with_tmr(cfg.model)
95 |
96 | # TODO: see what it is?
97 | # import ipdb;ipdb.set_trace()
98 |
99 | # model_conf = OmegaConf.to_yaml(model_conf_d)
100 | model = hydra.utils.instantiate(model_conf)
101 |
102 | # Loading modules one by one
103 | # motion_encoder / text_encoder / text_decoder
104 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
105 |
106 | if not os.path.exists(pt_path):
107 | print("The extracted model is not found. Split into submodules..")
108 | extract_ckpt(run_dir, ckpt_name)
109 |
110 | for fname in os.listdir(pt_path):
111 | module_name, ext = os.path.splitext(fname)
112 | if ext != ".pt":
113 | continue
114 |
115 | module = getattr(model, module_name, None)
116 | if module is None:
117 | continue
118 |
119 | module_path = os.path.join(pt_path, fname)
120 | state_dict = torch.load(module_path)
121 | module.load_state_dict(state_dict)
122 |
123 | print("Loading previous checkpoint done")
124 | model = model.to(device)
125 | if eval_mode:
126 | model = model.eval()
127 | return model
128 |
129 |
130 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model")
131 | def hydra_load_model(cfg: DictConfig) -> None:
132 | run_dir = cfg.run_dir
133 | ckpt_name = cfg.ckpt
134 | device = cfg.device
135 | eval_mode = cfg.eval_mode
136 | return load_model(run_dir, ckpt_name, device, eval_mode)
137 |
138 |
139 | if __name__ == "__main__":
140 | hydra_load_model()
141 |
--------------------------------------------------------------------------------
/src/tmr/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # For reference
6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
8 | class KLLoss:
9 | def __call__(self, q, p):
10 | mu_q, logvar_q = q
11 | mu_p, logvar_p = p
12 |
13 | log_var_ratio = logvar_q - logvar_p
14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
16 | return div.mean()
17 |
18 | def __repr__(self):
19 | return "KLLoss()"
20 |
21 |
22 | class InfoNCE_with_filtering:
23 | def __init__(self, temperature=0.7, threshold_selfsim=0.8):
24 | self.temperature = temperature
25 | self.threshold_selfsim = threshold_selfsim
26 |
27 | def get_sim_matrix(self, x, y):
28 | x_logits = torch.nn.functional.normalize(x, dim=-1)
29 | y_logits = torch.nn.functional.normalize(y, dim=-1)
30 | sim_matrix = x_logits @ y_logits.T
31 | return sim_matrix
32 |
33 | def __call__(self, x, y, sent_emb=None):
34 | bs, device = len(x), x.device
35 | sim_matrix = self.get_sim_matrix(x, y) / self.temperature
36 |
37 | if sent_emb is not None and self.threshold_selfsim:
38 | # put the threshold value between -1 and 1
39 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1
40 | # Filtering too close values
41 | # mask them by putting -inf in the sim_matrix
42 | selfsim = sent_emb @ sent_emb.T
43 | selfsim_nodiag = selfsim - selfsim.diag().diag()
44 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
45 | sim_matrix[idx] = -torch.inf
46 |
47 | labels = torch.arange(bs, device=device)
48 |
49 | total_loss = (
50 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
51 | ) / 2
52 |
53 | return total_loss
54 |
55 | def __repr__(self):
56 | return f"Constrastive(temp={self.temp})"
57 |
--------------------------------------------------------------------------------
/src/tmr/text_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch
4 | from torch import Tensor
5 | from typing import Dict, List
6 | import torch.nn.functional as F
7 |
8 |
9 | class TextToEmb(nn.Module):
10 | def __init__(
11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu"
12 | ) -> None:
13 | super().__init__()
14 |
15 | self.device = device
16 | from transformers import AutoTokenizer, AutoModel
17 | from transformers import logging
18 |
19 | logging.set_verbosity_error()
20 |
21 | # Tokenizer
22 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
24 |
25 | # Text model
26 | self.text_model = AutoModel.from_pretrained(modelpath)
27 | # Then configure the model
28 | self.text_encoded_dim = self.text_model.config.hidden_size
29 |
30 | if mean_pooling:
31 | self.forward = self.forward_pooling
32 |
33 | # put it in eval mode by default
34 | self.eval()
35 |
36 | # Freeze the weights just in case
37 | for param in self.parameters():
38 | param.requires_grad = False
39 |
40 | self.to(device)
41 |
42 | def train(self, mode: bool = True) -> nn.Module:
43 | # override it to be always false
44 | self.training = False
45 | for module in self.children():
46 | module.train(False)
47 | return self
48 |
49 | @torch.no_grad()
50 | def forward(self, texts: List[str], device=None) -> Dict:
51 | device = device if device is not None else self.device
52 |
53 | squeeze = False
54 | if isinstance(texts, str):
55 | texts = [texts]
56 | squeeze = True
57 |
58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
59 | output = self.text_model(**encoded_inputs.to(device))
60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1)
61 |
62 | if squeeze:
63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]}
64 | else:
65 | x_dict = {"x": output.last_hidden_state, "length": length}
66 | return x_dict
67 |
68 | @torch.no_grad()
69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor:
70 | device = device if device is not None else self.device
71 |
72 | squeeze = False
73 | if isinstance(texts, str):
74 | texts = [texts]
75 | squeeze = True
76 |
77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2
78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
79 | output = self.text_model(**encoded_inputs.to(device))
80 | attention_mask = encoded_inputs["attention_mask"]
81 |
82 | # Mean Pooling - Take attention mask into account for correct averaging
83 | token_embeddings = output["last_hidden_state"]
84 | input_mask_expanded = (
85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
86 | )
87 | sentence_embeddings = torch.sum(
88 | token_embeddings * input_mask_expanded, 1
89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
90 | # Normalize embeddings
91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
92 | if squeeze:
93 | sentence_embeddings = sentence_embeddings[0]
94 | return sentence_embeddings
95 |
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/tools/__init__.py
--------------------------------------------------------------------------------
/src/tools/easyconvert.py:
--------------------------------------------------------------------------------
1 | import src.tools.geometry as geometry
2 |
3 | def nfeats_of(rottype):
4 | if rottype in ["rotvec", "axisangle"]:
5 | return 3
6 | elif rottype in ["rotquat", "quaternion"]:
7 | return 4
8 | elif rottype in ["rot6d", "6drot", "rotation6d"]:
9 | return 6
10 | elif rottype in ["rotmat"]:
11 | return 9
12 | else:
13 | return TypeError("This rotation type doesn't have features.")
14 |
15 |
16 | def axis_angle_to(newtype, rotations):
17 | if newtype in ["matrix"]:
18 | rotations = geometry.axis_angle_to_matrix(rotations)
19 | return rotations
20 | elif newtype in ["rotmat"]:
21 | rotations = geometry.axis_angle_to_matrix(rotations)
22 | rotations = matrix_to("rotmat", rotations)
23 | return rotations
24 | elif newtype in ["rot6d", "6drot", "rotation6d"]:
25 | rotations = geometry.axis_angle_to_matrix(rotations)
26 | rotations = matrix_to("rot6d", rotations)
27 | return rotations
28 | elif newtype in ["rotquat", "quaternion"]:
29 | rotations = geometry.axis_angle_to_quaternion(rotations)
30 | return rotations
31 | elif newtype in ["rotvec", "axisangle"]:
32 | return rotations
33 | else:
34 | raise NotImplementedError
35 |
36 |
37 | def matrix_to(newtype, rotations):
38 | if newtype in ["matrix"]:
39 | return rotations
40 | if newtype in ["rotmat"]:
41 | rotations = rotations.reshape((*rotations.shape[:-2], 9))
42 | return rotations
43 | elif newtype in ["rot6d", "6drot", "rotation6d"]:
44 | rotations = geometry.matrix_to_rotation_6d(rotations)
45 | return rotations
46 | elif newtype in ["rotquat", "quaternion"]:
47 | rotations = geometry.matrix_to_quaternion(rotations)
48 | return rotations
49 | elif newtype in ["rotvec", "axisangle"]:
50 | rotations = geometry.matrix_to_axis_angle(rotations)
51 | return rotations
52 | else:
53 | raise NotImplementedError
54 |
55 |
56 | def to_matrix(oldtype, rotations):
57 | if oldtype in ["matrix"]:
58 | return rotations
59 | if oldtype in ["rotmat"]:
60 | rotations = rotations.reshape((*rotations.shape[:-2], 3, 3))
61 | return rotations
62 | elif oldtype in ["rot6d", "6drot", "rotation6d"]:
63 | rotations = geometry.rotation_6d_to_matrix(rotations)
64 | return rotations
65 | elif oldtype in ["rotquat", "quaternion"]:
66 | rotations = geometry.quaternion_to_matrix(rotations)
67 | return rotations
68 | elif oldtype in ["rotvec", "axisangle"]:
69 | rotations = geometry.axis_angle_to_matrix(rotations)
70 | return rotations
71 | else:
72 | raise NotImplementedError
73 |
--------------------------------------------------------------------------------
/src/tools/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tqdm
3 |
4 |
5 | class LevelsFilter(logging.Filter):
6 | def __init__(self, levels):
7 | self.levels = [getattr(logging, level) for level in levels]
8 |
9 | def filter(self, record):
10 | return record.levelno in self.levels
11 |
12 |
13 | class StreamToLogger(object):
14 | """
15 | Fake file-like stream object that redirects writes to a logger instance.
16 | """
17 | def __init__(self, logger, level):
18 | self.logger = logger
19 | self.level = level
20 | self.linebuf = ''
21 |
22 | def write(self, buf):
23 | for line in buf.rstrip().splitlines():
24 | self.logger.log(self.level, line.rstrip())
25 |
26 | def flush(self):
27 | pass
28 |
29 |
30 | class TqdmLoggingHandler(logging.Handler):
31 | def __init__(self, level=logging.NOTSET):
32 | super().__init__(level)
33 |
34 | def emit(self, record):
35 | try:
36 | msg = self.format(record)
37 | tqdm.tqdm.write(msg)
38 | self.flush()
39 | except Exception:
40 | self.handleError(record)
41 |
--------------------------------------------------------------------------------
/src/tools/runid.py:
--------------------------------------------------------------------------------
1 | #
2 | """
3 | runid util.
4 | Taken from wandb.sdk.lib.runid
5 | """
6 |
7 | import shortuuid # type: ignore
8 |
9 |
10 | def generate_id() -> str:
11 | # ~3t run ids (36**8)
12 | run_gen = shortuuid.ShortUUID(alphabet=list("0123456789abcdefghijklmnopqrstuvwxyz"))
13 | return run_gen.random(8)
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/art_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def rgba(c: str):
3 | from matplotlib import colors as mcolors
4 | return mcolors.to_rgba(c)
5 |
6 | def rgb(c: str):
7 | from matplotlib import colors as mcolors
8 | return mcolors.to_rgb(c)
9 |
10 | color_map = {
11 | 'source_motion': rgba('darkred'),
12 | 'source': rgba('darkred'),
13 | 'target_motion': rgba('olivedrab'),
14 | 'input': rgba('olivedrab'),
15 | 'target': rgba('olivedrab'),
16 | 'generation': rgba('purple'),
17 | 'generated': rgba('steelblue'),
18 | 'denoised': rgba('purple'),
19 | 'noised': rgba('darkgrey'),
20 | }
--------------------------------------------------------------------------------
/src/utils/cherrypick.py:
--------------------------------------------------------------------------------
1 | test_subset_mfix = ['000044', '000050', '000161', '000164', '000169', '000175',
2 | '000320', '000331', '000488', '000496', '000513', '001247',
3 | '001288', '001297', '001345', '001361', '001362', '001372',
4 | '001386', '001221', '001745', '001745', '000604', '001876',
5 | '000763', '001856', '001577', '000469', '000555', '000555',
6 | '000605', '001907', '001687', '001437', '000819', '000161',
7 | '000973']
8 |
9 |
--------------------------------------------------------------------------------
/src/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | test_keyds= ["001895", "001893", "001892", "001888", "001848", "001881",
4 | "001827", "001845", "001858", "001812", "001798", "001755"]
5 |
6 |
7 | keyids_for_testing = ["001892", "001881", "001800", "001790",
8 | "001849", "001860", "001862",
9 | "001641", "001696", "001676", "001690", "001692",
10 | "001762", "001794", "001694", "001812", "001874",
11 | "001883"]
12 |
13 | def split_txt_into_multi_lines(input_str: str, line_length: int):
14 | words = input_str.split(' ')
15 | line_count = 0
16 | split_input = ''
17 | for word in words:
18 | line_count += 1
19 | line_count += len(word)
20 | if line_count > line_length:
21 | split_input += '\n'
22 | line_count = len(word) + 1
23 | split_input += word
24 | split_input += ' '
25 | else:
26 | split_input += word
27 | split_input += ' '
28 |
29 | return split_input
30 |
31 | def regroup_metrics(metrics):
32 | from src.info.joints import smplh_joints
33 | pose_names = smplh_joints[1:23]
34 | dico = {key: val.numpy() for key, val in metrics.items()}
35 |
36 | if "APE_pose" in dico:
37 | APE_pose = dico.pop("APE_pose")
38 | for name, ape in zip(pose_names, APE_pose):
39 | dico[f"APE_pose_{name}"] = ape
40 |
41 | if "APE_joints" in dico:
42 | APE_joints = dico.pop("APE_joints")
43 | for name, ape in zip(smplh_joints, APE_joints):
44 | dico[f"APE_joints_{name}"] = ape
45 |
46 | if "AVE_pose" in dico:
47 | AVE_pose = dico.pop("AVE_pose")
48 | for name, ave in zip(pose_names, AVE_pose):
49 | dico[f"AVE_pose_{name}"] = ave
50 |
51 | if "AVE_joints" in dico:
52 | AVE_joints = dico.pop("AVE_joints")
53 | for name, ape in zip(smplh_joints, AVE_joints):
54 | dico[f"AVE_joints_{name}"] = ave
55 |
56 | return dico
57 |
58 |
59 | def sanitize(dico):
60 | dico = {key: "{:.5f}".format(float(val)) for key, val in dico.items()}
61 | return dico
62 |
63 | def out2blender(dicto):
64 | blender_dic = {}
65 | blender_dic['trans'] = dicto['body_transl']
66 | blender_dic['rots'] = torch.cat([dicto['body_orient'],
67 | dicto['body_pose']], dim=-1)
68 | return blender_dic
--------------------------------------------------------------------------------
/src/utils/genutils.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | import numpy as np
3 | import torch
4 | from torch import Tensor
5 | import logging
6 | import os
7 | import random
8 | from einops import rearrange
9 | from pathlib import Path
10 |
11 | def extract_data_path(full_path, directory_name="data"):
12 | """
13 | Slices the given path up to and including the specified directory.
14 |
15 | Args:
16 | full_path (str): The full path as a string.
17 | directory_name (str): The directory to slice up to (included in the result).
18 |
19 | Returns:
20 | str: The sliced path as a string up to and including the specified directory.
21 | """
22 | path = Path(full_path)
23 | subpath = Path()
24 | for part in path.parts:
25 | subpath /= part
26 | if part == directory_name:
27 | break
28 |
29 | return str(subpath)
30 |
31 |
32 | def freeze(model) -> None:
33 | r"""
34 | Freeze all params for inference.
35 | """
36 | for param in model.parameters():
37 | param.requires_grad = False
38 |
39 | model.eval()
40 |
41 | # A logger for this file
42 | log = logging.getLogger(__name__)
43 | def to_tensor(array):
44 | if torch.is_tensor(array):
45 | return array
46 | else:
47 | return torch.tensor(array)
48 |
49 | def DotDict(in_dict):
50 | if isinstance(in_dict, dotdict):
51 | return in_dict
52 | out_dict = copy(in_dict)
53 | for k,v in out_dict.items():
54 | if isinstance(v,dict):
55 | out_dict[k] = DotDict(v)
56 | return dotdict(out_dict)
57 |
58 |
59 | def dict_to_device(tensor_dict, device):
60 | return {k: v.to(device) for k, v in tensor_dict.items()}
61 |
62 | class dotdict(dict):
63 | """dot.notation access to dictionary attributes"""
64 | __getattr__ = dict.get
65 | __setattr__ = dict.__setitem__
66 | __delattr__ = dict.__delitem__
67 |
68 | def cast_dict_to_tensors(d, device="cpu"):
69 | if isinstance(d, dict):
70 | return {k: cast_dict_to_tensors(v, device) for k, v in d.items()}
71 | elif isinstance(d, np.ndarray):
72 | return torch.from_numpy(d).float().to(device)
73 | elif isinstance(d, torch.Tensor):
74 | return d.to(device)
75 | else:
76 | return d
--------------------------------------------------------------------------------
/src/utils/motionfix_utils.py:
--------------------------------------------------------------------------------
1 | test_subset_amt = ['001247','000038', '000091','000100',
2 | '000122','000133', '000177', '000227',
3 | '000230', '000246', '000289', '000362',
4 | '000522', '000544','000564', '000689',
5 | '000686','000847', '000917', '000947',
6 | '000931', '000899','001327', '001546',
7 | '000432']
8 |
--------------------------------------------------------------------------------
/tmr_evaluator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/tmr_evaluator/__init__.py
--------------------------------------------------------------------------------
/tmr_evaluator/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 |
5 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
6 | """Numpy implementation of the Frechet Distance.
7 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
8 | and X_2 ~ N(mu_2, C_2) is
9 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
10 | Stable version by Dougal J. Sutherland.
11 | Params:
12 | -- mu1 : Numpy array containing the activations of a layer of the
13 | inception net (like returned by the function 'get_predictions')
14 | for generated samples.
15 | -- mu2 : The sample mean over activations, precalculated on an
16 | representative dataset set.
17 | -- sigma1: The covariance matrix over activations for generated samples.
18 | -- sigma2: The covariance matrix over activations, precalculated on an
19 | representative dataset set.
20 | Returns:
21 | -- : The Frechet Distance.
22 | """
23 |
24 | mu1 = np.atleast_1d(mu1)
25 | mu2 = np.atleast_1d(mu2)
26 |
27 | sigma1 = np.atleast_2d(sigma1)
28 | sigma2 = np.atleast_2d(sigma2)
29 |
30 | assert mu1.shape == mu2.shape, \
31 | 'Training and test mean vectors have different lengths'
32 | assert sigma1.shape == sigma2.shape, \
33 | 'Training and test covariances have different dimensions'
34 |
35 | diff = mu1 - mu2
36 |
37 | # Product might be almost singular
38 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
39 | if not np.isfinite(covmean).all():
40 | msg = ('fid calculation produces singular product; '
41 | 'adding %s to diagonal of cov estimates') % eps
42 | print(msg)
43 | offset = np.eye(sigma1.shape[0]) * eps
44 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
45 |
46 | # Numerical error might give slight imaginary component
47 | if np.iscomplexobj(covmean):
48 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
49 | m = np.max(np.abs(covmean.imag))
50 | raise ValueError('Imaginary component {}'.format(m))
51 | covmean = covmean.real
52 |
53 | tr_covmean = np.trace(covmean)
54 |
55 | return (diff.dot(diff) + np.trace(sigma1) +
56 | np.trace(sigma2) - 2 * tr_covmean)
57 |
58 | def calculate_activation_statistics(activations):
59 | """
60 | Params:
61 | -- activation: num_samples x dim_feat
62 | Returns:
63 | -- mu: dim_feat
64 | -- sigma: dim_feat x dim_feat
65 | """
66 | mu = np.mean(activations, axis=0)
67 | cov = np.cov(activations, rowvar=False)
68 | return mu, cov
69 |
70 | def evaluate_fid(gt_motion_embeddings, gen_motion_embeddings):
71 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
72 | mu, cov = calculate_activation_statistics(gen_motion_embeddings)
73 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
74 | print(f'FID: {fid:.4f}')
75 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/utils/__init__.py
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseMask(object):
5 | @property
6 | def bool_matrix(self):
7 | """Return a bool (uint8) matrix with 1s to all places that should be
8 | kept."""
9 | raise NotImplementedError()
10 |
11 | @property
12 | def float_matrix(self):
13 | """Return the bool matrix as a float to be used as a multiplicative
14 | mask for non softmax attentions."""
15 | if not hasattr(self, "_float_matrix"):
16 | with torch.no_grad():
17 | self._float_matrix = self.bool_matrix.float()
18 | return self._float_matrix
19 |
20 | @property
21 | def lengths(self):
22 | """If the matrix is of the following form
23 |
24 | 1 1 1 0 0 0 0
25 | 1 0 0 0 0 0 0
26 | 1 1 0 0 0 0 0
27 |
28 | then return it as a vector of integers
29 |
30 | 3 1 2.
31 | """
32 | if not hasattr(self, "_lengths"):
33 | with torch.no_grad():
34 | lengths = self.bool_matrix.long().sum(dim=-1)
35 | # make sure that the mask starts with 1s and continues with 0s
36 | # this should be changed to something more efficient, however,
37 | # I chose simplicity over efficiency since the LengthMask class
38 | # will be used anyway (and the result is cached)
39 | m = self.bool_matrix.view(-1, self.shape[-1])
40 | for i, l in enumerate(lengths.view(-1)):
41 | if not torch.all(m[i, :l]):
42 | raise ValueError("The mask is not a length mask")
43 | self._lengths = lengths
44 | return self._lengths
45 |
46 | @property
47 | def shape(self):
48 | """Return the shape of the boolean mask."""
49 | return self.bool_matrix.shape
50 |
51 | @property
52 | def additive_matrix(self):
53 | """Return a float matrix to be added to an attention matrix before
54 | softmax."""
55 | if not hasattr(self, "_additive_matrix"):
56 | with torch.no_grad():
57 | self._additive_matrix = torch.log(self.bool_matrix.float())
58 | return self._additive_matrix
59 |
60 | @property
61 | def additive_matrix_finite(self):
62 | """Same as additive_matrix but with -1e24 instead of infinity."""
63 | if not hasattr(self, "_additive_matrix_finite"):
64 | with torch.no_grad():
65 | self._additive_matrix_finite = (
66 | (~self.bool_matrix).float() * (-1e24)
67 | )
68 | return self._additive_matrix_finite
69 |
70 | @property
71 | def all_ones(self):
72 | """Return true if the mask is all ones."""
73 | if not hasattr(self, "_all_ones"):
74 | with torch.no_grad():
75 | self._all_ones = torch.all(self.bool_matrix)
76 | return self._all_ones
77 |
78 | @property
79 | def lower_triangular(self):
80 | """Return true if the attention is a triangular causal mask."""
81 | if not hasattr(self, "_lower_triangular"):
82 | self._lower_triangular = False
83 | with torch.no_grad():
84 | try:
85 | lengths = self.lengths
86 | if len(lengths.shape) == 1:
87 | target = torch.arange(
88 | 1,
89 | len(lengths)+1,
90 | device=lengths.device
91 | )
92 | self._lower_triangular = torch.all(lengths == target)
93 | except ValueError:
94 | pass
95 | return self._lower_triangular
96 |
97 | class LengthMask(BaseMask):
98 | """Provide a BaseMask interface for lengths. Mostly to be used with
99 | sequences of different lengths.
100 |
101 | Arguments
102 | ---------
103 | lengths: The lengths as a PyTorch long tensor
104 | max_len: The maximum length for the mask (defaults to lengths.max())
105 | device: The device to be used for creating the masks (defaults to
106 | lengths.device)
107 | """
108 | def __init__(self, lengths, max_len=None, device=None):
109 | self._device = device or lengths.device
110 | with torch.no_grad():
111 | self._lengths = lengths.clone().to(self._device)
112 | self._max_len = max_len or self._lengths.max()
113 |
114 | self._bool_matrix = None
115 | self._all_ones = torch.all(self._lengths == self._max_len).item()
116 |
117 | @property
118 | def bool_matrix(self):
119 | if self._bool_matrix is None:
120 | with torch.no_grad():
121 | indices = torch.arange(self._max_len, device=self._device)
122 | self._bool_matrix = (
123 | indices.view(1, -1) < self._lengths.view(-1, 1)
124 | )
125 | return self._bool_matrix
126 |
127 |
--------------------------------------------------------------------------------