├── LICENSE ├── LICENSE-MF ├── README.md ├── assets ├── editing.gif └── simcurve.gif ├── compute_metrics.py ├── configs ├── callback │ ├── base.yaml │ ├── best_checkpoint.yaml │ ├── last_checkpoint.yaml │ ├── latest_checkpoint.yaml │ ├── lr_logging.yaml │ ├── progress.yaml │ └── render.yaml ├── compute_metrics.yaml ├── data │ └── motionfix.yaml ├── evaluator │ └── edit_evaluator.yaml ├── hydra │ ├── hydra_logging │ │ ├── console.yaml │ │ ├── custom.yaml │ │ └── rich.yaml │ └── job_logging │ │ ├── console.yaml │ │ ├── custom.yaml │ │ └── rich.yaml ├── logger │ ├── none.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── machine │ └── server.yaml ├── model │ ├── basic_clip.yaml │ ├── basic_clip_cls.yaml │ ├── basic_clip_cls_arch.yaml │ ├── denoiser │ │ ├── denoiser.yaml │ │ ├── ditdenoiser_cls.yaml │ │ └── ditdenoiser_cls_arch.yaml │ ├── infer_scheduler │ │ ├── ddim.yaml │ │ └── ddpm.yaml │ ├── losses │ │ ├── basic.yaml │ │ └── function │ │ │ ├── kl.yaml │ │ │ ├── klmulti.yaml │ │ │ ├── mse.yaml │ │ │ ├── recons.yaml │ │ │ ├── recons_bp.yaml │ │ │ └── smoothL1.yaml │ ├── motion_condition_encoder │ │ └── actor.yaml │ ├── optim │ │ └── adamw.yaml │ ├── text_encoder │ │ ├── clipenc.yaml │ │ ├── clipenc_sim.yaml │ │ ├── distilbert_enc.yaml │ │ └── t5_enc.yaml │ └── train_scheduler │ │ ├── ddim.yaml │ │ └── ddpm.yaml ├── motionfix_eval.yaml ├── path.yaml ├── sampler │ ├── all_conseq.yaml │ ├── fix_conseq.yaml │ ├── upper_bound_variable_conseq.yaml │ └── variable_conseq.yaml ├── stats.yaml ├── train.yaml ├── train_cls.yaml ├── train_cls_arch.yaml └── trainer │ ├── base-longer.yaml │ └── base.yaml ├── deps ├── gpt │ ├── edit │ │ ├── gpt-labels_full.json │ │ └── gpt-labels_full_v2.json │ └── gpt3-labels-list.json ├── inference │ ├── labels.json │ ├── labels_train_spatial.json │ ├── labels_val_spatial.json │ ├── qual.txt │ ├── smpl_part_seg.json │ ├── text-annot_val_spatial-pairs.json │ ├── texts_pairs_train.json │ └── texts_pairs_val.json ├── smplh │ └── smpl.faces └── stats │ └── statistics_motionfix.npy ├── gpt_parts ├── joint_utils.py └── prompts.py ├── motionfix_evaluate.py ├── requirements.txt ├── scripts ├── download_data.sh └── install.sh ├── src ├── __init__.py ├── callback │ ├── __init__.py │ ├── progress.py │ └── render.py ├── data │ ├── __init__.py │ ├── base.py │ ├── comp.py │ ├── features.py │ ├── motionfix.py │ ├── sampling │ │ ├── __init__.py │ │ ├── base.py │ │ ├── custom_batch_sampler.py │ │ ├── framerate.py │ │ └── frames.py │ └── tools │ │ ├── __init__.py │ │ ├── amass_utils.py │ │ ├── collate.py │ │ ├── contacts.py │ │ ├── extract_pairs.py │ │ ├── rotation_transformation.py │ │ ├── smpl.py │ │ ├── spatiotempo.py │ │ ├── tensors.py │ │ └── utils.py ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── gaussian_diffusion.py │ ├── respace.py │ └── timestep_sampler.py ├── evaluator │ └── evaluate_edits.py ├── info │ └── joints.py ├── launch │ ├── blender.py │ └── prepare.py ├── logger │ ├── __init__.py │ ├── tools.py │ └── wandb_log.py ├── model │ ├── DiT_denoiser.py │ ├── DiT_denoiser_cls.py │ ├── DiT_denoiser_cls_arch.py │ ├── DiT_models.py │ ├── __init__.py │ ├── base.py │ ├── base_diffusion.py │ ├── dummy.py │ ├── losses │ │ ├── __init__.py │ │ ├── compute.py │ │ ├── compute_mld.py │ │ ├── kl.py │ │ ├── recons.py │ │ ├── recons_bp.py │ │ └── utils.py │ ├── metrics │ │ ├── __init__.py │ │ └── compute.py │ ├── motiondecoder │ │ ├── __init__.py │ │ └── actor.py │ ├── motionencoder │ │ ├── __init__.py │ │ └── actor.py │ ├── readme.txt │ ├── textencoder │ │ ├── __init__.py │ │ ├── clip_encoder.py │ │ ├── distilbert.py │ │ ├── distilbert_encoder.py │ │ ├── t5_encoder.py │ │ └── text_space.py │ ├── tmed_denoiser.py │ ├── tmr_utils │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── losses.py │ │ ├── metrics.py │ │ ├── temos.py │ │ ├── text_encoder.py │ │ ├── tmr.py │ │ └── utils.py │ ├── trans_enc.py │ ├── transformer_encoder │ │ ├── encoder.py │ │ ├── encoder_layer.py │ │ ├── feed_forward.py │ │ └── note.md │ └── utils │ │ ├── __init__.py │ │ ├── all_positional_encodings.py │ │ ├── body_parts.py │ │ ├── lr_scheduler.py │ │ ├── positional_encoding.py │ │ ├── smpl_fast.py │ │ ├── timestep_embed.py │ │ ├── tools.py │ │ ├── transf_utils.py │ │ └── vae.py ├── render │ ├── __init__.py │ ├── anim.py │ ├── mesh_viz.py │ └── video.py ├── repr_utils │ └── tmr_utils.py ├── tmr │ ├── __init__.py │ ├── actor.py │ ├── data │ │ ├── __init__.py │ │ └── motionfix_loader.py │ ├── load_model.py │ ├── losses.py │ ├── metrics.py │ ├── temos.py │ ├── text_encoder.py │ └── tmr.py ├── tools │ ├── __init__.py │ ├── easyconvert.py │ ├── frank.py │ ├── geometry.py │ ├── interpolation.py │ ├── logging.py │ ├── runid.py │ └── transforms3d.py └── utils │ ├── __init__.py │ ├── art_utils.py │ ├── cherrypick.py │ ├── eval_utils.py │ ├── file_io.py │ ├── genutils.py │ ├── inference.py │ ├── mesh_utils.py │ ├── motionfix_utils.py │ ├── nlp_consts.py │ ├── smpl_body_utils.py │ └── text_constants.py ├── tmr_evaluator ├── __init__.py ├── fid.py ├── motion2motion_retr.py ├── retr_ori.py └── text2motion_retr.py ├── train.py └── utils ├── __init__.py ├── masking.py ├── misc.py └── transformations.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zhengyuan Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-MF: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nikos Athanasiou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | SimMotionEdit: Text-Based Human Motion Editing with Motion Similarity Prediction 3 |

4 | Zhengyuan Li  5 | Kai Cheng  6 | Anindita Ghosh  7 | Uttaran Bhattacharya  8 | Liang-Yan Gui  9 | Aniket Bera  10 |
11 | Purdue University, DFKI, MPI-INF, 12 |
13 | Adobe Inc., University of Illinois Urbana-Champaign 14 |
15 | CVPR 2025 16 |

17 |

18 | 19 |

20 | 21 | 22 | 23 | 24 | 25 | 26 |

27 | 28 |
29 | 30 | | Motion Similarity | Text-based Motion Editing | 31 | |-------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------| 32 | | | | 33 | 34 |
35 | 36 | ## Environment Setup 37 | Please follow [motionfix](https://github.com/atnikos/motionfix) to download the dataset and set up the environment. 38 | 39 | We provide pretrained model and pre-processed similarity in [this link](https://drive.google.com/drive/folders/1LjiKVjDHqOEnykZMP3ZiTJEz_EY-TqC3?usp=sharing). 40 | 41 | Constant setup: set `preproc.sim_file` to the path of preprocessed similarity file in `./configs/data/motionfix.yaml`. 42 | 43 | ## Evaluation 44 | 45 | #### Step 1: Extract the samples 46 | ```bash 47 | python motionfix_evaluate.py folder=/path/to/exp/ guidance_scale_text_n_motion=2.0 guidance_scale_motion=2.0 data=motionfix 48 | ``` 49 | 50 | #### Step 2: Compute the metrics 51 | ```bash 52 | python compute_metrics.py folder=/path/to/exp/samples/npys 53 | ``` 54 | 55 | ## Training 56 | ```bash 57 | python -u train.py --config-name="train_cls_arch" experiment=cls_arch run_id=no_text 58 | ``` 59 | 60 | ## Acknowledgements 61 | Our code is based on: [motionfix](https://github.com/atnikos/motionfix). 62 | 63 | ## License 64 | This code is distributed under an MIT LICENSE. We also include the LICENSE of motionfix in our repo. Other third-party datasets and software are subject to their respective licenses. 65 | 66 | ## Citation 67 | You can cite this paper using: 68 | 69 | ```bibtex 70 | @article{li2025simmotionedittextbasedhumanmotion, 71 | title={SimMotionEdit: Text-Based Human Motion Editing with Motion Similarity Prediction}, 72 | author={Zhengyuan Li and Kai Cheng and Anindita Ghosh and Uttaran Bhattacharya and Liangyan Gui and Aniket Bera}, 73 | year={2025}, 74 | eprint={2503.18211}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.CV} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /assets/editing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/assets/editing.gif -------------------------------------------------------------------------------- /assets/simcurve.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/assets/simcurve.gif -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | import logging 3 | import hydra 4 | import torch 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | 10 | def collect_gen_samples(motion_gen_path, normalizer, device): 11 | cur_samples = {} 12 | cur_samples_raw = {} 13 | # it becomes from 14 | # translation | root_orient | rots --> trans | rots | root_orient 15 | print("Collecting Generated Samples") 16 | from src.data.features import _get_body_transl_delta_pelv_infer 17 | import glob 18 | 19 | sample_files = glob.glob(f'{motion_gen_path}/*.npy') 20 | for fname in tqdm(sample_files): 21 | keyid = str(Path(fname).name).replace('.npy', '') 22 | gen_motion_b = np.load(fname, 23 | allow_pickle=True).item()['pose'] 24 | gen_motion_b = torch.from_numpy(gen_motion_b) 25 | trans = gen_motion_b[..., :3] 26 | global_orient_6d = gen_motion_b[..., 3:9] 27 | body_pose_6d = gen_motion_b[..., 9:] 28 | trans_delta = _get_body_transl_delta_pelv_infer(global_orient_6d, 29 | trans) 30 | gen_motion_b_fixed = torch.cat([trans_delta, body_pose_6d, 31 | global_orient_6d], dim=-1) 32 | gen_motion_b_fixed = normalizer(gen_motion_b_fixed) 33 | cur_samples[keyid] = gen_motion_b_fixed.to(device) 34 | cur_samples_raw[keyid] = torch.cat([trans, global_orient_6d, 35 | body_pose_6d], dim=-1).to(device) 36 | return cur_samples, cur_samples_raw 37 | 38 | 39 | @hydra.main(config_path="configs", version_base="1.2", config_name="compute_metrics") 40 | def _compute_metrics(cfg: DictConfig): 41 | return compute_metrics(cfg) 42 | 43 | def compute_metrics(newcfg: DictConfig) -> None: 44 | from tmr_evaluator.motion2motion_retr import retrieval 45 | from pathlib import Path 46 | samples_folder = newcfg.folder 47 | metrs_batches, metrs_full = retrieval(samples_folder) 48 | print("\n===== Metrics for Retrieval on Batches of 32 =====") 49 | print(metrs_batches) 50 | print("\n===== Metrics for Retrieval on the full test set =====") 51 | print( metrs_full) 52 | 53 | if __name__ == '__main__': 54 | _compute_metrics() 55 | 56 | -------------------------------------------------------------------------------- /configs/callback/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /callback/last_checkpoint@last_ckpt 3 | - /callback/latest_checkpoint@latest_ckpt 4 | - /callback/progress@progress 5 | - /callback/render@render 6 | - /callback/lr_logging@lr_logging -------------------------------------------------------------------------------- /configs/callback/best_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 2 | dirpath: ${path.working_dir}/checkpoints 3 | filename: best 4 | monitor: "Metrics/APE_root" 5 | mode: min 6 | every_n_epochs: 1 7 | save_top_k: 1 8 | -------------------------------------------------------------------------------- /configs/callback/last_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 2 | dirpath: ${path.working_dir}/checkpoints 3 | 4 | filename: latest-{epoch} 5 | every_n_epochs: 1 6 | save_top_k: 1 7 | save_last: true -------------------------------------------------------------------------------- /configs/callback/latest_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 2 | dirpath: ${path.working_dir}/checkpoints 3 | filename: latest-{epoch} 4 | monitor: step 5 | mode: max 6 | every_n_epochs: 100 7 | save_top_k: -1 8 | save_last: false 9 | -------------------------------------------------------------------------------- /configs/callback/lr_logging.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 2 | # proper way to log LR but does not work with the custom logger thing ... 3 | logging_interval: epoch 4 | -------------------------------------------------------------------------------- /configs/callback/progress.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.callback.ProgressLogger 2 | -------------------------------------------------------------------------------- /configs/callback/render.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.callback.RenderCallback 2 | 3 | every_n_epochs: 50 4 | num_workers: ${machine.num_workers} 5 | save_last: true 6 | nvids_to_save: 3 7 | bm_path: ${path.data} 8 | modelname: ${model.modelname} -------------------------------------------------------------------------------- /configs/compute_metrics.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: . 4 | job: 5 | chdir: true 6 | output_subdir: null 7 | 8 | folder: ??? 9 | -------------------------------------------------------------------------------- /configs/data/motionfix.yaml: -------------------------------------------------------------------------------- 1 | dataname: motionfix 2 | _target_: src.data.motionfix.MotionFixDataModule 3 | debug: ${debug} 4 | 5 | datapath: ${path.data}/motionfix-dataset/motionfix.pth.tar # amass_bodilex.pth.tar 6 | 7 | # Amass 8 | smplh_path: ${path.data}/body_models 9 | smplh_path_dbg: ${path.minidata}/body_models 10 | 11 | load_with_rot: true 12 | 13 | load_splits: 14 | - "train" 15 | - "val" 16 | - "test" 17 | 18 | proportion: 1.0 19 | text_augment: false 20 | 21 | # Machine 22 | batch_size: ${machine.batch_size} # it's tiny 23 | num_workers: ${machine.num_workers} 24 | rot_repr: '6d' 25 | preproc: 26 | stats_file: ${path.deps}/stats/statistics_${data.dataname}.npy # full path for statistics 27 | split_seed: 0 28 | calculate_minmax: True 29 | generate_joint_files: True 30 | use_cuda: True 31 | n_body_joints: 22 32 | norm_type: std # norm or std 33 | sim_file: PATH/TO/SIM 34 | 35 | # Motion 36 | framerate: 30 37 | sampler: ${sampler} 38 | 39 | load_feats: 40 | - "body_transl_delta_pelv" 41 | - "body_orient_xy" 42 | - "z_orient_delta" 43 | - "body_pose" 44 | - "body_joints_local_wo_z_rot" 45 | - "body_transl" 46 | - "body_orient" 47 | # Other 48 | progress_bar: true 49 | -------------------------------------------------------------------------------- /configs/evaluator/edit_evaluator.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.evaluator.evaluate_edits.MotionEditEvaluator 2 | metrics_to_eval: 3 | # - 'foot_skating' 4 | - 'loc_pres' 5 | - 'glo_pres' 6 | - 'lc_pre_gt' 7 | - 'gb_pre_gt' 8 | - 'loc_edit' 9 | - 'glob_edit' 10 | smplh_path: ${path.data}/body_models 11 | -------------------------------------------------------------------------------- /configs/hydra/hydra_logging/console.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): sinc.tools.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): sinc.tools.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | root: 42 | level: ${logger_level} 43 | handlers: 44 | - console 45 | 46 | disable_existing_loggers: false 47 | -------------------------------------------------------------------------------- /configs/hydra/hydra_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | colorlog: 5 | (): colorlog.ColoredFormatter 6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 7 | datefmt: '%d/%m/%y %H:%M:%S' 8 | 9 | handlers: 10 | console: 11 | class: logging.StreamHandler 12 | formatter: colorlog 13 | stream: ext://sys.stdout 14 | 15 | root: 16 | level: INFO 17 | handlers: 18 | - console 19 | 20 | disable_existing_loggers: false 21 | -------------------------------------------------------------------------------- /configs/hydra/hydra_logging/rich.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | colorlog: 5 | (): colorlog.ColoredFormatter 6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 7 | datefmt: '%d/%m/%y %H:%M:%S' 8 | 9 | handlers: 10 | console: 11 | class: rich.logging.RichHandler # logging.StreamHandler 12 | formatter: colorlog 13 | 14 | root: 15 | level: INFO 16 | handlers: 17 | - console 18 | 19 | disable_existing_loggers: false -------------------------------------------------------------------------------- /configs/hydra/job_logging/console.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): sinc.tools.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): sinc.tools.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | root: 42 | level: ${logger_level} 43 | handlers: 44 | - console 45 | 46 | disable_existing_loggers: false 47 | -------------------------------------------------------------------------------- /configs/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): sinc.tools.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): sinc.tools.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | file_out: 42 | class: logging.FileHandler 43 | formatter: simple 44 | filename: logs.out 45 | filters: 46 | - noimportant 47 | 48 | file_err: 49 | class: logging.FileHandler 50 | formatter: simple 51 | filename: logs.err 52 | filters: 53 | - onlyimportant 54 | 55 | root: 56 | level: ${logger_level} 57 | handlers: 58 | - console 59 | - file_out 60 | - file_err 61 | 62 | disable_existing_loggers: false 63 | -------------------------------------------------------------------------------- /configs/hydra/job_logging/rich.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): src.tools.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): src.tools.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | verysimple: 19 | format: '%(message)s' 20 | simple: 21 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 22 | datefmt: '%d/%m/%y %H:%M:%S' 23 | 24 | colorlog: 25 | (): colorlog.ColoredFormatter 26 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 27 | - %(message)s' 28 | datefmt: '%d/%m/%y %H:%M:%S' 29 | 30 | log_colors: 31 | DEBUG: purple 32 | INFO: green 33 | WARNING: yellow 34 | ERROR: red 35 | CRITICAL: red 36 | 37 | handlers: 38 | console: 39 | class: rich.logging.RichHandler # logging.StreamHandler 40 | formatter: verysimple # colorlog 41 | rich_tracebacks: true 42 | 43 | file_out: 44 | class: logging.FileHandler 45 | formatter: simple 46 | filename: logs.out 47 | filters: 48 | - noimportant 49 | 50 | file_err: 51 | class: logging.FileHandler 52 | formatter: simple 53 | filename: logs.err 54 | filters: 55 | - onlyimportant 56 | 57 | root: 58 | level: ${logger_level} 59 | handlers: 60 | - console 61 | - file_out 62 | - file_err 63 | 64 | disable_existing_loggers: false -------------------------------------------------------------------------------- /configs/logger/none.yaml: -------------------------------------------------------------------------------- 1 | logger_name: null 2 | version: ${run_id} 3 | 4 | project: null -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | logger_name: tensorboard 2 | # version: ${run_id} 3 | 4 | save_dir: tensorboard 5 | version: "" 6 | log_graph: false 7 | default_hp_metric: false 8 | project: null -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | logger_name: wandb 2 | name: ${experiment}-${run_id} 3 | 4 | save_dir: wandb 5 | project: ${project} 6 | group: null 7 | tags: null 8 | notes: null 9 | 10 | offline: false 11 | resume: false 12 | save_code: false 13 | log_model: false 14 | 15 | # wandb.watch(model, log=cfg.log, log_freq=cfg.log_freq) 16 | -------------------------------------------------------------------------------- /configs/machine/server.yaml: -------------------------------------------------------------------------------- 1 | name: server 2 | 3 | # specific attributes to this machine 4 | batch_size: 128 5 | smpl_batch_size: 64 6 | num_workers: 16 7 | 8 | # Specific attributes for training 9 | -------------------------------------------------------------------------------- /configs/model/basic_clip.yaml: -------------------------------------------------------------------------------- 1 | modelname: basic_clip 2 | _target_: src.model.base_diffusion.MD 3 | 4 | latent_dim: 768 5 | 6 | ff_size: 1024 7 | num_layers: 9 8 | num_head: 8 9 | droupout: 0.1 10 | activation: "gelu" 11 | render_vids_every_n_epochs: 100 12 | num_vids_to_render: 2 13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr 14 | 15 | zero_len_source: false 16 | old_way: false 17 | # normalization 18 | statistics_path: ${statistics_path} 19 | norm_type: standardize # min_max standardize 20 | 21 | # diffusion related 22 | diff_params: 23 | num_inference_timesteps: 200 24 | num_train_timesteps: 300 25 | prob_uncondp: 0.05 # 0.1 0.25 26 | prob_drop_text: 0.05 # 0.1 0.25 27 | prob_drop_motion: 0.05 28 | guidance_scale_text: 2.5 # 29 | guidance_scale_motion: 2.0 30 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 31 | predict_type: 'sample' # noise 32 | 33 | motion_condition: source 34 | source_encoder: null # trans_enc 35 | condition: text 36 | smpl_path: ${data.smplh_path} 37 | copy_target: false 38 | nfeats: 135 39 | dim_per_feat: [135] 40 | # data related 41 | 42 | input_feats: 43 | # - "body_transl_delta" 44 | - "body_transl_delta_pelv" 45 | # - "body_transl_delta_pelv_xy_wo_z" 46 | # - "body_transl_z" 47 | # - "body_transl" 48 | - "body_orient_xy" 49 | - "z_orient_delta" 50 | - "body_pose" 51 | - "body_joints_local_wo_z_rot" 52 | 53 | pad_inputs: false 54 | 55 | loss_func_pos: mse # l1 mse 56 | loss_func_feats: mse # l1 mse 57 | 58 | defaults: 59 | # diffusion stuff 60 | - _self_ 61 | - /path@path 62 | - train_scheduler: ddpm 63 | - infer_scheduler: ddpm 64 | - denoiser: denoiser_larger # ditdenoiser 65 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc 66 | - motion_condition_encoder: actor 67 | - losses: basic 68 | - optim: adamw 69 | - /model/losses/function/recons@func_recons 70 | - /model/losses/function/recons@func_latent 71 | - /model/losses/function/kl@func_kl 72 | -------------------------------------------------------------------------------- /configs/model/basic_clip_cls.yaml: -------------------------------------------------------------------------------- 1 | modelname: basic_clip_repr 2 | _target_: src.model.base_diffusion.MD 3 | 4 | latent_dim: 512 5 | 6 | ff_size: 1024 7 | num_layers: 9 8 | num_head: 8 9 | droupout: 0.1 10 | activation: "gelu" 11 | render_vids_every_n_epochs: 100 12 | num_vids_to_render: 2 13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr 14 | 15 | zero_len_source: false 16 | old_way: false 17 | # normalization 18 | statistics_path: ${statistics_path} 19 | norm_type: standardize # min_max standardize 20 | use_cls: True 21 | use_repr: False 22 | use_v_weight: True 23 | source_align_coef: 0 24 | target_align_coef: 0 25 | cls_coef: 0.05 26 | n_cls: 3 27 | target_align_depth: 6 28 | 29 | # diffusion related 30 | diff_params: 31 | num_inference_timesteps: 200 32 | num_train_timesteps: 300 33 | prob_uncondp: 0.05 # 0.1 0.25 34 | prob_drop_text: 0.05 # 0.1 0.25 35 | prob_drop_motion: 0.05 36 | guidance_scale_text: 2.5 # 37 | guidance_scale_motion: 2.0 38 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 39 | predict_type: 'sample' # noise 40 | 41 | motion_condition: source 42 | source_encoder: null # trans_enc 43 | condition: text 44 | smpl_path: ${data.smplh_path} 45 | copy_target: false 46 | nfeats: 135 47 | dim_per_feat: [135] 48 | # data related 49 | 50 | input_feats: 51 | # - "body_transl_delta" 52 | - "body_transl_delta_pelv" 53 | # - "body_transl_delta_pelv_xy_wo_z" 54 | # - "body_transl_z" 55 | # - "body_transl" 56 | - "body_orient_xy" 57 | - "z_orient_delta" 58 | - "body_pose" 59 | - "body_joints_local_wo_z_rot" 60 | 61 | pad_inputs: false 62 | 63 | loss_func_pos: mse # l1 mse 64 | loss_func_feats: mse # l1 mse 65 | 66 | defaults: 67 | # diffusion stuff 68 | - _self_ 69 | - /path@path 70 | - train_scheduler: ddpm 71 | - infer_scheduler: ddpm 72 | - denoiser: ditdenoiser_cls # ditdenoiser 73 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc 74 | - motion_condition_encoder: actor 75 | - losses: basic 76 | - optim: adamw 77 | - /model/losses/function/recons@func_recons 78 | - /model/losses/function/recons@func_latent 79 | - /model/losses/function/kl@func_kl 80 | -------------------------------------------------------------------------------- /configs/model/basic_clip_cls_arch.yaml: -------------------------------------------------------------------------------- 1 | modelname: basic_clip_repr 2 | _target_: src.model.base_diffusion.MD 3 | 4 | latent_dim: 512 5 | 6 | ff_size: 1024 7 | num_layers: 9 8 | num_head: 8 9 | droupout: 0.1 10 | activation: "gelu" 11 | render_vids_every_n_epochs: 100 12 | num_vids_to_render: 2 13 | lr_scheduler: null # cosine # null # reduceonplateau, steplr 14 | 15 | zero_len_source: false 16 | old_way: false 17 | # normalization 18 | statistics_path: ${statistics_path} 19 | norm_type: standardize # min_max standardize 20 | use_cls: True 21 | use_repr: False 22 | use_v_weight: True 23 | source_align_coef: 0 24 | target_align_coef: 0 25 | cls_coef: 0.05 26 | n_cls: 3 27 | target_align_depth: 6 28 | 29 | # diffusion related 30 | diff_params: 31 | num_inference_timesteps: 200 32 | num_train_timesteps: 300 33 | prob_uncondp: 0.05 # 0.1 0.25 34 | prob_drop_text: 0.05 # 0.1 0.25 35 | prob_drop_motion: 0.05 36 | guidance_scale_text: 2.5 # 37 | guidance_scale_motion: 2.0 38 | noise_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 39 | predict_type: 'sample' # noise 40 | 41 | motion_condition: source 42 | source_encoder: null # trans_enc 43 | condition: text 44 | smpl_path: ${data.smplh_path} 45 | copy_target: false 46 | nfeats: 135 47 | dim_per_feat: [135] 48 | # data related 49 | 50 | input_feats: 51 | # - "body_transl_delta" 52 | - "body_transl_delta_pelv" 53 | # - "body_transl_delta_pelv_xy_wo_z" 54 | # - "body_transl_z" 55 | # - "body_transl" 56 | - "body_orient_xy" 57 | - "z_orient_delta" 58 | - "body_pose" 59 | - "body_joints_local_wo_z_rot" 60 | 61 | pad_inputs: false 62 | 63 | loss_func_pos: mse # l1 mse 64 | loss_func_feats: mse # l1 mse 65 | 66 | defaults: 67 | # diffusion stuff 68 | - _self_ 69 | - /path@path 70 | - train_scheduler: ddpm 71 | - infer_scheduler: ddpm 72 | - denoiser: ditdenoiser_cls_arch # ditdenoiser 73 | - text_encoder: clipenc_sim #clipenc #t5_enc # distilbert_enc # t5_enc 74 | - motion_condition_encoder: actor 75 | - losses: basic 76 | - optim: adamw 77 | - /model/losses/function/recons@func_recons 78 | - /model/losses/function/recons@func_latent 79 | - /model/losses/function/kl@func_kl 80 | -------------------------------------------------------------------------------- /configs/model/denoiser/denoiser.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.tmed_denoiser.TMED_denoiser 2 | text_encoded_dim: 768 # or 512 patch-14-large or base 3 | ff_size: 1024 4 | num_layers: 8 5 | num_heads: 4 6 | dropout: 0.1 7 | activation: 'gelu' 8 | condition: ${model.condition} 9 | motion_condition: ${model.motion_condition} 10 | latent_dim: ${model.latent_dim} 11 | nfeats: ${model.nfeats} # TODO FIX THIS 12 | use_sep: true 13 | pred_delta_motion: false 14 | -------------------------------------------------------------------------------- /configs/model/denoiser/ditdenoiser_cls.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.DiT_denoiser_cls.DiT_Denoiser_CLS 2 | text_encoded_dim: 768 # or 512 patch-14-large or base 3 | ff_size: 1024 4 | num_layers: 8 5 | num_heads: 8 6 | dropout: 0.1 7 | activation: 'gelu' 8 | condition: ${model.condition} 9 | motion_condition: ${model.motion_condition} 10 | latent_dim: ${model.latent_dim} 11 | nfeats: ${model.nfeats} 12 | use_sep: true 13 | pred_delta_motion: false 14 | repr_dim: 256 15 | n_cls: 3 16 | encoder_layers: 4 # 4? -------------------------------------------------------------------------------- /configs/model/denoiser/ditdenoiser_cls_arch.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.DiT_denoiser_cls_arch.DiT_Denoiser_CLS_Arch 2 | text_encoded_dim: 768 # or 512 patch-14-large or base 3 | ff_size: 1024 4 | num_layers: 8 5 | num_heads: 8 6 | dropout: 0.1 7 | activation: 'gelu' 8 | condition: ${model.condition} 9 | motion_condition: ${model.motion_condition} 10 | latent_dim: ${model.latent_dim} 11 | nfeats: ${model.nfeats} 12 | use_sep: true 13 | pred_delta_motion: false 14 | repr_dim: 256 15 | n_cls: 3 16 | encoder_layers: 4 17 | ablation: no_text # 'raw_text', 'no_text', 'raw_source', 'raw_text_and_source' -------------------------------------------------------------------------------- /configs/model/infer_scheduler/ddim.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffusers.DDIMScheduler 2 | num_train_timesteps: 1000 3 | beta_start: 0.00085 4 | beta_end: 0.012 5 | beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 6 | # variance_type: 'fixed_small' 7 | clip_sample: false # clip sample to -1~1 8 | prediction_type: sample 9 | # below are for ddim 10 | set_alpha_to_one: false 11 | steps_offset: 1 -------------------------------------------------------------------------------- /configs/model/infer_scheduler/ddpm.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffusers.DDPMScheduler 2 | num_train_timesteps: 1000 3 | beta_start: 0.00085 4 | beta_end: 0.012 5 | beta_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 6 | variance_type: 'fixed_small' 7 | clip_sample: false # clip sample to -1~1 8 | prediction_type: sample 9 | -------------------------------------------------------------------------------- /configs/model/losses/basic.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.losses.MLDLosses 2 | 3 | # Loss terms 4 | ## Reconstruction losses 5 | lmd_rfeats_recons: 1.0 6 | lmd_jfeats_recons: 1.0 7 | predict_epsilon: false 8 | 9 | modelname: ${model.modelname} 10 | 11 | ## Latent sinc. losses 12 | lmd_latent: 1e-5 13 | lmd_kl: 1e-5 14 | lmd_prior: 0.0 15 | lmd_recons: 1.0 16 | lmd_gen: 1.0 17 | fuse: 'concat' # 'add' null 18 | # Ablations 19 | loss_on_both: true 20 | # loss on joint position features 21 | loss_on_jfeats: false 22 | ablation_no_kl_combine: false 23 | ablation_no_kl_gaussian: false 24 | ablation_no_motionencoder: false 25 | 26 | # # Text => rfeats (rotation features) 27 | # recons_text2rfeats: ${.lmd_rfeats_recons} 28 | # recons_text2rfeats_func: ${model.func_recons} 29 | 30 | # # Text => jfeats (xyz features) 31 | # recons_text2jfeats: ${.lmd_jfeats_recons} 32 | # recons_text2jfeats_func: ${model.func_recons} 33 | 34 | # # rfeats => rfeats 35 | # recons_rfeats2rfeats: ${.lmd_rfeats_recons} 36 | # recons_rfeats2rfeats_func: ${model.func_recons} 37 | 38 | # # vts => vts 39 | # recons_vertex2vertex: ${.lmd_rfeats_recons} 40 | # recons_vertex2vertex_func: ${model.func_recons} 41 | 42 | # # jfeats => jfeats 43 | # recons_jfeats2jfeats: ${.lmd_jfeats_recons} 44 | # recons_jfeats2jfeats_func: ${model.func_recons} 45 | 46 | # # Latent sinc.losses 47 | # latent_manifold: ${.lmd_latent} 48 | # latent_manifold_func: ${model.func_latent} 49 | 50 | # # VAE losses 51 | # kl_text: ${.lmd_kl} 52 | # kl_text_func: ${model.func_kl} 53 | 54 | # kl_motion: ${.lmd_kl} 55 | # kl_motion_func: ${model.func_kl} 56 | 57 | # kl_text2motion: ${.lmd_kl} 58 | # kl_text2motion_func: ${model.func_kl} 59 | 60 | # kl_motion2text: ${.lmd_kl} 61 | # kl_motion2text_func: ${model.func_kl} 62 | 63 | -------------------------------------------------------------------------------- /configs/model/losses/function/kl.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.losses.KLLoss -------------------------------------------------------------------------------- /configs/model/losses/function/klmulti.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.losses.KLLossMulti -------------------------------------------------------------------------------- /configs/model/losses/function/mse.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.MSELoss 2 | reduction: mean 3 | -------------------------------------------------------------------------------- /configs/model/losses/function/recons.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.losses.recons.Recons -------------------------------------------------------------------------------- /configs/model/losses/function/recons_bp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.losses.recons_bp.ReconsBP 2 | -------------------------------------------------------------------------------- /configs/model/losses/function/smoothL1.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.SmoothL1Loss 2 | reduction: mean -------------------------------------------------------------------------------- /configs/model/motion_condition_encoder/actor.yaml: -------------------------------------------------------------------------------- 1 | name: actor_encoder 2 | _target_: src.model.motionencoder.ActorAgnosticEncoder 3 | 4 | latent_dim: ${model.latent_dim} 5 | 6 | ff_size: ${model.ff_size} 7 | num_layers: ${model.num_layers} 8 | num_head: ${model.num_head} 9 | droupout: ${model.droupout} 10 | activation: ${model.activation} 11 | nfeats: ${model.nfeats} # TODO FIX THIS -------------------------------------------------------------------------------- /configs/model/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 1e-4 3 | lr_final: 1e-6 4 | t_warmup: 150 5 | -------------------------------------------------------------------------------- /configs/model/text_encoder/clipenc.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.textencoder.clip_encoder.ClipTextEncoder 2 | finetune: false # if false, model weights are frozen 3 | # true --> 77x768 | false --> 1x768 4 | last_hidden_state: true # if true, the last hidden state is used as the text embedding 5 | # clip-vit-base-patch32 | clip-vit-large-patch14 6 | modelpath: ${path.deps}/clip-vit-large-patch14 7 | -------------------------------------------------------------------------------- /configs/model/text_encoder/clipenc_sim.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.textencoder.clip_encoder.ClipTextEncoder 2 | finetune: false # if false, model weights are frozen 3 | # true --> 77x768 | false --> 1x768 4 | last_hidden_state: false # if true, the last hidden state is used as the text embedding 5 | # clip-vit-base-patch32 | clip-vit-large-patch14 6 | modelpath: ${path.deps}/clip-vit-large-patch14 7 | -------------------------------------------------------------------------------- /configs/model/text_encoder/distilbert_enc.yaml: -------------------------------------------------------------------------------- 1 | # name: distilbert_linear_encoder 2 | _target_: src.model.textencoder.distilbert_encoder.DistilbertEncoderTransformer 3 | 4 | latent_dim: ${model.latent_dim} 5 | 6 | ff_size: ${model.ff_size} 7 | num_layers: ${model.num_layers} 8 | num_head: ${model.num_head} 9 | droupout: ${model.droupout} 10 | activation: ${model.activation} 11 | 12 | finetune: false 13 | modelpath: ${path.deps}/distilbert-base-uncased 14 | -------------------------------------------------------------------------------- /configs/model/text_encoder/t5_enc.yaml: -------------------------------------------------------------------------------- 1 | name: t5_text_encoder 2 | _target_: src.model.textencoder.t5_encoder.T5TextEncoder 3 | finetune: false # if false, model weights are frozen 4 | modelpath: ${path.deps}/flan-t5-base -------------------------------------------------------------------------------- /configs/model/train_scheduler/ddim.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffusers.DDIMScheduler 2 | num_train_timesteps: 1000 3 | beta_start: 0.00085 4 | beta_end: 0.012 5 | beta_schedule: 'scaled_linear' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 6 | # variance_type: 'fixed_small' 7 | clip_sample: false # clip sample to -1~1 8 | prediction_type: sample 9 | # below are for ddim 10 | set_alpha_to_one: false 11 | steps_offset: 1 12 | -------------------------------------------------------------------------------- /configs/model/train_scheduler/ddpm.yaml: -------------------------------------------------------------------------------- 1 | _target_: diffusers.DDPMScheduler 2 | num_train_timesteps: 1000 3 | beta_start: 0.00085 4 | beta_end: 0.012 5 | beta_schedule: 'squaredcos_cap_v2' # Optional: ['linear', 'scaled_linear', 'squaredcos_cap_v2'] 6 | variance_type: 'fixed_small' 7 | clip_sample: false # clip sample to -1~1 8 | prediction_type: sample 9 | -------------------------------------------------------------------------------- /configs/motionfix_eval.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: . 4 | job: 5 | chdir: true 6 | output_subdir: null 7 | 8 | folder: ??? 9 | mode: sample # denoise / sample / sthing else 10 | 11 | prob_way: '3way' # 2way 12 | 13 | savedir: null 14 | mean: false 15 | fact: 1 16 | number_of_samples: 1 17 | ckpt_name: 'last' 18 | last_ckpt_path: ${get_last_checkpoint:${folder},${ckpt_name}} 19 | logger_level: INFO 20 | save_pkl: false 21 | render_vids: true 22 | subset: null 23 | 24 | num_sampling_steps: 1000 25 | 26 | guidance_scale_text_n_motion: null 27 | guidance_scale_motion: null 28 | 29 | init_from: 'noise' # noise 30 | condition_mode: 'full_cond' # 'mot_cond' 'text_cond' 31 | inpaint: false 32 | linear_gd: false 33 | 34 | defaults: 35 | - _self_ 36 | - data: motionfix 37 | - /path@path 38 | - override hydra/job_logging: rich # custom 39 | - override hydra/hydra_logging: rich # custom 40 | 41 | split_to_load: 42 | - "test" 43 | - "val" 44 | -------------------------------------------------------------------------------- /configs/path.yaml: -------------------------------------------------------------------------------- 1 | # path to additional modules 2 | deps: ${code_path:./deps} 3 | data: ${code_path:./data} 4 | minidata: ${code_path:./minidata} 5 | code_dir: ${code_path:} 6 | working_dir: ${working_path:""} 7 | minilog: ${code_path:./miniexperiments} -------------------------------------------------------------------------------- /configs/sampler/all_conseq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.sampling.FrameSampler 2 | 3 | # For fix size data 4 | request_frames: null 5 | sampling: conseq 6 | sampling_step: 1 7 | 8 | # For data of any size 9 | max_len: 800 10 | min_len: 15 11 | -------------------------------------------------------------------------------- /configs/sampler/fix_conseq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.sampling.FrameSampler 2 | 3 | # For fix size data 4 | request_frames: 60 5 | threshold_reject: 0.75 6 | sampling: conseq 7 | sampling_step: 1 8 | 9 | # For data of any size 10 | max_len: 10000 11 | min_len: 10 -------------------------------------------------------------------------------- /configs/sampler/upper_bound_variable_conseq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.sampling.FrameSampler 2 | 3 | # For fix size data 4 | request_frames: 5 | sampling: conseq 6 | sampling_step: 1 7 | threshold_reject: 1.0 8 | 9 | # For data of any size 10 | max_len: 500 11 | min_len: 15 -------------------------------------------------------------------------------- /configs/sampler/variable_conseq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.sampling.FrameSampler 2 | 3 | # For fix size data 4 | request_frames: 300 5 | sampling: conseq 6 | sampling_step: 1 7 | threshold_reject: 0.75 8 | 9 | # For data of any size 10 | max_len: 600 11 | min_len: 15 12 | -------------------------------------------------------------------------------- /configs/stats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: motion_prior_mlp_cvae 4 | - scheduler: reduce_on_plateau 5 | - dataloader: amass 6 | model_family: "mlp" 7 | debug: False 8 | variant: "base" 9 | dataset: "amass" 10 | render_train_outputs: True 11 | 12 | # all used directories 13 | project_dir: "." # /data/deps/stats 14 | chkpnt_dir: "data/checkpoints" # experiments 15 | amass_path: "data/amass/amass.pth.tar" 16 | smplx_models_path: "data/body_models/" 17 | # grab_path: "data/amass/GRAB/GRAB.pth.tar" 18 | grab_path: "/ps/project/datasets/GRAB" 19 | joints_path: "/is/cluster/fast/mdiomataris/grab_joints_new" 20 | 21 | # experiment details 22 | wandb_logger: 23 | # _target_: pytorch_lightning.loggers.WandbLogger 24 | resume: 'allow' 25 | id: 26 | group: 27 | tags: 28 | mode: "online" 29 | # project: 'HOI-common-sense' 30 | project: 'Motion_Prior-HPT' 31 | save_dir: "data/wandb_logs" 32 | 33 | rendering: 34 | dir: "data/rendered_outputs" 35 | in_training: True 36 | in_testing: True 37 | every_epochs: 100 38 | train_ids: 39 | - 0 40 | # - 10 41 | - 42 42 | # - 100 43 | - 200 44 | # - 300 45 | # - 345 46 | # - 128 47 | - 333 48 | # - 444 49 | test_ids: 50 | - 1 51 | - 2 52 | - 3 53 | - 5 54 | - 8 55 | - 13 56 | - 17 57 | - 16 58 | 59 | exp_id: None 60 | 61 | # lightning checkpoint callback 62 | trainer: 63 | _target_: pytorch_lightning.Trainer 64 | default_root_dir: ${chkpnt_dir} 65 | max_epochs: 300 66 | accelerator: 'gpu' 67 | devices: 1 68 | fast_dev_run: False 69 | overfit_batches: 0.0 70 | enable_progress_bar: True 71 | auto_scale_batch_size: 72 | accumulate_grad_batches: 73 | gradient_clip_val: 2.0 74 | callbacks: 75 | logger: 76 | resume_from_checkpoint: 77 | 78 | # precision: "bf16" 79 | 80 | batch_size: 32 81 | num_workers: 4 82 | tune: False 83 | 84 | # lightning checkpoint callback 85 | model_checkpoint: 86 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 87 | dirpath: 88 | filename: 89 | monitor: "val/loss_monitor" 90 | save_top_k: 1 91 | mode: "min" 92 | save_last: True 93 | every_n_epochs: 5 94 | 95 | optimizer: 96 | _target_: torch.optim.Adam 97 | lr: 1e-4 98 | 99 | # training parameters 100 | monitor_loss: "loss" 101 | 102 | # preprocessing parameters 103 | preproc: 104 | joints_dump_path: "/is/cluster/fast/mdiomataris/grab_joints_new" 105 | split_seed: 0 106 | calculate_minmax: True 107 | generate_joint_files: True 108 | use_cuda: True 109 | 110 | # dataloading parameters 111 | dl: 112 | framerate_ratio: 4 113 | chunk_duration: 8.0 114 | trim_nointeractions: False 115 | force_recalculate_stats: False 116 | 117 | # augmentation parameters 118 | aug: 119 | undo_interaction_prob: 0.1 120 | out_of_reach_prob: 0.1 121 | min_oor_distance: 0.05 122 | max_oor_distance: 0.3 123 | random_rotate: False 124 | random_rot_type: "3d" 125 | framerate_dev: 0 126 | 127 | loss_type: "mse" 128 | joint_loss: False 129 | n_body_joints: 22 130 | 131 | rot_repr: "6d" 132 | norm_type: "std" 133 | load_feats: 134 | - "body_transl" 135 | - "body_transl_delta" 136 | - "body_transl_delta_pelv" 137 | - "body_transl_delta_pelv_xy" 138 | - "body_transl_z" 139 | - "body_orient" 140 | - "body_pose" 141 | - "body_orient_delta" 142 | - "body_pose_delta" 143 | - "body_orient_xy" 144 | - "body_joints" 145 | # - "body_joints_rel" 146 | # - "body_joints_vel" 147 | # - "object_transl" 148 | # - "object_transl_rel" 149 | # - "object_transl_vel" 150 | # - "object_orient" 151 | # - "obj_contact_bin" 152 | # - "hands_contact_bin" 153 | # - "obj_wrists_dist" 154 | # - "wrist_joints_transl" 155 | # - "wrist_joints_transl_rel" 156 | # - "wrist_joints_vel" 157 | # - "joint_global_oris" 158 | # - "joint_ang_vel" 159 | # - "wrists_ang_vel" 160 | # - "wrists_ang_vel_euler" 161 | # - "active_grasp_frames" 162 | # - "index_tips_vel" 163 | 164 | 165 | feats_dims: 166 | body_transl: 3 167 | body_transl_delta: 3 168 | body_transl_delta_pelv: 3 169 | body_transl_delta_pelv_xy: 3 170 | body_transl_z: 1 171 | body_orient: 6 172 | body_orient_delta: 6 173 | body_orient_xy: 6 174 | body_pose: 126 175 | body_pose_delta: 126 176 | body_joints: 66 177 | body_joints_rel: 66 178 | body_joints_vel: 66 179 | object_transl: 3 180 | object_transl_rel: 3 181 | object_transl_vel: 3 182 | object_orient: 6 183 | obj_contact_bin: 1 184 | obj_wrists_dist: 6 185 | wrist_joints_transl: 6 186 | wrist_joints_transl_rel: 6 187 | wrist_joints_vel: 6 188 | index_tips_vel: 6 189 | joint_ang_vel: 6 190 | wrists_ang_vel: 6 191 | hands_contact_bin: 2 192 | 193 | # rendering: 194 | # fps: 30 195 | # choose_random: 196 | # indices: 197 | # - 36 #-> drinking from bottle 198 | # - 160 #-> headphones over head 199 | # - 269 #-> airplane 200 | # - 272 #-> failing case, scisors 201 | # - 51 #-> good example (easy) -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${expdir}/${project}/${experiment}/${run_id}/ 4 | job: 5 | chdir: true 6 | env_set: 7 | # if you want to use wandb you should assign this key 8 | WANDB_API_KEY: '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' 9 | PYOPENGL_PLATFORM: 'egl' 10 | HYDRA_FULL_ERROR: 1 11 | WANDB__SERVICE_WAIT: 300 12 | 13 | debug: false 14 | 15 | # Global configurations shared between different modules 16 | expdir: ${get_expdir:${debug}} 17 | experiment: ${data.dataname} 18 | # must be the same when you are resuming experiment 19 | project: new_code 20 | seed: 42 21 | logger_level: INFO 22 | run_id: ${generate_id:} 23 | # For finetuning 24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/} 25 | resume_ckpt_name: 'last' 26 | renderer: null # ait # null 27 | 28 | # log gradients/weights 29 | watch_model: false 30 | log_freq: 1000 31 | log: 'all' 32 | 33 | devices: 1 34 | 35 | # For finetuning 36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline 37 | ftune_ckpt_name: 'last' # 'all' 38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}} 39 | 40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy 41 | # statistics_file: "statistics_amass_circle.py" 42 | statistics_path: ${path.deps}/stats/${statistics_file} 43 | 44 | 45 | # Composing nested config with default 46 | defaults: 47 | - data: motionfix 48 | - model: basic_clip 49 | - machine: server 50 | - trainer: base 51 | - sampler: variable_conseq # cut it 52 | - logger: tensorboard # wandb 53 | - callback: base 54 | - /path@path 55 | - override hydra/job_logging: rich 56 | - override hydra/hydra_logging: rich 57 | - _self_ 58 | -------------------------------------------------------------------------------- /configs/train_cls.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${expdir}/${project}/${experiment}/${run_id}/ 4 | job: 5 | chdir: true 6 | env_set: 7 | # if you want to use wandb you should assign this key 8 | WANDB_API_KEY: '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' 9 | PYOPENGL_PLATFORM: 'egl' 10 | HYDRA_FULL_ERROR: 1 11 | WANDB__SERVICE_WAIT: 300 12 | 13 | debug: false 14 | 15 | # Global configurations shared between different modules 16 | expdir: ${get_expdir:${debug}} 17 | experiment: ${data.dataname} 18 | # must be the same when you are resuming experiment 19 | project: new_code 20 | seed: 42 21 | logger_level: INFO 22 | run_id: ${generate_id:} 23 | # For finetuning 24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/} 25 | resume_ckpt_name: 'last' 26 | renderer: null # ait # null 27 | 28 | # log gradients/weights 29 | watch_model: false 30 | log_freq: 1000 31 | log: 'all' 32 | 33 | devices: 1 34 | 35 | # For finetuning 36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline 37 | ftune_ckpt_name: 'last' # 'all' 38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}} 39 | 40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy 41 | # statistics_file: "statistics_amass_circle.py" 42 | statistics_path: ${path.deps}/stats/${statistics_file} 43 | 44 | 45 | # Composing nested config with default 46 | defaults: 47 | - data: motionfix 48 | - model: basic_clip_cls 49 | - machine: server 50 | - trainer: base-longer 51 | - sampler: variable_conseq # cut it 52 | - logger: tensorboard # wandb 53 | - callback: base 54 | - /path@path 55 | - override hydra/job_logging: rich 56 | - override hydra/hydra_logging: rich 57 | - _self_ 58 | -------------------------------------------------------------------------------- /configs/train_cls_arch.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${expdir}/${project}/${experiment}/${run_id}/ 4 | job: 5 | chdir: true 6 | env_set: 7 | # if you want to use wandb you should assign this key 8 | WANDB_API_KEY: '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' 9 | PYOPENGL_PLATFORM: 'egl' 10 | HYDRA_FULL_ERROR: 1 11 | WANDB__SERVICE_WAIT: 300 12 | 13 | debug: false 14 | 15 | # Global configurations shared between different modules 16 | expdir: ${get_expdir:${debug}} 17 | experiment: ${data.dataname} 18 | # must be the same when you are resuming experiment 19 | project: new_code 20 | seed: 42 21 | logger_level: INFO 22 | run_id: ${generate_id:} 23 | # For finetuning 24 | resume: ${working_path:${expdir}/${project}/${experiment}/${run_id}/} 25 | resume_ckpt_name: 'last' 26 | renderer: null # ait # null 27 | 28 | # log gradients/weights 29 | watch_model: false 30 | log_freq: 1000 31 | log: 'all' 32 | 33 | devices: 1 34 | 35 | # For finetuning 36 | ftune: null #${working_path:""} #/depot/bera89/data/li5280/project/motionfix/experiments/sigga-cr/baseline/baseline 37 | ftune_ckpt_name: 'last' # 'all' 38 | ftune_ckpt_path: ${get_last_checkpoint:${ftune},${ftune_ckpt_name}} 39 | 40 | statistics_file: statistics_${data.dataname}${get_debug:${debug}}.npy 41 | # statistics_file: "statistics_amass_circle.py" 42 | statistics_path: ${path.deps}/stats/${statistics_file} 43 | 44 | 45 | # Composing nested config with default 46 | defaults: 47 | - data: motionfix 48 | - model: basic_clip_cls_arch 49 | - machine: server 50 | - trainer: base-longer 51 | - sampler: variable_conseq # cut it 52 | - logger: tensorboard # wandb 53 | - callback: base 54 | - /path@path 55 | - override hydra/job_logging: rich 56 | - override hydra/hydra_logging: rich 57 | - _self_ 58 | -------------------------------------------------------------------------------- /configs/trainer/base-longer.yaml: -------------------------------------------------------------------------------- 1 | strategy: null # 'ddp' for multi gpu 2 | benchmark: True 3 | max_epochs: 1501 4 | accelerator: gpu 5 | log_every_n_steps: 40 # 100 6 | deterministic: False 7 | detect_anomaly: False 8 | enable_progress_bar: True 9 | check_val_every_n_epoch: 80 # if k --> happens every (k-1), 2*(k-1), ... 100 10 | limit_train_batches: 1.0 11 | limit_val_batches: 1.0 12 | num_sanity_val_steps: 0 # 2 13 | precision: 32 # => 'bf16' | 16 | 32 14 | -------------------------------------------------------------------------------- /configs/trainer/base.yaml: -------------------------------------------------------------------------------- 1 | strategy: null # 'ddp' for multi gpu 2 | benchmark: True 3 | max_epochs: 1001 4 | accelerator: gpu 5 | log_every_n_steps: 40 # 100 6 | deterministic: False 7 | detect_anomaly: False 8 | enable_progress_bar: True 9 | check_val_every_n_epoch: 100 # if k --> happens every (k-1), 2*(k-1), ... 100 10 | limit_train_batches: 1.0 11 | limit_val_batches: 1.0 12 | num_sanity_val_steps: 0 # 2 13 | precision: 32 # => 'bf16' | 16 | 32 14 | -------------------------------------------------------------------------------- /deps/inference/qual.txt: -------------------------------------------------------------------------------- 1 | indep-seg+seq/ 2 | climb down stairs, hold rail with left hand 3 | spatial_pairs-1008-0.mp4 4 | 5 | wide side step to the right 6 | seg-10094-0.mp4 7 | 8 | walk up stairs 9 | seg-10090-2.mp4 10 | 11 | lift right arm 12 | seg-10072-19.mp4 13 | 14 | 15 | left step left, hold arms as in waltz 16 | spatial_pairs-9918-1.mp4 17 | 18 | 19 | high jump, right hand full high 20 | spatial_pairs-9854-6.mp4 21 | spatial_pairs-9854-5.mp4 22 | spatial_pairs-9854-3.mp4 23 | 24 | 25 | look down, walk 26 | spatial_pairs-10090-0.mp4 27 | 28 | 29 | spatial_pairs-5740-2 30 | spatial_pairs-12301-0 31 | -------------------------------------------------------------------------------- /deps/smplh/smpl.faces: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/deps/smplh/smpl.faces -------------------------------------------------------------------------------- /deps/stats/statistics_motionfix.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/deps/stats/statistics_motionfix.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | accelerate==0.23.0 3 | aiohttp==3.8.1 4 | aiosignal==1.2.0 5 | aitviewer==1.12.0 6 | antlr4-python3-runtime==4.9.3 7 | appdirs==1.4.4 8 | asttokens==2.0.5 9 | async-timeout==4.0.2 10 | attrs==21.4.0 11 | backcall==0.2.0 12 | blis==0.7.9 13 | cachetools==5.0.0 14 | catalogue==2.0.8 15 | certifi==2021.10.8 16 | charset-normalizer==2.0.12 17 | click==8.0.4 18 | cmake==3.26.0 19 | colorlog==6.6.0 20 | confection==0.0.4 21 | contourpy==1.1.1 22 | cycler==0.12.1 23 | cymem==2.0.7 24 | decorator==4.4.2 25 | diffusers==0.21.2 26 | docker-pycreds==0.4.0 27 | einops==0.6.1 28 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl#sha256=0964370218b7e1672a30ac50d72cdc6b16f7c867496f1d60925691188f4d2510 29 | executing==0.8.3 30 | filelock==3.6.0 31 | fonttools==4.43.1 32 | freetype-py==2.2.0 33 | frozenlist==1.3.0 34 | fsspec==2024.3.1 35 | future==0.18.2 36 | gitdb==4.0.9 37 | GitPython==3.1.27 38 | glcontext==2.4.0 39 | google-auth==2.6.2 40 | google-auth-oauthlib==0.4.6 41 | grpcio==1.44.0 42 | huggingface-hub==0.17.3 43 | hydra-colorlog==1.2.0 44 | hydra-core==1.3.2 45 | idna==3.3 46 | imageio==2.16.1 47 | imageio-ffmpeg==0.4.5 48 | imgui==2.0.0 49 | importlib-metadata==4.11.3 50 | ipdb==0.13.9 51 | ipython==8.1.1 52 | jedi==0.18.1 53 | Jinja2==3.1.2 54 | joblib==1.3.2 55 | kiwisolver==1.4.5 56 | langcodes==3.3.0 57 | lightning-utilities==0.8.0 58 | lit==16.0.0 59 | loguru==0.6.0 60 | Markdown==3.3.6 61 | markdown-it-py==3.0.0 62 | MarkupSafe==2.1.2 63 | matplotlib==3.8.0 64 | matplotlib-inline==0.1.3 65 | mdurl==0.1.2 66 | moderngl==5.8.2 67 | moderngl-window==2.4.4 68 | more-itertools==9.1.0 69 | moviepy==1.0.3 70 | mpmath==1.3.0 71 | multidict==6.0.2 72 | multipledispatch==1.0.0 73 | murmurhash==1.0.9 74 | networkx==2.7.1 75 | numpy==1.22.3 76 | oauthlib==3.2.0 77 | omegaconf==2.3.0 78 | opencv-contrib-python-headless==4.8.0.76 79 | opencv-python==4.5.5.64 80 | orjson==3.9.15 81 | packaging==21.3 82 | pandas==1.4.1 83 | parso==0.8.3 84 | pathtools==0.1.2 85 | pathy==0.10.1 86 | pexpect==4.8.0 87 | pickleshare==0.7.5 88 | Pillow==9.0.1 89 | preshed==3.0.8 90 | proglog==0.1.9 91 | promise==2.3 92 | prompt-toolkit==3.0.28 93 | protobuf==3.19.4 94 | psutil==5.9.0 95 | ptyprocess==0.7.0 96 | pudb==2022.1.1 97 | pure-eval==0.2.2 98 | pyasn1==0.4.8 99 | pyasn1-modules==0.2.8 100 | pydantic==1.10.2 101 | pyDeprecate==0.3.1 102 | pyglet==2.0.9 103 | Pygments==2.16.1 104 | PyOpenGL==3.1.0 105 | pyparsing==3.0.7 106 | PyQt5==5.15.9 107 | PyQt5-Qt5==5.15.2 108 | PyQt5-sip==12.12.2 109 | pyrender==0.1.45 110 | pyrr==0.10.3 111 | python-dateutil==2.8.2 112 | pytz==2021.3 113 | PyWavelets==1.4.1 114 | PyYAML==6.0 115 | regex==2022.3.15 116 | requests==2.27.1 117 | requests-oauthlib==1.3.1 118 | rich==13.5.3 119 | roma==1.4.0 120 | rsa==4.8 121 | sacremoses==0.0.49 122 | safetensors==0.3.3 123 | scikit-image==0.19.3 124 | scikit-video==1.1.11 125 | scipy==1.7.2 126 | sentry-sdk==1.5.8 127 | setproctitle==1.3.2 128 | shortuuid==1.0.8 129 | six==1.16.0 130 | smart-open==6.3.0 131 | smmap==5.0.0 132 | smplx==0.1.28 133 | spacy==3.5.1 134 | spacy-legacy==3.0.12 135 | spacy-loggers==1.0.4 136 | srsly==2.4.6 137 | stack-data==0.2.0 138 | sympy==1.11.1 139 | tensorboard==2.10.0 140 | tensorboard-data-server==0.6.1 141 | tensorboard-plugin-wit==1.8.1 142 | termcolor==1.1.0 143 | thinc==8.1.9 144 | tifffile==2023.9.18 145 | tokenizers==0.11.6 146 | toml==0.10.2 147 | tqdm==4.63.0 148 | traitlets==5.1.1 149 | transformers==4.17.0 150 | trimesh==3.10.5 151 | triton==2.1.0 152 | typer==0.7.0 153 | typing_extensions==4.11.0 154 | urllib3==1.26.9 155 | urwid==2.1.2 156 | urwid-readline==0.13 157 | usd-core==23.8 158 | uuid==1.30 159 | wandb==0.16.6 160 | wasabi==1.1.1 161 | wcwidth==0.2.5 162 | websockets==11.0.3 163 | Werkzeug==2.0.3 164 | yarl==1.7.2 165 | yaspin==2.1.0 166 | zipp==3.7.0 167 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir data 3 | mkdir data/body_models 4 | # dataset 5 | gdown --folder "https://drive.google.com/drive/folders/1DM7oIJwxwoljVxAfhfktocTptwVX5sqR?usp=sharing" 6 | mv motionfix-dataset data/ 7 | # tmr folder 8 | gdown --folder "https://drive.google.com/drive/folders/15LHeriOCjmh4Cp5H9M94xoFGBN0amxI8?usp=sharing" 9 | mv tmr-evaluator eval-deps 10 | # smpl models 11 | gdown --folder "https://drive.google.com/drive/folders/1s3re2I1OzBimQIpudUEFB1hClFWOPJjC?usp=drive_link" 12 | mv smplh data/body_models 13 | # tmed checkpoints 14 | gdown --folder "https://drive.google.com/drive/folders/1M_i_zUSlktdEKf-xBF9g6y7N-lfDtuPD?usp=sharing" 15 | mkdir experiments 16 | mv tmed experiments/ 17 | mkdir experiments/tmed/checkpoints 18 | mv experiments/tmed/last.ckpt experiments/tmed/checkpoints/ 19 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Creating virtual environment" 4 | python3.10 -m venv mfix-env 5 | echo "Activating virtual environment" 6 | 7 | source $PWD/mfix-env/bin/activate 8 | 9 | $PWD/mfix-env/bin/pip install --upgrade pip setuptools 10 | 11 | $PWD/mfix-env/bin/pip install "torch==2.1.2" "torchvision==0.16.2" --index-url https://download.pytorch.org/whl/cu118 12 | 13 | $PWD/mfix-env/bin/pip install "pytorch-lightning==2.2.4" 14 | $PWD/mfix-env/bin/pip install -r requirements.txt 15 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/__init__.py -------------------------------------------------------------------------------- /src/callback/__init__.py: -------------------------------------------------------------------------------- 1 | from .progress import ProgressLogger 2 | from .render import RenderCallback -------------------------------------------------------------------------------- /src/callback/progress.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning import LightningModule, Trainer 4 | from pytorch_lightning.callbacks import Callback 5 | import psutil 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ProgressLogger(Callback): 11 | def __init__(self, 12 | metric_monitor: dict, 13 | precision: int = 3): 14 | # Metric to monitor 15 | self.metric_monitor = metric_monitor 16 | self.precision = precision 17 | 18 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None: 19 | logger.info("Training started") 20 | 21 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None: 22 | logger.info("Training done") 23 | 24 | def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None: 25 | if trainer.sanity_checking: 26 | logger.info("Sanity checking ok.") 27 | 28 | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs) -> None: 29 | metric_format = f"{{:.{self.precision}e}}" 30 | line = f"Epoch {trainer.current_epoch}" 31 | line = f"{line:>{len('Epoch xxxx')}}" # Right padding 32 | metrics_str = [] 33 | 34 | losses_dict = trainer.callback_metrics 35 | for metric_name, dico_name in self.metric_monitor.items(): 36 | if dico_name not in losses_dict: 37 | dico_name = f"losses/{dico_name}" 38 | 39 | if dico_name in losses_dict: 40 | metric = losses_dict[dico_name].item() 41 | metric = metric_format.format(metric) 42 | metric = f"{metric_name} {metric}" 43 | metrics_str.append(metric) 44 | 45 | if len(metrics_str) == 0: 46 | return 47 | 48 | memory = f"Memory {psutil.virtual_memory().percent}%" 49 | line = line + ": " + " ".join(metrics_str) + " " + memory 50 | logger.info(line) 51 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/base.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader 3 | 4 | from src.data.tools.collate import collate_batch_last_padding, collate_datastruct_and_text 5 | import torch 6 | from typing import List 7 | 8 | class BASEDataModule(pl.LightningDataModule): 9 | def __init__(self, 10 | batch_size: int, 11 | num_workers: int, 12 | load_feats: List[str], 13 | batch_sampler: str | None = None, 14 | dataset_percentages: dict[str, float] | None = None): 15 | super().__init__() 16 | 17 | collate_fn = lambda b: collate_batch_last_padding(b, load_feats) 18 | 19 | def set_worker_sharing_strategy(worker_id: int) -> None: 20 | sharing_strategy = "file_system" 21 | torch.multiprocessing.set_sharing_strategy(sharing_strategy) 22 | self.dataloader_options = { 23 | 'batch_size': batch_size, 24 | 'num_workers': num_workers, 25 | 'collate_fn': collate_fn, 26 | 'drop_last': False, 27 | 'worker_init_fn': set_worker_sharing_strategy 28 | # 'pin_memory': True, 29 | } 30 | self.batch_sampler = batch_sampler 31 | self.ds_perc = dataset_percentages 32 | self.batch_size = batch_size 33 | # need to be overloaded: 34 | # - self.Dataset 35 | # - self._sample_set => load only a small subset 36 | # There is an helper below (get_sample_set) 37 | # - self.nfeats 38 | # - self.transforms 39 | self._train_dataset = None 40 | self._val_dataset = None 41 | self._test_dataset = None 42 | 43 | # Optional 44 | self._subset_dataset = None 45 | 46 | def get_sample_set(self, overrides={}): 47 | sample_params = self.hparams.copy() 48 | sample_params.update(overrides) 49 | return self.Dataset(**sample_params) 50 | 51 | def train_dataloader(self): 52 | if self.batch_sampler is not None: 53 | from src.data.sampling.custom_batch_sampler import PercBatchSampler, CustomBatchSampler, CustomBatchSamplerV2, CustomBatchSamplerV4 54 | from src.data.sampling.custom_batch_sampler import mix_datasets_anysize 55 | # ratio_batch_sampler = CustomBatchSamplerV2(concat_dataset=self.dataset['train'], 56 | # batch_size=self.batch_size) 57 | ratio_batch_sampler = CustomBatchSamplerV4(concat_dataset=self.dataset['train'], 58 | batch_size=self.batch_size, 59 | mix_percentages=self.ds_perc) 60 | 61 | # ratio_batch_sampler = PercBatchSampler(data_source=self.dataset['train'], 62 | # baxtch_size=self.batch_size) 63 | # dataset_percentages=self.ds_perc) 64 | del self.dataloader_options['batch_size'] 65 | return DataLoader(self.dataset['train'], 66 | batch_sampler=ratio_batch_sampler, 67 | **self.dataloader_options) 68 | else: 69 | return DataLoader(self.dataset['train'], 70 | shuffle=True, 71 | **self.dataloader_options) 72 | 73 | def val_dataloader(self): 74 | if self.batch_sampler is not None: 75 | return DataLoader(self.dataset['test'], 76 | #batch_sampler=ratio_batch_sampler, 77 | shuffle=False, 78 | **self.dataloader_options) 79 | else: 80 | return DataLoader(self.dataset['test'], 81 | shuffle=False, 82 | **self.dataloader_options) 83 | 84 | def test_dataloader(self): 85 | return DataLoader(self.dataset['test'], 86 | shuffle=False, 87 | **self.dataloader_options) 88 | 89 | # def train_dataloader(self): 90 | # return DataLoader(self.train_dataset, 91 | # shuffle=True, 92 | # **self.dataloader_options) 93 | 94 | # def predict_dataloader(self): 95 | # return DataLoader(self.train_dataset, 96 | # shuffle=False, 97 | # **self.dataloader_options) 98 | 99 | # def val_dataloader(self): 100 | # return DataLoader(self.val_dataset, 101 | # shuffle=False, 102 | # **self.dataloader_options) 103 | 104 | # def test_dataloader(self): 105 | # return DataLoader(self.test_dataset, 106 | # shuffle=False, 107 | # **self.dataloader_options) 108 | 109 | # def subset_dataloader(self): 110 | # return DataLoader(self.subset_dataset, 111 | # shuffle=False, 112 | # **self.dataloader_options) 113 | -------------------------------------------------------------------------------- /src/data/features.py: -------------------------------------------------------------------------------- 1 | from src.tools.transforms3d import change_for, transform_body_pose 2 | from src.utils.genutils import to_tensor 3 | import torch 4 | 5 | def _get_body_pose(data): 6 | """get body pose""" 7 | # default is axis-angle representation: Frames x (Jx3) (J=21) 8 | if not torch.is_tensor(data): 9 | pose = to_tensor(data['rots'][..., 3:3 + 21*3]) # drop pelvis orientation 10 | else: 11 | pose = to_tensor(data[..., 3:3 + 21*3]) 12 | pose = transform_body_pose(pose, f"aa->6d") 13 | return pose 14 | 15 | def _get_body_transl(data): 16 | """get body pelvis translation""" 17 | if not torch.is_tensor(data): 18 | tran = data['trans'] 19 | else: 20 | tran = data 21 | return to_tensor(tran) 22 | 23 | def _get_body_orient(data): 24 | """get body global orientation""" 25 | # default is axis-angle representation 26 | if not torch.is_tensor(data): 27 | pelvis_orient = to_tensor(data['rots'][..., :3]) 28 | else: 29 | pelvis_orient = data 30 | # axis-angle to rotation matrix & drop last row 31 | pelvis_orient = transform_body_pose(pelvis_orient, "aa->6d") 32 | return pelvis_orient 33 | 34 | def _get_body_transl_delta_pelv(data): 35 | """ 36 | get body pelvis tranlation delta relative to pelvis coord.frame 37 | v_i = t_i - t_{i-1} relative to R_{i-1} 38 | """ 39 | trans = to_tensor(data['trans']) 40 | trans_vel = trans - trans.roll(1, 0) # shift one right and subtract 41 | pelvis_orient = transform_body_pose(to_tensor(data['rots'][..., :3]), "aa->rot") 42 | trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0)) 43 | trans_vel_pelv[0] = 0 # zero out velocity of first frame 44 | return trans_vel_pelv 45 | 46 | def _get_body_transl_delta_pelv_infer(pelvis_orient, trans): 47 | """ 48 | get body pelvis tranlation delta relative to pelvis coord.frame 49 | v_i = t_i - t_{i-1} relative to R_{i-1} 50 | """ 51 | trans = to_tensor(trans) 52 | trans_vel = trans - trans.roll(1, 0) # shift one right and subtract 53 | pelvis_orient = transform_body_pose(to_tensor(pelvis_orient), "6d->rot") 54 | trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0)) 55 | trans_vel_pelv[0] = 0 # zero out velocity of first frame 56 | return trans_vel_pelv 57 | -------------------------------------------------------------------------------- /src/data/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import FrameSampler 2 | from .framerate import subsample, upsample 3 | -------------------------------------------------------------------------------- /src/data/sampling/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class FrameSampler: 7 | sampling: str = "conseq" 8 | sampling_step: int = 1 9 | request_frames: Optional[int] = None 10 | threshold_reject: int = 0.75 11 | max_len: int = 1000 12 | min_len: int = 10 13 | 14 | def __call__(self, num_frames): 15 | from .frames import get_frameix_from_data_index 16 | return get_frameix_from_data_index(num_frames, 17 | self.max_len, 18 | self.request_frames, 19 | self.sampling, 20 | self.sampling_step) 21 | 22 | def accept(self, duration): 23 | # Outputs have original lengths 24 | # Check if it is too long 25 | if self.request_frames is None: 26 | if duration > self.max_len: 27 | return False 28 | if duration < self.min_len: 29 | return False 30 | else: 31 | # Reject sample if the length is 32 | # too little relative to 33 | # the request frames 34 | 35 | # min_number = self.threshold_reject * self.request_frames 36 | if duration < self.min_len: # min_number: 37 | return False 38 | return True 39 | 40 | def get(self, key, default=None): 41 | return getattr(self, key, default) 42 | 43 | def __getitem__(self, key): 44 | return getattr(self, key) 45 | -------------------------------------------------------------------------------- /src/data/sampling/framerate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # TODO: use a real subsampler.. 5 | def subsample(num_frames, last_framerate, new_framerate): 6 | step = int(last_framerate / new_framerate) 7 | assert step >= 1 8 | frames = np.arange(0, num_frames, step) 9 | return frames 10 | 11 | 12 | # TODO: use a real upsampler.. 13 | def upsample(motion, last_framerate, new_framerate): 14 | step = int(new_framerate / last_framerate) 15 | assert step >= 1 16 | 17 | # Alpha blending => interpolation 18 | alpha = np.linspace(0, 1, step+1) 19 | last = np.einsum("l,...->l...", 1-alpha, motion[:-1]) 20 | new = np.einsum("l,...->l...", alpha, motion[1:]) 21 | 22 | chuncks = (last + new)[:-1] 23 | output = np.concatenate(chuncks.swapaxes(1, 0)) 24 | # Don't forget the last one 25 | output = np.concatenate((output, motion[[-1]])) 26 | return output 27 | 28 | 29 | if __name__ == "__main__": 30 | motion = np.arange(105) 31 | submotion = motion[subsample(len(motion), 100.0, 12.5)] 32 | newmotion = upsample(submotion, 12.5, 100) 33 | 34 | print(newmotion) 35 | -------------------------------------------------------------------------------- /src/data/sampling/frames.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | from numpy import ndarray as Array 5 | import random 6 | 7 | 8 | def get_frameix_from_data_index(num_frames: int, 9 | max_len: Optional[int], 10 | request_frames: Optional[int], 11 | sampling: str = "conseq", 12 | sampling_step: int = 1) -> Array: 13 | nframes = num_frames 14 | # do not pad small sequences sample from long ones 15 | if request_frames is None or request_frames > nframes: 16 | frame_ix = np.arange(nframes) 17 | else: 18 | # sampling goal: input: ----------- 11 nframes 19 | # o--o--o--o- 4 ninputs 20 | # 21 | # step number is computed like that: [(11-1)/(4-1)] = 3 22 | # [---][---][---][- 23 | # So step = 3, and we take 0 to step*ninputs+1 with steps 24 | # [o--][o--][o--][o-] 25 | # then we can randomly shift the vector 26 | # -[o--][o--][o--]o 27 | # If there are too much frames required 28 | # Nikos: It never gets here now. Should add a pad flag instead of this. 29 | if request_frames > nframes: 30 | fair = False # True 31 | if fair: 32 | # distills redundancy everywhere 33 | choices = np.random.choice(range(nframes), 34 | request_frames, 35 | replace=True) 36 | frame_ix = sorted(choices) 37 | else: 38 | # adding the last frame until done 39 | # Nikos: do not pad 40 | ntoadd = max(0, request_frames - nframes) 41 | lastframe = nframes - 1 42 | padding = lastframe * np.ones(ntoadd, dtype=int) 43 | frame_ix = np.concatenate((np.arange(0, nframes), 44 | padding)) 45 | 46 | elif sampling in ["conseq", "random_conseq"]: 47 | step_max = (nframes - 1) // (request_frames - 1) 48 | if sampling == "conseq": 49 | if sampling_step == -1 or sampling_step * (request_frames - 1) >= nframes: 50 | step = step_max 51 | else: 52 | step = sampling_step 53 | elif sampling == "random_conseq": 54 | step = random.randint(1, step_max) 55 | 56 | lastone = step * (request_frames - 1) 57 | shift_max = nframes - lastone - 1 58 | shift = random.randint(0, max(0, shift_max - 1)) 59 | frame_ix = shift + np.arange(0, lastone + 1, step) 60 | 61 | elif sampling == "random": 62 | choices = np.random.choice(range(nframes), 63 | request_frames, 64 | replace=False) 65 | frame_ix = sorted(choices) 66 | 67 | else: 68 | raise ValueError("Sampling not recognized.") 69 | 70 | return frame_ix 71 | -------------------------------------------------------------------------------- /src/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensors import lengths_to_mask, lengths_to_mask_njoints 2 | from .collate import collate_text_and_length, collate_pairs_and_text, collate_datastruct_and_text, collate_tensor_with_padding 3 | -------------------------------------------------------------------------------- /src/data/tools/amass_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | name_mapping = {'MPI_mosh':'MoSh', 4 | 'ACCAD':'ACCAD', 5 | 'WEIZMANN': 'WEIZMANN', 6 | 'CNRS': 'CNRS', 7 | 'DFaust': 'DFaust', 8 | 'TCDHands': 'TCDHands', 9 | 'SOMA': 'SOMA', 10 | 'CMU': 'CMU', 11 | 'SFU': 'SFU', 12 | 'TotalCapture': 'TotalCapture', 13 | 'HDM05': 'HDM05', 14 | 'Transitions': 'Transitions', 15 | 'MoSh': 'MoSh', 16 | 'KIT': 'KIT', 17 | 'DanceDB': 'DanceDB', 18 | 'Transitions_mocap': 'Transitions', 19 | 'PosePrior': 'PosePrior', 20 | 'MPI_Limits': 'PosePrior', 21 | 'BMLhandball': 'BMLhandball', 22 | 'SSM': 'SSM', 23 | 'TCD_handMocap': 'TCDHands', 24 | 'BMLrub': 'BMLrub', 25 | 'BioMotionLab_NTroje': 'BMLrub', 26 | 'SSM_synced': 'SSM', 27 | 'Eyes_Japan_Dataset': 'Eyes_Japan_Dataset', 28 | 'DFaust_67': 'DFaust', 29 | 'EKUT': 'EKUT', 30 | 'MPI_HDM05': 'HDM05', 31 | 'GRAB': 'GRAB', 32 | 'HumanEva': 'HumanEva', 33 | 'HUMAN4D': 'HUMAN4D', 34 | 'BMLmovi': 'BMLmovi' 35 | } 36 | 37 | 38 | def path_normalizer(paths_list_or_str): 39 | if isinstance(paths_list_or_str, str): 40 | paths_list = [paths_list_or_str] 41 | else: 42 | paths_list = list(paths_list_or_str) 43 | # works only for dir1/dir2/fname.npz to normalize dir1 44 | plist = ['/'.join(p.split('/')[-3:]) for p in paths_list] 45 | norm_path = ['/'.join([name_mapping[p.split('/')[0]], p.split('/')[1], p.split('/')[2]]) for p in plist if p.split('/')[0] in name_mapping.keys()] 46 | if isinstance(paths_list_or_str, str): 47 | return norm_path[0] 48 | else: 49 | return norm_path 50 | 51 | def fname_normalizer(fname): 52 | dataset_name, subject, sequence_name = fname.split('/') 53 | sequence_name = sequence_name.replace('_poses.npz', '') 54 | sequence_name = sequence_name.replace('_poses', '') 55 | sequence_name = sequence_name.replace('poses', '') 56 | sequence_name = sequence_name.replace('_stageii.npz', '') 57 | sequence_name = sequence_name.replace('_stageii', '') 58 | sequence_name = sequence_name.rstrip() 59 | getVals = list([val for val in sequence_name 60 | if val.isalpha() or val.isnumeric()]) 61 | sequence_name = ''.join(getVals) 62 | 63 | return '/'.join([dataset_name, subject, sequence_name]) 64 | 65 | def flip_motion(pose, trans): 66 | # expects T, Jx3 67 | # Permutation of SMPL pose parameters when flipping the shape 68 | SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 69 | 12, 14, 13, 15, 17, 16, 19, 70 | 18, 21, 20] #, 23, 22] 71 | SMPL_POSE_FLIP_PERM = [] 72 | for i in SMPL_JOINTS_FLIP_PERM: 73 | SMPL_POSE_FLIP_PERM.append(3*i) 74 | SMPL_POSE_FLIP_PERM.append(3*i+1) 75 | SMPL_POSE_FLIP_PERM.append(3*i+2) 76 | 77 | flipped_parts = SMPL_POSE_FLIP_PERM 78 | pose = pose[:, flipped_parts] 79 | # we also negate the second and the third dimension of the axis-angle 80 | pose[:, 1::3] = -pose[:, 1::3] 81 | pose[:, 2::3] = -pose[:, 2::3] 82 | x, y, z = torch.unbind(trans, dim=-1) 83 | mirrored_trans = torch.stack((-x, y, z), axis=-1) 84 | 85 | return pose, mirrored_trans -------------------------------------------------------------------------------- /src/data/tools/contacts.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from src.info.joints import smplh_joints 4 | left_foot_joints = [] 5 | right_foot_joints = [] 6 | jointnames = ['foot', 'small_toe', 'heel', 'big_toe', 'ankle'] 7 | 8 | def foot_detect(positions, thres): 9 | """ Get Foot Contacts """ 10 | velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0]) 11 | 12 | feet_l_x = (positions[1:, left_foot_joints, 0] - positions[:-1, left_foot_joints, 0]) ** 2 13 | feet_l_y = (positions[1:, left_foot_joints, 1] - positions[:-1, left_foot_joints, 1]) ** 2 14 | feet_l_z = (positions[1:, left_foot_joints, 2] - positions[:-1, left_foot_joints, 2]) ** 2 15 | # feet_l_h = positions[:-1,fid_l,1] 16 | # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float) 17 | feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32) 18 | 19 | feet_r_x = (positions[1:, right_foot_joints, 0] - positions[:-1, right_foot_joints, 0]) ** 2 20 | feet_r_y = (positions[1:, right_foot_joints, 1] - positions[:-1, right_foot_joints, 1]) ** 2 21 | feet_r_z = (positions[1:, right_foot_joints, 2] - positions[:-1, right_foot_joints, 2]) ** 2 22 | # feet_r_h = positions[:-1,fid_r,1] 23 | # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float) 24 | feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32) 25 | return feet_l, feet_r 26 | -------------------------------------------------------------------------------- /src/data/tools/extract_pairs.py: -------------------------------------------------------------------------------- 1 | from src.utils.nlp_consts import fix_spell 2 | from src.data.tools.spatiotempo import temporal_compositions, spatial_compositions 3 | from src.data.tools.spatiotempo import EXCLUDED_ACTIONS, EXCLUDED_ACTIONS_WO_TR 4 | 5 | 6 | def extract_frame_labels_onlytext(babel_labels): 7 | seg_acts = [] 8 | # if is_valid: 9 | if babel_labels['frame_ann'] is None: 10 | # 'transl' 'pose''betas' 11 | action_label = babel_labels['seq_ann']['labels'][0]['proc_label'] 12 | action_label = fix_spell(action_label) 13 | seg_acts.append(action_label) 14 | else: 15 | for seg_an in babel_labels['frame_ann']['labels']: 16 | action_label = fix_spell(seg_an['proc_label']) 17 | if action_label not in EXCLUDED_ACTIONS: 18 | seg_acts.append(action_label) 19 | 20 | return seg_acts 21 | 22 | 23 | def extract_frame_labels(babel_labels, fps, seqlen, max_simultaneous): 24 | 25 | seg_ids = [] 26 | seg_acts = [] 27 | # is_valid = True 28 | # possible_frame_dtypes = ['seg', 'pairs', 'separate_pairs', 'spatial_pairs'] 29 | # # if 'seq' in datatype and babel_labels['frame_ann'] is not None: 30 | # # is_valid = False 31 | # if bool(set(datatype.split('+')) & set(possible_frame_dtypes)) \ 32 | # and babel_labels['frame_ann'] is None: 33 | # is_valid = False 34 | 35 | possible_motions = {} 36 | # if is_valid: 37 | if babel_labels['frame_ann'] is None: 38 | # 'transl' 'pose''betas' 39 | action_label = babel_labels['seq_ann']['labels'][0]['proc_label'] 40 | possible_motions['seq'] = [(0, seqlen, fix_spell(action_label))] 41 | else: 42 | # Get segments 43 | # segments_dict = {k: {} for k in range(babel_labels['frame_ann']['labels'])} 44 | seg_list = [] 45 | 46 | for seg_an in babel_labels['frame_ann']['labels']: 47 | action_label = fix_spell(seg_an['proc_label']) 48 | 49 | st_f = int(seg_an['start_t'] * fps) 50 | end_f = int(seg_an['end_t'] * fps) 51 | 52 | if end_f > seqlen: 53 | end_f = seqlen 54 | seg_ids.append((st_f, end_f)) 55 | seg_acts.append(action_label) 56 | 57 | if action_label not in EXCLUDED_ACTIONS and end_f > st_f: 58 | seg_list.append((st_f, end_f, action_label)) 59 | 60 | possible_motions['seg'] = seg_list 61 | spatial = spatial_compositions(seg_list, 62 | actions_up_to=max_simultaneous) 63 | possible_motions['spatial_pairs'] = spatial 64 | possible_motions['separate_pairs'] = temporal_compositions( 65 | seg_ids, seg_acts) 66 | 67 | return possible_motions 68 | -------------------------------------------------------------------------------- /src/data/tools/rotation_transformation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import src.tools.geometry as geometry 3 | from einops import rearrange 4 | 5 | 6 | def rotate_trajectory(traj, rotZ, inverse=False): 7 | if inverse: 8 | # transpose 9 | rotZ = rearrange(rotZ, "... i j -> ... j i") 10 | 11 | vel = torch.diff(traj, dim=-2) 12 | # 0 for the first one => keep the dimentionality 13 | vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2) 14 | vel_local = torch.einsum("...kj,...k->...j", rotZ[..., :2, :2], vel[..., :2]) 15 | # Integrate the trajectory 16 | traj_local = torch.cumsum(vel_local, dim=-2) 17 | # First frame should be the same as before 18 | traj_local = traj_local - traj_local[..., [0], :] + traj[..., [0], :] 19 | return traj_local 20 | 21 | 22 | def rotate_trans(trans, rotZ, inverse=False): 23 | traj = trans[..., :2] 24 | transZ = trans[..., 2] 25 | traj_local = rotate_trajectory(traj, rotZ, inverse=inverse) 26 | trans_local = torch.cat((traj_local, transZ[..., None]), axis=-1) 27 | return trans_local 28 | 29 | 30 | def rotate_joints2D(joints, rotZ, inverse=False): 31 | if inverse: 32 | # transpose 33 | rotZ = rearrange(rotZ, "... i j -> ... j i") 34 | 35 | assert joints.shape[-1] == 2 36 | vel = torch.diff(joints, dim=-2) 37 | # 0 for the first one => keep the dimentionality 38 | vel = torch.cat((0 * vel[..., [0], :], vel), dim=-2) 39 | vel_local = torch.einsum("...kj,...lk->...lj", rotZ[..., :2, :2], vel[..., :2]) 40 | # Integrate the trajectory 41 | joints_local = torch.cumsum(vel_local, dim=-2) 42 | # First frame should be the same as before 43 | joints_local = joints_local - joints_local[..., [0], :] + joints[..., [0], :] 44 | return joints_local 45 | 46 | 47 | def rotate_joints(joints, rotZ, inverse=False): 48 | joints2D = joints[..., :2] 49 | jointsZ = joints[..., 2] 50 | joints2D_local = rotate_joints2D(joints2D, rotZ, inverse=inverse) 51 | joints_local = torch.cat((joints2D_local, jointsZ[..., None]), axis=-1) 52 | return joints_local 53 | 54 | 55 | def canonicalize_rotations(global_orient, trans, angle=0.0): 56 | global_euler = geometry.matrix_to_euler_angles(global_orient, "ZYX") 57 | anglesZ, anglesY, anglesX = torch.unbind(global_euler, -1) 58 | 59 | rotZ = geometry._axis_angle_rotation("Z", anglesZ) 60 | 61 | # remove the current rotation 62 | # make it local 63 | local_trans = rotate_trans(trans, rotZ) 64 | 65 | # For information: 66 | # rotate_joints(joints, rotZ) == joints_local 67 | 68 | diff_mat_rotZ = rotZ[..., 1:, :, :] @ rotZ.transpose(-1, -2)[..., :-1, :, :] 69 | 70 | vel_anglesZ = geometry.matrix_to_axis_angle(diff_mat_rotZ)[..., 2] 71 | # padding "same" 72 | vel_anglesZ = torch.cat((vel_anglesZ[..., [0]], vel_anglesZ), dim=-1) 73 | 74 | # Compute new rotation: 75 | # canonicalized 76 | anglesZ = torch.cumsum(vel_anglesZ, -1) 77 | anglesZ += angle 78 | rotZ = geometry._axis_angle_rotation("Z", anglesZ) 79 | 80 | new_trans = rotate_trans(local_trans, rotZ, inverse=True) 81 | 82 | new_global_euler = torch.stack((anglesZ, anglesY, anglesX), -1) 83 | new_global_orient = geometry.euler_angles_to_matrix(new_global_euler, "ZYX") 84 | 85 | return new_global_orient, new_trans -------------------------------------------------------------------------------- /src/data/tools/smpl.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import src.tools.geometry as geometry 3 | import torch 4 | from torch import Tensor 5 | from src.tools.easyconvert import matrix_to, axis_angle_to 6 | from src.transforms.smpl import RotTransDatastruct 7 | 8 | 9 | def canonicalize_smplh(poses: Tensor, trans: Optional[Tensor] = None): 10 | bs, nframes, njoints = poses.shape[:3] 11 | 12 | global_orient = poses[:, :, 0] 13 | 14 | # first global rotations 15 | rot2d = geometry.matrix_to_axis_angle(global_orient[:, 0]) 16 | rot2d[:, :2] = 0 # Remove the rotation along the vertical axis 17 | rot2d = geometry.axis_angle_to_matrix(rot2d) 18 | 19 | # Rotate the global rotation to eliminate Z rotations 20 | global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) 21 | 22 | # Construct canonicalized version of x 23 | xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2) 24 | 25 | if trans is not None: 26 | vel = trans[:, 1:] - trans[:, :-1] 27 | # Turn the translation as well 28 | vel = torch.einsum("ikj,ilk->ilj", rot2d, vel) 29 | trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), 30 | torch.cumsum(vel, 1)), 1) 31 | return xc, trans 32 | else: 33 | return xc 34 | 35 | 36 | def smpl_data_to_matrix_and_trans(data, nohands=True): 37 | trans = data['trans'] 38 | nframes = len(trans) 39 | try: 40 | axis_angle_poses = data['poses'] 41 | axis_angle_poses = data['poses'].reshape(nframes, -1, 3) 42 | except: 43 | breakpoint() 44 | 45 | if nohands: 46 | axis_angle_poses = axis_angle_poses[:, :22] 47 | 48 | matrix_poses = axis_angle_to("matrix", axis_angle_poses) 49 | 50 | return RotTransDatastruct(rots=matrix_poses, trans=trans) 51 | -------------------------------------------------------------------------------- /src/data/tools/tensors.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import torch 3 | from torch import Tensor 4 | 5 | def lengths_to_mask_njoints(lengths: List[int], njoints: int, device: torch.device) -> Tensor: 6 | # joints*lenghts 7 | joints_lengths = [njoints*l for l in lengths] 8 | joints_mask = lengths_to_mask(joints_lengths, device) 9 | return joints_mask 10 | 11 | 12 | def lengths_to_mask(lengths: List[int], device: torch.device) -> Tensor: 13 | lengths = torch.tensor(lengths, device=device) 14 | max_len = max(lengths) 15 | mask = torch.arange(max_len, 16 | device=device).expand(len(lengths), 17 | max_len) < lengths.unsqueeze(1) 18 | return mask 19 | 20 | from copy import copy 21 | import numpy as np 22 | import torch 23 | from torch import Tensor 24 | import logging 25 | import os 26 | import random 27 | from einops import rearrange 28 | 29 | def read_json(file_path): 30 | import json 31 | with open(file_path, 'r') as json_file: 32 | data = json.load(json_file) 33 | return data 34 | 35 | def freeze(model) -> None: 36 | r""" 37 | Freeze all params for inference. 38 | """ 39 | for param in model.parameters(): 40 | param.requires_grad = False 41 | 42 | model.eval() 43 | 44 | # A logger for this file 45 | log = logging.getLogger(__name__) 46 | def to_tensor(array): 47 | if torch.is_tensor(array): 48 | return array 49 | else: 50 | return torch.tensor(array) 51 | 52 | def DotDict(in_dict): 53 | if isinstance(in_dict, dotdict): 54 | return in_dict 55 | out_dict = copy(in_dict) 56 | for k,v in out_dict.items(): 57 | if isinstance(v,dict): 58 | out_dict[k] = DotDict(v) 59 | return dotdict(out_dict) 60 | 61 | 62 | def dict_to_device(tensor_dict, device): 63 | return {k: v.to(device) for k, v in tensor_dict.items()} 64 | 65 | class dotdict(dict): 66 | """dot.notation access to dictionary attributes""" 67 | __getattr__ = dict.get 68 | __setattr__ = dict.__setitem__ 69 | __delattr__ = dict.__delitem__ 70 | 71 | def cast_dict_to_tensors(d, device="cpu"): 72 | if isinstance(d, dict): 73 | return {k: cast_dict_to_tensors(v, device) for k, v in d.items()} 74 | elif isinstance(d, np.ndarray): 75 | return torch.from_numpy(d).float().to(device) 76 | elif isinstance(d, torch.Tensor): 77 | return d.to(device) 78 | else: 79 | return d -------------------------------------------------------------------------------- /src/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # from matplotlib import collections as mc 3 | # from matplotlib.pyplot import cm 4 | # import matplotlib.pyplot as plt 5 | from typing import Dict, List, Optional, Tuple 6 | 7 | def separate_actions(pair: Tuple[Tuple]): 8 | 9 | if len(pair) == 3: 10 | if pair[0][1] < pair[2][0]: 11 | # a1 10, 15 t 14, 18 a2 17, 25 12 | # a1 10, 15 t 16, 16 a2 17, 25 13 | # transition only --> transition does not matter 14 | final_pair = [(pair[0][0], pair[1][0]), 15 | (pair[1][0] + 1, pair[2][0] - 1), 16 | (pair[2][0], pair[2][1])] 17 | else: 18 | # overlap + transition --> transition does not matter 19 | over = pair[2][0] - pair[0][1] 20 | final_pair = [(pair[0][0], int(pair[0][1] + over/2)), 21 | (int(pair[0][1] + over/2 + 1), pair[2][1])] 22 | else: 23 | # give based on small or long 24 | # p1_prop = (pair[0][1] - pair[0][0]) / (eoq - soq) 25 | # p2_prop = (pair[1][1] - pair[1][0]) / (eoq - soq) 26 | # over = pair[1][0] - pair[0][1] 27 | # final_pair = [(pair[0][0], int(p1_prop*over) + pair[0][1]), 28 | # (int(p1_prop*over) + pair[0][1] + 1, pair[1][1])] 29 | 30 | # no transition at all 31 | over = pair[0][1] - pair[1][0] 32 | final_pair = [(pair[0][0], int(pair[0][1] + over/2)), 33 | (int(pair[0][1] + over/2 + 1), pair[1][1])] 34 | 35 | return final_pair 36 | 37 | def timeline_overlaps(interval: Tuple, interval_list: List[Tuple]) -> Tuple[List[Tuple], 38 | List[Tuple], 39 | List[Tuple], 40 | List[Tuple]]: 41 | ''' 42 | Returns the intervals for which: 43 | (1) arr1 has overlap with 44 | (2) arr1 is a subset of 45 | (3) arr1 is a superset of 46 | ''' 47 | l = interval[0] 48 | r = interval[1] 49 | inter_sub = [] 50 | inter_super = [] 51 | inter_before = [] 52 | inter_after = [] 53 | for s in interval_list: 54 | 55 | if (s[0] > l and s[0] > r) or (s[1] < l and s[1] < r): 56 | continue 57 | if s[0] <= l and s[1] >= r: 58 | inter_sub.append(s) 59 | if s[0] >= l and s[1] <= r: 60 | inter_super.append(s) 61 | if s[0] < l and s[1] < r and s[1] >= l: 62 | inter_before.append(s) 63 | if s[0] > l and s[0] <= r and s[1] > r: 64 | inter_after.append(s) 65 | 66 | return inter_before, inter_after 67 | 68 | def segments_sorted(segs_fr: List[List], acts: List) -> Tuple[List[List], List]: 69 | 70 | assert len(segs_fr) == len(acts) 71 | if len(segs_fr) == 1: return segs_fr, acts 72 | L = [ (segs_fr[i],i) for i in range(len(segs_fr)) ] 73 | L.sort() 74 | sorted_segs_fr, permutation = zip(*L) 75 | sort_acts = [acts[i] for i in permutation] 76 | 77 | return list(sorted_segs_fr), sort_acts 78 | 79 | 80 | # def plot_timeline(segments, babel_id, outdir=get_original_cwd(), accel=None): 81 | 82 | # seg_ids = [(s_s, s_e) for s_s, s_e, _ in segments] 83 | # seg_acts = [f'{a}\n{s_s}|---|{s_e}' for s_s, s_e, a in segments] 84 | # seg_lns = [ [(x[0], i*0.01), (x[1], i*0.01)] for i, x in enumerate(seg_ids) ] 85 | # colorline = cm.rainbow(np.linsinc.0, 1, len(seg_acts))) 86 | # lc = mc.LineCollection(seg_lns, colors=colorline, linewidths=3, 87 | # label=seg_acts) 88 | # fig, ax = plt.subplots() 89 | 90 | # ax.add_collection(lc) 91 | # fig.tight_layout() 92 | # ax.autoscale() 93 | # ax.margins(0.1) 94 | # # alternative for putting text there 95 | # # from matplotlib.lines import Line2D 96 | # # proxies = [ Line2D([0, 1], [0, 1], color=x) for x in colorline] 97 | # # ax.legend(proxies, seg_acts, fontsize='x-small', loc='upper left') 98 | # for i, a in enumerate(seg_acts): 99 | # plt.text((seg_ids[i][0]+seg_ids[i][1])/2, i*0.01 - 0.002, a, 100 | # fontsize='x-small', ha='center') 101 | # if accel is not None: 102 | # plt.plot(accel) 103 | # plt.title(f'Babel Sequence ID\n{babel_id}') 104 | # plt.savefig(f'{outdir}/plot_{babel_id}.png') 105 | # plt.close() 106 | -------------------------------------------------------------------------------- /src/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | # rescale_timesteps=rescale_timesteps, 46 | ) 47 | -------------------------------------------------------------------------------- /src/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /src/launch/blender.py: -------------------------------------------------------------------------------- 1 | # Fix blender path 2 | import sys 3 | # sys.path.append("/linkhome/rech/genlgm01/uwm78rj/.local/lib/python3.9/site-packages") 4 | 5 | import bpy 6 | import os 7 | DIR = os.path.dirname(bpy.data.filepath) 8 | if DIR not in sys.path: 9 | sys.path.append(DIR) 10 | 11 | from argparse import ArgumentParser 12 | 13 | # Workaround temorary for cluster vs local 14 | # TODO fix it 15 | import socket 16 | if socket.gethostname() == 'ps018': 17 | packages_path = '/home/nathanasiou/.local/lib/python3.10/site-packages' 18 | sys.path.insert(0, packages_path) 19 | sys.path.append("/home/nathanasiou/.venvs/teach/lib/python3.10/site-packages") 20 | sys.path.append('/usr/lib/python3/dist-packages') 21 | 22 | # Monkey patch argparse such that 23 | # blender / python / hydra parsing works 24 | def parse_args(self, args=None, namesinc.None): 25 | if args is not None: 26 | return self.parse_args_bak(args=args, namesinc.namesinc. 27 | try: 28 | idx = sys.argv.index("--") 29 | args = sys.argv[idx+1:] # the list after '--' 30 | except ValueError as e: # '--' not in the list: 31 | args = [] 32 | return self.parse_args_bak(args=args, namesinc.namesinc. 33 | 34 | setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args) 35 | setattr(ArgumentParser, 'parse_args', parse_args) 36 | -------------------------------------------------------------------------------- /src/launch/prepare.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # os.environ['HOME']='/home/nathanasiou' 4 | # sys.path.insert(0,'/usr/lib/python3.10/') 5 | # os.environ['PYTHONPATH']='/home/nathanasiou/.venvs/teach/lib/python3.10/site-packages' 6 | import warnings 7 | from pathlib import Path 8 | from omegaconf import OmegaConf 9 | from src.tools.runid import generate_id 10 | import hydra 11 | 12 | # Local paths 13 | def code_path(path=""): 14 | code_dir = hydra.utils.get_original_cwd() 15 | code_dir = Path(code_dir) 16 | return str(code_dir / path) 17 | 18 | def get_local_debug(): 19 | import socket 20 | hostname = socket.gethostname() 21 | if hostname == 'ps018': 22 | local_debug = True 23 | else: 24 | local_debug = False 25 | return local_debug 26 | 27 | def working_path(path): 28 | return str(Path(os.getcwd()) / path) 29 | 30 | 31 | # fix the id for this run 32 | ID = generate_id() 33 | def generate_id(): 34 | return ID 35 | 36 | def concat_string_list(l, d1, d2, d3): 37 | """ 38 | Concatenate the strings of a list in a sorted order 39 | """ 40 | # import ipdb; ipdb.set_trace() 41 | if d1 == 0: 42 | if 'hml3d' in l: 43 | l.remove('hml3d') 44 | if d2 == 0: 45 | if 'motionfix' in l: 46 | l.remove('motionfix') 47 | if d3 == 0: 48 | if 'sinc_synth' in l: 49 | l.remove('sinc_synth') 50 | 51 | return '_'.join(sorted(l)) 52 | 53 | def get_last_checkpoint(path, ckpt_name="last"): 54 | if path is None: 55 | return None 56 | output_dir = Path(hydra.utils.to_absolute_path(path)) 57 | if ckpt_name != 'last': 58 | last_ckpt_path = output_dir / "checkpoints" / f'latest-epoch={ckpt_name}.ckpt' 59 | else: 60 | last_ckpt_path = output_dir / "checkpoints/last.ckpt" 61 | # check exsitence 62 | if last_ckpt_path.exists(): 63 | return str(last_ckpt_path) 64 | else: 65 | return None 66 | 67 | def get_samples_folder(path): 68 | output_dir = Path(hydra.utils.to_absolute_path(path)) 69 | samples_path = output_dir / "samples" 70 | return str(samples_path) 71 | 72 | def get_expdir(debug): 73 | if debug: 74 | return 'experiments' 75 | else: 76 | return 'experiments' 77 | 78 | def get_debug(debug): 79 | if debug: 80 | return '_debug' 81 | else: 82 | return '' 83 | 84 | # this has to run -- pytorch memory leak in the dataloader associated with #973 pytorch issues 85 | #import resource 86 | #rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 87 | #resource.setrlimit(resource.RLIMIT_NOFILE, (12000, rlimit[1])) 88 | # Solutions summarized in --> https://github.com/Project-MONAI/MONAI/issues/701 89 | OmegaConf.register_new_resolver("get_debug", get_debug) 90 | OmegaConf.register_new_resolver("get_expdir", get_expdir) 91 | OmegaConf.register_new_resolver("code_path", code_path) 92 | OmegaConf.register_new_resolver("working_path", working_path) 93 | OmegaConf.register_new_resolver("generate_id", generate_id) 94 | OmegaConf.register_new_resolver("concat_string_list", concat_string_list) 95 | OmegaConf.register_new_resolver("absolute_path", hydra.utils.to_absolute_path) 96 | OmegaConf.register_new_resolver("get_last_checkpoint", get_last_checkpoint) 97 | OmegaConf.register_new_resolver("get_samples_folder", get_samples_folder) 98 | 99 | 100 | # Remove warnings 101 | warnings.filterwarnings( 102 | "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*" 103 | ) 104 | 105 | warnings.filterwarnings( 106 | "ignore", ".*does not have many workers which may be a bottleneck*" 107 | ) 108 | 109 | warnings.filterwarnings( 110 | "ignore", ".*Our suggested max number of worker in current system is*" 111 | ) 112 | -------------------------------------------------------------------------------- /src/logger/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from .wandb_log import WandbLogger 4 | # from pytorch_lightning.loggers import WandbLogger 5 | from pytorch_lightning.loggers import TensorBoardLogger 6 | from hydra.utils import to_absolute_path 7 | from omegaconf import DictConfig, OmegaConf 8 | from .tools import cfg_to_flatten_config 9 | import types 10 | import wandb 11 | 12 | def instantiate_logger(cfg: DictConfig): 13 | conf = OmegaConf.to_container(cfg.logger, resolve=True) 14 | name = conf.pop('logger_name') 15 | if name is None: 16 | return False 17 | if name == 'wandb': 18 | project_save_dir = to_absolute_path(Path(cfg.path.working_dir) / conf['save_dir']) 19 | Path(cfg.path.working_dir) 20 | Path(project_save_dir).mkdir(exist_ok=True) 21 | conf['dir'] = project_save_dir 22 | conf['config'] = cfg_to_flatten_config(cfg) 23 | # maybe do this for connection error in cluster, could be redundant 24 | 25 | # conf['settings'] = wandb.Settings(start_method="fork") 26 | 27 | # conf['mode']= 'online' if not cfg.logger.offline else 'offline' 28 | conf['notes']= cfg.logger.notes if cfg.logger.notes is not None else None 29 | conf['tags'] = cfg.logger.tags.strip().split(',')\ 30 | if cfg.logger.tags is not None else None 31 | print('init WandbLogger') 32 | logger = WandbLogger(**conf) 33 | print('after init WandbLogger') 34 | 35 | # begin / end already defined 36 | 37 | else: 38 | def begin(self, *args, **kwargs): 39 | return 40 | 41 | def end(self, *args, **kwargs): 42 | return 43 | 44 | if name == 'tensorboard': 45 | # TODO: need to update 46 | # logger = TensorBoardLogger(**conf) 47 | logger = TensorBoardLogger(save_dir = \ 48 | to_absolute_path(Path(cfg.path.working_dir) / conf['save_dir']), name='lightning_logs') 49 | 50 | logger.begin = begin 51 | logger.end = end 52 | else: 53 | raise NotImplementedError("This logger is not recognized.") 54 | 55 | logger.lname = name 56 | return logger 57 | -------------------------------------------------------------------------------- /src/logger/tools.py: -------------------------------------------------------------------------------- 1 | # Taken from pytorch lighting / loggers / base 2 | # Mimic the log hyperparams of wandb 3 | from argparse import Namespace 4 | from typing import Dict, Any, MutableMapping, Union 5 | from omegaconf import DictConfig 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: 12 | # in case converting from namespace 13 | if isinstance(params, Namespace): 14 | params = vars(params) 15 | 16 | if params is None: 17 | params = {} 18 | 19 | return params 20 | 21 | 22 | def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]: 23 | """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. 24 | 25 | Args: 26 | params: Dictionary containing the hyperparameters 27 | delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. 28 | 29 | Returns: 30 | Flattened dict. 31 | 32 | Examples: 33 | >>> _flatten_dict({'a': {'b': 'c'}}) 34 | {'a/b': 'c'} 35 | >>> _flatten_dict({'a': {'b': 123}}) 36 | {'a/b': 123} 37 | >>> _flatten_dict({5: {'a': 123}}) 38 | {'5/a': 123} 39 | """ 40 | 41 | def _dict_generator(input_dict, prefixes=None): 42 | prefixes = prefixes[:] if prefixes else [] 43 | if isinstance(input_dict, MutableMapping): 44 | for key, value in input_dict.items(): 45 | key = str(key) 46 | if isinstance(value, (MutableMapping, Namespace)): 47 | value = vars(value) if isinstance(value, Namespace) else value 48 | yield from _dict_generator(value, prefixes + [key]) 49 | else: 50 | yield prefixes + [key, value if value is not None else str(None)] 51 | else: 52 | yield prefixes + [input_dict if input_dict is None else str(input_dict)] 53 | 54 | return {delimiter.join(keys): val for *keys, val in _dict_generator(params)} 55 | 56 | 57 | def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: 58 | """Returns params with non-primitvies converted to strings for logging. 59 | 60 | >>> params = {"float": 0.3, 61 | ... "int": 1, 62 | ... "string": "abc", 63 | ... "bool": True, 64 | ... "list": [1, 2, 3], 65 | ... "namesinc.: Namesinc.foo=3), 66 | ... "layer": torch.nn.BatchNorm1d} 67 | >>> import pprint 68 | >>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE 69 | {'bool': True, 70 | 'float': 0.3, 71 | 'int': 1, 72 | 'layer': "", 73 | 'list': '[1, 2, 3]', 74 | 'namesinc.: 'Namesinc.foo=3)', 75 | 'string': 'abc'} 76 | """ 77 | for k in params.keys(): 78 | # convert relevant np scalars to python types first (instead of str) 79 | if isinstance(params[k], (np.bool_, np.integer, np.floating)): 80 | params[k] = params[k].item() 81 | elif type(params[k]) not in [bool, int, float, str, torch.Tensor]: 82 | params[k] = str(params[k]) 83 | return params 84 | 85 | 86 | def cfg_to_flatten_config(cfg: DictConfig): 87 | params = _convert_params(cfg) 88 | params = _flatten_dict(params) 89 | params = _sanitize_params(params) 90 | return params 91 | -------------------------------------------------------------------------------- /src/logger/wandb_log.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from pytorch_lightning.utilities import rank_zero_only 3 | from pytorch_lightning.loggers import WandbLogger as _pl_WandbLogger 4 | import os 5 | from pathlib import Path 6 | import time 7 | 8 | # Fix the step logging 9 | class WandbLogger(_pl_WandbLogger): 10 | # @rank_zero_only 11 | # def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: 12 | 13 | # assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" 14 | # metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) 15 | 16 | # # if 'epoch' not in metrics: 17 | # # self.experiment.log({**metrics, "trainer/global_step": step}, 18 | # # step=step) 19 | # # else: 20 | 21 | # if 'epoch' not in metrics: 22 | # wandb_step = self.experiment.step 23 | # else: 24 | # wandb_step = int(metrics["epoch"]) 25 | 26 | # if step is not None: 27 | # self.experiment.log({**metrics, "trainer/global_step": step}, 28 | # step=wandb_step) 29 | # else: 30 | # self.experiment.log(metrics, step=wandb_step) 31 | 32 | # @property 33 | def name(self) -> Optional[str]: 34 | """ Override the method because model checkpointing define the path before 35 | the initialization, and in offline mode you can't get the good path 36 | """ 37 | # don't create an experiment if we don't have one 38 | # return self._experiment.project_name() if self._experiment else self._name 39 | return self._wandb_init["project"] 40 | 41 | def symlink_checkpoint(self, code_dir, project, run_id): 42 | # this is the hydra run dir!! see train.yaml 43 | local_project_dir = Path("wandb") / project 44 | local_project_dir.mkdir(parents=True, exist_ok=True) 45 | 46 | # ... but code_dir is the current dir see path.yaml 47 | Path(code_dir) / project / run_id 48 | os.symlink(Path(code_dir) / "wandb" / project / run_id, 49 | local_project_dir / f'{run_id}_{time.strftime("%Y%m%d%H%M%S")}') 50 | # # Creating a another symlink for easy access 51 | os.symlink(Path(code_dir) / "wandb" / project / run_id / "checkpoints", 52 | Path("checkpoints")) 53 | # if it exists an error is spawned which makes sense, but ... 54 | 55 | def symlink_run(self, checkpoint_folder: str): 56 | 57 | code_dir = checkpoint_folder.split("wandb/")[0] 58 | # # local run 59 | local_wandb = Path("wandb/wandb") 60 | local_wandb.mkdir(parents=True, exist_ok=True) 61 | offline_run = self.experiment.dir.split("wandb/wandb/")[1].split("/files")[0] 62 | # # Create the symlink 63 | os.symlink(Path(code_dir) / "wandb/wandb" / offline_run, local_wandb / offline_run) 64 | 65 | def begin(self, code_dir, project, run_id): 66 | self.symlink_checkpoint(code_dir, project, run_id) 67 | 68 | def end(self, checkpoint_folder): 69 | self.symlink_run(checkpoint_folder) 70 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/dummy.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from src.data.tools import PoseData 6 | 7 | 8 | class Dummy(pl.LightningModule): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__() 11 | self.linear = nn.Linear(10, 10) 12 | self.relu = nn.ReLU() 13 | self.store_examples = {"train": None, 14 | "val": None, 15 | "test": None} 16 | 17 | def forward(self, batch: dict) -> PoseData: 18 | return batch["pose_data"] 19 | 20 | def allsplit_step(self, split: str, batch, batch_idx): 21 | joints = batch["pose_data"].joints 22 | if batch_idx == 0: 23 | self.store_examples[split] = { 24 | "text": batch["text"], 25 | "length": batch["length"], 26 | "ref": joints, 27 | "from_text": joints, 28 | "from_motion": joints 29 | } 30 | 31 | x = batch["pose_data"].poses 32 | x = torch.rand((5, 10), device=x.device) 33 | xhat = self.linear(x) 34 | 35 | loss = torch.nn.functional.mse_loss(x, xhat) 36 | return loss 37 | 38 | def training_step(self, batch, batch_idx): 39 | return self.allsplit_step("train", batch, batch_idx) 40 | 41 | def validation_step(self, batch, batch_idx): 42 | self.allsplit_step("val", batch, batch_idx) 43 | 44 | def test_step(self, batch, batch_idx): 45 | self.allsplit_step("test", batch, batch_idx) 46 | 47 | def configure_optimizers(self): 48 | optimizer = torch.optim.AdamW(lr=1e-4, params=self.parameters()) 49 | return {"optimizer": optimizer} -------------------------------------------------------------------------------- /src/model/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .compute_mld import MLDLosses -------------------------------------------------------------------------------- /src/model/losses/compute.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | 4 | from torchmetrics import Metric 5 | 6 | 7 | class TemosComputeLosses(Metric): 8 | def __init__(self, vae: bool, 9 | mode: str, 10 | loss_on_both: bool = False, 11 | motion_branch: bool = False, 12 | loss_on_jfeats: bool = True, 13 | ablation_no_kl_combine: bool = False, 14 | ablation_no_motionencoder: bool = False, 15 | ablation_no_kl_gaussian: bool = False, 16 | dist_sync_on_step=True, **kwargs): 17 | super().__init__(dist_sync_on_step=dist_sync_on_step) 18 | 19 | # Save parameters 20 | self.vae = vae 21 | self.mode = mode 22 | self.motion_branch = motion_branch 23 | 24 | self.loss_on_both = loss_on_both 25 | self.ablation_no_kl_combine = ablation_no_kl_combine 26 | self.ablation_no_kl_gaussian = ablation_no_kl_gaussian 27 | self.ablation_no_motionencoder = ablation_no_motionencoder 28 | 29 | self.loss_on_jfeats = loss_on_jfeats 30 | losses = [] 31 | if mode == "xyz" or loss_on_jfeats: 32 | if motion_branch: 33 | losses.append("recons_jfeats2jfeats") 34 | losses.append("recons_text2jfeats") 35 | if mode == "smpl": 36 | if motion_branch: 37 | losses.append("recons_rfeats2rfeats") 38 | losses.append("recons_text2rfeats") 39 | else: 40 | ValueError("This mode is not recognized.") 41 | 42 | if vae or loss_on_both: 43 | kl_losses = [] 44 | if not ablation_no_kl_combine and not ablation_no_motionencoder: 45 | kl_losses.extend(["kl_text2motion", "kl_motion2text"]) 46 | if not ablation_no_kl_gaussian: 47 | if not motion_branch: 48 | kl_losses.extend(["kl_text"]) 49 | else: 50 | kl_losses.extend(["kl_text", "kl_motion"]) 51 | losses.extend(kl_losses) 52 | if not self.vae or loss_on_both: 53 | if motion_branch: 54 | losses.append("latent_manifold") 55 | losses.append("total") 56 | 57 | for loss in losses: 58 | self.add_state(loss, default=torch.tensor(0.0), dist_reduce_fx="sum") 59 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 60 | self.losses = losses 61 | 62 | # Instantiate loss functions 63 | self._losses_func = {loss: hydra.utils.instantiate(kwargs[loss + "_func"]) 64 | for loss in losses if loss != "total"} 65 | # Save the lambda parameters 66 | self._params = {loss: kwargs[loss] for loss in losses if loss != "total"} 67 | 68 | def update(self, ds_text=None, ds_motion=None, ds_ref=None, 69 | lat_text=None, lat_motion=None, dis_text=None, 70 | dis_motion=None, dis_ref=None): 71 | total: float = 0.0 72 | 73 | if self.mode == "xyz" or self.loss_on_jfeats: 74 | if self.motion_branch: 75 | total += self._update_loss("recons_jfeats2jfeats", ds_motion.jfeats, ds_ref.jfeats) 76 | total += self._update_loss("recons_text2jfeats", ds_text.jfeats, ds_ref.jfeats) 77 | if self.mode == "smpl": 78 | if self.motion_branch: 79 | total += self._update_loss("recons_rfeats2rfeats", ds_motion.rfeats, ds_ref.rfeats) 80 | total += self._update_loss("recons_text2rfeats", ds_text.rfeats, ds_ref.rfeats) 81 | 82 | if self.vae or self.loss_on_both: 83 | if not self.ablation_no_kl_combine and self.motion_branch: 84 | total += self._update_loss("kl_text2motion", dis_text, dis_motion) 85 | total += self._update_loss("kl_motion2text", dis_motion, dis_text) 86 | if not self.ablation_no_kl_gaussian: 87 | total += self._update_loss("kl_text", dis_text, dis_ref) 88 | if self.motion_branch: 89 | total += self._update_loss("kl_motion", dis_motion, dis_ref) 90 | if not self.vae or self.loss_on_both: 91 | if self.motion_branch: 92 | total += self._update_loss("latent_manifold", lat_text, lat_motion) 93 | 94 | self.total += total.detach() 95 | self.count += 1 96 | 97 | return total 98 | 99 | def compute(self, split): 100 | count = getattr(self, "count") 101 | return {loss: getattr(self, loss)/count for loss in self.losses} 102 | 103 | def _update_loss(self, loss: str, outputs, inputs): 104 | # Update the loss 105 | val = self._losses_func[loss](outputs, inputs) 106 | getattr(self, loss).__iadd__(val.detach()) 107 | # Return a weighted sum 108 | weighted_loss = self._params[loss] * val 109 | return weighted_loss 110 | 111 | def loss2logname(self, loss: str, split: str): 112 | if loss == "total": 113 | log_name = f"losses/{loss}/{split}" 114 | else: 115 | loss_type, name = loss.split("_") 116 | log_name = f"losses/{loss_type}/{name}/{split}" 117 | return log_name 118 | -------------------------------------------------------------------------------- /src/model/losses/compute_mld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torchmetrics import Metric 5 | 6 | 7 | class MLDLosses(Metric): 8 | """ 9 | MLD Loss 10 | """ 11 | 12 | def __init__(self, predict_epsilon, lmd_prior, 13 | lmd_kl, lmd_gen, lmd_recons, 14 | **kwargs): 15 | super().__init__(dist_sync_on_step=True) 16 | 17 | # Save parameters 18 | # self.vae = vae 19 | self.predict_epsilon = predict_epsilon 20 | self.lmd_prior = lmd_prior 21 | self.lmd_kl = lmd_kl 22 | self.lmd_gen = lmd_gen 23 | self.lmd_recons = lmd_recons 24 | losses = [] 25 | 26 | # diffusion loss 27 | # instance noise loss 28 | losses.append("inst_loss") 29 | losses.append("x_loss") 30 | if lmd_prior != 0.0: 31 | # prior noise loss 32 | losses.append("prior_loss") 33 | 34 | # if self.stage in ['vae', 'vae_diffusion']: 35 | # # reconstruction loss 36 | # losses.append("recons_feature") 37 | # losses.append("recons_verts") 38 | # losses.append("recons_joints") 39 | # losses.append("recons_limb") 40 | 41 | # losses.append("gen_feature") 42 | # losses.append("gen_joints") 43 | 44 | # # KL loss 45 | # losses.append("kl_motion") 46 | 47 | 48 | losses.append("loss") 49 | 50 | for loss in losses: 51 | self.add_state(loss, 52 | default=torch.tensor(0.0), 53 | dist_reduce_fx="sum") 54 | # self.register_buffer(loss, torch.tensor(0.0)) 55 | self.add_state("count", torch.tensor(0), dist_reduce_fx="sum") 56 | self.losses = losses 57 | 58 | self._losses_func = {} 59 | self._params = {} 60 | for loss in losses: 61 | if loss.split('_')[0] == 'inst': 62 | self._losses_func[loss] = nn.MSELoss(reduction='mean') 63 | self._params[loss] = 1 64 | elif loss.split('_')[0] == 'x': 65 | self._losses_func[loss] = nn.MSELoss(reduction='mean') 66 | self._params[loss] = 1 67 | elif loss.split('_')[0] == 'prior': 68 | self._losses_func[loss] = nn.MSELoss(reduction='mean') 69 | self._params[loss] = self.lmd_prior 70 | if loss.split('_')[0] == 'kl': 71 | if self.lmd_kl != 0.0: 72 | self._losses_func[loss] = KLLoss() 73 | self._params[loss] = self.lmd_kl 74 | elif loss.split('_')[0] == 'recons': 75 | self._losses_func[loss] = torch.nn.SmoothL1Loss( 76 | reduction='mean') 77 | self._params[loss] = self.lmd_recons 78 | elif loss.split('_')[0] == 'gen': 79 | self._losses_func[loss] = torch.nn.SmoothL1Loss( 80 | reduction='mean') 81 | self._params[loss] = self.lmd_gen 82 | else: 83 | ValueError("This loss is not recognized.") 84 | 85 | def update(self, rs_set): 86 | total_loss: float = 0.0 87 | # Compute the losses 88 | # Compute instance loss 89 | 90 | # predict noise 91 | if self.predict_epsilon: 92 | total_loss += self._update_loss("inst_loss", rs_set['noise_pred'], 93 | rs_set['noise']) 94 | # predict x 95 | else: 96 | total_loss += self._update_loss("x_loss", rs_set['pred'], 97 | rs_set['diff_in']) 98 | 99 | if self.lmd_prior != 0.0: 100 | # loss - prior loss 101 | total_loss += self._update_loss("prior_loss", rs_set['noise_prior'], 102 | rs_set['dist_m1']) 103 | 104 | self.loss += total_loss 105 | self.count += 1 106 | 107 | return total_loss 108 | 109 | def compute(self): 110 | count = getattr(self, "count") 111 | return {loss: getattr(self, loss) / count for loss in self.losses} 112 | 113 | def _update_loss(self, loss: str, outputs, inputs): 114 | # Update the loss 115 | val = self._losses_func[loss](outputs, inputs) 116 | getattr(self, loss).__iadd__(val) 117 | # Return a weighted sum 118 | weighted_loss = self._params[loss] * val 119 | return weighted_loss 120 | 121 | def loss2logname(self, loss: str, split: str): 122 | if loss == "loss": 123 | log_name = f"total_{loss}/{split}" 124 | else: 125 | loss_type, name = loss.split("_") 126 | log_name = f"{loss_type}/{name}/{split}" 127 | return log_name 128 | 129 | class KLLoss: 130 | 131 | def __init__(self): 132 | pass 133 | 134 | def __call__(self, q, p): 135 | div = torch.distributions.kl_divergence(q, p) 136 | return div.mean() 137 | 138 | def __repr__(self): 139 | return "KLLoss()" 140 | 141 | 142 | class KLLossMulti: 143 | 144 | def __init__(self): 145 | self.klloss = KLLoss() 146 | 147 | def __call__(self, qlist, plist): 148 | return sum([self.klloss(q, p) for q, p in zip(qlist, plist)]) 149 | 150 | def __repr__(self): 151 | return "KLLossMulti()" 152 | -------------------------------------------------------------------------------- /src/model/losses/kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KLLoss: 5 | def __init__(self): 6 | pass 7 | 8 | def __call__(self, q, p, reduce_fx='mean'): 9 | div = torch.distributions.kl_divergence(q, p) 10 | if reduce_fx == 'mean': 11 | return div.mean() 12 | else: 13 | return div 14 | 15 | def __repr__(self): 16 | return "KLLoss()" 17 | 18 | 19 | class KLLossMulti: 20 | def __init__(self): 21 | self.klloss = KLLoss() 22 | 23 | def __call__(self, qlist, plist): 24 | return sum([self.klloss(q, p) 25 | for q, p in zip(qlist, plist)]) 26 | 27 | def __repr__(self): 28 | return "KLLossMulti()" 29 | -------------------------------------------------------------------------------- /src/model/losses/recons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import smooth_l1_loss 3 | 4 | 5 | class Recons: 6 | def __call__(self, input_motion_feats_lst, output_features_lst): 7 | # for x,y in zip(input_motion_feats_lst, output_features_lst): 8 | # print('----------------------') 9 | # print(x.shape, y.shape) 10 | # print(smooth_l1_loss(x, y, reduction="mean").shape) 11 | # print('----------------------') 12 | recons = torch.stack([smooth_l1_loss(x.squeeze(), y.squeeze(), reduction="mean") for x,y in zip(input_motion_feats_lst, 13 | output_features_lst)]).mean() 14 | return recons 15 | 16 | def __repr__(self): 17 | return "Recons()" 18 | -------------------------------------------------------------------------------- /src/model/losses/recons_bp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import smooth_l1_loss 3 | 4 | 5 | class ReconsBP: 6 | def __call__(self, input_motion_feats_lst, output_features_lst): 7 | recons = [smooth_l1_loss(x.squeeze(), 8 | y.squeeze(), 9 | reduction='none') for x,y in zip(input_motion_feats_lst, 10 | output_features_lst)] 11 | 12 | recons = torch.stack([bpl.mean((1,2)) for bpl in recons], dim=1) 13 | return recons 14 | 15 | def __repr__(self): 16 | return "ReconsBP()" 17 | 18 | 19 | -------------------------------------------------------------------------------- /src/model/losses/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.nn import Module 4 | 5 | class LossTracker(Module): 6 | def __init__(self, losses): 7 | super().__init__() 8 | self.losses = losses 9 | self.training = False 10 | self.count = 0 11 | for loss in self.losses: 12 | self.register_buffer(loss, torch.tensor(0.0, device='cpu'), persistent=False) 13 | 14 | def reset(self): 15 | self.count = 0 16 | for loss in self.losses: 17 | getattr(self, loss).__imul__(0) 18 | 19 | def update(self, losses_dict): 20 | self.count += 1 21 | for loss_name, loss_val in losses_dict.items(): 22 | getattr(self, loss_name).__iadd__(loss_val) 23 | 24 | def compute(self): 25 | if self.count == 0: 26 | raise ValueError("compute should be called after update") 27 | # compute the mean 28 | return {loss: getattr(self, loss)/self.count for loss in self.losses} 29 | 30 | def loss2logname(self, loss: str, split: str): 31 | if loss == "total": 32 | log_name = f"{loss}/{split}" 33 | else: 34 | 35 | if '_multi' in loss: 36 | if 'bodypart' in loss: 37 | loss_type, name, multi, _ = loss.split("_") 38 | name = f'{name}_multiple_bp' 39 | else: 40 | loss_type, name, multi = loss.split("_") 41 | name = f'{name}_multiple' 42 | else: 43 | loss_type, name = loss.split("_") 44 | log_name = f"{loss_type}/{name}/{split}" 45 | return log_name 46 | -------------------------------------------------------------------------------- /src/model/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from .compute import ComputeMetrics -------------------------------------------------------------------------------- /src/model/motiondecoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import ActorAgnosticDecoder 2 | -------------------------------------------------------------------------------- /src/model/motiondecoder/actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | 6 | from typing import List, Optional 7 | from torch import nn, Tensor 8 | 9 | from src.model.utils import PositionalEncoding 10 | from src.data.tools import lengths_to_mask 11 | from einops import rearrange 12 | 13 | 14 | class ActorAgnosticDecoder(pl.LightningModule): 15 | def __init__(self, nfeats: int, 16 | latent_dim: int = 256, ff_size: int = 1024, 17 | num_layers: int = 4, num_heads: int = 4, 18 | dropout: float = 0.1, 19 | activation: str = "gelu", **kwargs) -> None: 20 | 21 | super().__init__() 22 | self.save_hyperparameters(logger=False) 23 | 24 | output_feats = nfeats 25 | 26 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) # multi GPU 27 | 28 | seq_trans_decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, 29 | nhead=num_heads, 30 | dim_feedforward=ff_size, 31 | dropout=dropout, 32 | activation=activation) # for multi GPU 33 | 34 | self.seqTransDecoder = nn.TransformerDecoder(seq_trans_decoder_layer, 35 | num_layers=num_layers) 36 | 37 | self.final_layer = nn.Linear(latent_dim, output_feats) 38 | 39 | def forward(self, z: Tensor, mask: Tensor, mem_masks=None): 40 | latent_dim = z.shape[-1] 41 | bs, nframes = mask.shape 42 | nfeats = self.hparams.nfeats 43 | 44 | # z = z[:, None] # sequence of 1 element for the memory 45 | # separate latents 46 | # torch.cat((z0[:, None], z1[:, None]), 1) 47 | if len(z.shape) > 3: 48 | z = rearrange(z, "bs nz z_len latent_dim -> (nz z_len) bs latent_dim") 49 | else: 50 | z = rearrange(z, "bs z_len latent_dim -> z_len bs latent_dim") 51 | 52 | # Construct time queries 53 | time_queries = torch.zeros(nframes, bs, latent_dim, device=z.device) 54 | time_queries = self.sequence_pos_encoding(time_queries) 55 | 56 | # Pass through the transformer decoder 57 | # with the latent vector for memory 58 | if mem_masks is not None: 59 | mem_masks = ~mem_masks 60 | output = self.seqTransDecoder(tgt=time_queries, memory=z, 61 | tgt_key_padding_mask=~mask, 62 | memory_key_padding_mask=mem_masks) 63 | output = self.final_layer(output) 64 | # zero for padded area 65 | output[~mask.T] = 0 66 | # Pytorch Transformer: [Sequence, Batch size, ...] 67 | feats = output.permute(1, 0, 2) 68 | return feats 69 | -------------------------------------------------------------------------------- /src/model/motionencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import ActorAgnosticEncoder -------------------------------------------------------------------------------- /src/model/motionencoder/actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | 6 | from typing import List, Optional, Union 7 | from torch import nn, Tensor 8 | from torch.distributions.distribution import Distribution 9 | 10 | from src.model.utils import PositionalEncoding 11 | from src.data.tools import lengths_to_mask 12 | 13 | 14 | class ActorAgnosticEncoder(nn.Module): 15 | def __init__(self, nfeats: int, 16 | latent_dim: int = 256, ff_size: int = 1024, 17 | num_layers: int = 4, num_heads: int = 4, 18 | dropout: float = 0.1, 19 | activation: str = "gelu", **kwargs) -> None: 20 | super().__init__() 21 | 22 | input_feats = nfeats 23 | self.skel_embedding = nn.Linear(input_feats, latent_dim) 24 | # self.layer_norm = nn.LayerNorm(nfeats) 25 | # Action agnostic: only one set of params 26 | self.emb_token = nn.Parameter(torch.randn(latent_dim)) 27 | 28 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) # multi-GPU 29 | 30 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, 31 | nhead=num_heads, 32 | dim_feedforward=ff_size, 33 | dropout=dropout, 34 | activation=activation) # multi-gpu 35 | 36 | self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, 37 | num_layers=num_layers) 38 | 39 | def forward(self, features: Tensor, mask: Tensor) -> Union[Tensor, Distribution]: 40 | in_mask = mask 41 | device = features.device 42 | 43 | nframes, bs, nfeats = features.shape 44 | 45 | x = features 46 | # Embed each human poses into latent vectors 47 | # x = self.layer_norm(x) 48 | x = self.skel_embedding(x) 49 | # Switch sequence and batch_size because the input of 50 | # Pytorch Transformer is [Sequence, Batch size, ...] 51 | # x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] 52 | # Each batch has its own set of tokens 53 | emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1) 54 | 55 | # adding the embedding token for all sequences 56 | xseq = torch.cat((emb_token[None], x), 0) 57 | 58 | # create a bigger mask, to allow attend to emb 59 | token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) 60 | aug_mask = torch.cat((token_mask, in_mask), 1) 61 | # add positional encoding 62 | xseq = self.sequence_pos_encoding(xseq) 63 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 64 | 65 | return final[0] 66 | -------------------------------------------------------------------------------- /src/model/readme.txt: -------------------------------------------------------------------------------- 1 | denoiser models 2 | phase encoding 3 | length input -------------------------------------------------------------------------------- /src/model/textencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .distilbert_encoder import DistilbertEncoderTransformer 2 | from .clip_encoder import ClipTextEncoder 3 | from .t5_encoder import T5TextEncoder -------------------------------------------------------------------------------- /src/model/textencoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.distributions.distribution import Distribution 7 | from src.utils.file_io import hack_path 8 | import pytorch_lightning as pl 9 | 10 | class ClipTextEncoder(pl.LightningModule): 11 | 12 | def __init__( 13 | self, 14 | modelpath: str, 15 | finetune: bool = False, 16 | last_hidden_state: bool = True, 17 | **kwargs 18 | ) -> None: 19 | 20 | super().__init__() 21 | self.save_hyperparameters(logger=False) 22 | from transformers import logging 23 | from transformers import AutoModel, AutoTokenizer 24 | logging.set_verbosity_error() 25 | # Tokenizer 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | 28 | self.tokenizer = AutoTokenizer.from_pretrained(hack_path(modelpath)) 29 | self.text_model = AutoModel.from_pretrained(hack_path(modelpath)) 30 | 31 | # Don't train the model 32 | if not finetune: 33 | self.text_model.training = False 34 | for p in self.text_model.parameters(): 35 | p.requires_grad = False 36 | 37 | # Then configure the model 38 | self.max_length = self.tokenizer.model_max_length 39 | if "clip" in modelpath: 40 | self.text_encoded_dim = self.text_model.config.text_config.hidden_size 41 | if last_hidden_state: 42 | self.variant = "clip_hidden" 43 | else: 44 | self.variant = "clip" 45 | elif "bert" in modelpath: 46 | self.variant = "bert" 47 | self.text_encoded_dim = self.text_model.config.hidden_size 48 | else: 49 | raise ValueError(f"Model {modelpath} not supported") 50 | 51 | def forward(self, texts: List[str]): 52 | # get prompt text embeddings 53 | if self.variant in ["clip", "clip_hidden"]: 54 | text_inputs = self.tokenizer( 55 | texts, 56 | padding="max_length", 57 | truncation=True, 58 | max_length=self.max_length, 59 | return_tensors="pt", 60 | ) 61 | text_input_ids = text_inputs.input_ids.to(self.text_model.device) 62 | txt_att_mask = text_inputs.attention_mask.to(self.text_model.device) 63 | # split into max length Clip can handle 64 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: 65 | text_input_ids = text_input_ids[:, :self.tokenizer. 66 | model_max_length] 67 | elif self.variant == "bert": 68 | text_inputs = self.tokenizer(texts, 69 | return_tensors="pt", 70 | padding=True) 71 | 72 | # use pooled ouuput if latent dim is two-dimensional 73 | # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) 74 | # text encoder forward, clip must use get_text_features 75 | # TODO check the CLIP network 76 | if self.variant == "clip": 77 | # (batch_Size, text_encoded_dim) 78 | text_embeddings = self.text_model.get_text_features( 79 | text_input_ids.to(self.text_model.device)) 80 | # (batch_Size, 1, text_encoded_dim) 81 | text_embeddings = text_embeddings.unsqueeze(1) 82 | elif self.variant == "clip_hidden": 83 | # (batch_Size, seq_length , text_encoded_dim) 84 | text_embeddings = self.text_model.text_model(text_input_ids, 85 | # attention_mask=txt_att_mask 86 | ).last_hidden_state 87 | elif self.variant == "bert": 88 | # (batch_Size, seq_length , text_encoded_dim) 89 | text_embeddings = self.text_model( 90 | **text_inputs.to(self.text_model.device)).last_hidden_state 91 | else: 92 | raise NotImplementedError(f"Model {self.name} not implemented") 93 | 94 | return text_embeddings, txt_att_mask.bool() 95 | -------------------------------------------------------------------------------- /src/model/textencoder/distilbert.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import pytorch_lightning as pl 3 | 4 | import torch.nn as nn 5 | import os 6 | 7 | import torch 8 | from torch import Tensor 9 | from torch.distributions.distribution import Distribution 10 | 11 | 12 | class DistilbertEncoderBase(pl.LightningModule): 13 | def __init__(self, modelpath: str, 14 | finetune: bool = False) -> None: 15 | super().__init__() 16 | 17 | from transformers import AutoTokenizer, AutoModel 18 | from transformers import logging 19 | logging.set_verbosity_error() 20 | # Tokenizer 21 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 22 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 23 | 24 | # Text model 25 | self.text_model = AutoModel.from_pretrained(modelpath) 26 | # Don't train the model 27 | if not finetune: 28 | self.text_model.training = False 29 | for p in self.text_model.parameters(): 30 | p.requires_grad = False 31 | 32 | # Then configure the model 33 | self.text_encoded_dim = self.text_model.config.hidden_size 34 | 35 | def train(self, mode: bool = True): 36 | self.training = mode 37 | for module in self.children(): 38 | # Don't put the model in 39 | if module == self.text_model and not self.hparams.finetune: 40 | continue 41 | module.train(mode) 42 | return self 43 | 44 | def get_last_hidden_state(self, texts: List[str], 45 | return_mask: bool = False 46 | ) -> Union[Tensor, tuple[Tensor, Tensor]]: 47 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 48 | output = self.text_model(**encoded_inputs.to(self.text_model.device)) 49 | if not return_mask: 50 | return output.last_hidden_state 51 | return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool) -------------------------------------------------------------------------------- /src/model/textencoder/distilbert_encoder.py: -------------------------------------------------------------------------------- 1 | from .distilbert import DistilbertEncoderBase 2 | import torch 3 | 4 | from typing import List, Union 5 | from torch import nn, Tensor 6 | from torch.distributions.distribution import Distribution 7 | 8 | from src.model.utils import PositionalEncoding 9 | from src.data.tools import lengths_to_mask 10 | from src.utils.file_io import hack_path 11 | 12 | class DistilbertEncoderTransformer(DistilbertEncoderBase): 13 | def __init__(self, modelpath: str, 14 | finetune: bool = False, 15 | vae: bool = False, 16 | latent_dim: int = 256, 17 | ff_size: int = 1024, 18 | num_layers: int = 4, num_heads: int = 4, 19 | dropout: float = 0.1, 20 | activation: str = "gelu", **kwargs) -> None: 21 | super().__init__(modelpath=hack_path(modelpath), finetune=finetune) 22 | self.save_hyperparameters(logger=False) 23 | encoded_dim = self.text_encoded_dim 24 | latent_dim = latent_dim 25 | # Projection of the text-outputs into the latent space 26 | self.projection = nn.Sequential(nn.ReLU(), 27 | nn.Linear(encoded_dim, latent_dim)) 28 | 29 | # TransformerVAE adapted from ACTOR 30 | # Action agnostic: only one set of params 31 | if vae: 32 | self.mu_token = nn.Parameter(torch.randn(latent_dim)) 33 | self.logvar_token = nn.Parameter(torch.randn(latent_dim)) 34 | else: 35 | self.emb_token = nn.Parameter(torch.randn(latent_dim)) 36 | 37 | self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout) 38 | 39 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, 40 | nhead=num_heads, 41 | dim_feedforward=ff_size, 42 | dropout=dropout, 43 | activation=activation) 44 | 45 | self.seqTransEncoder = nn.TransformerEncoder(seq_trans_encoder_layer, 46 | num_layers=num_layers) 47 | 48 | def forward(self, texts: List[str]) -> Union[Tensor, Distribution]: 49 | text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True) 50 | 51 | x = self.projection(text_encoded) 52 | bs, nframes, _ = x.shape 53 | # bs, nframes, totjoints, nfeats = x.shape 54 | # Switch sequence and batch_size because the input of 55 | # Pytorch Transformer is [Sequence, Batch size, ...] 56 | x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] 57 | 58 | if self.hparams.vae: 59 | mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1) 60 | logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1) 61 | 62 | # adding the distribution tokens for all sequences 63 | xseq = torch.cat((mu_token[None], logvar_token[None], x), 0) 64 | 65 | # create a bigger mask, to allow attend to mu and logvar 66 | token_mask = torch.ones((bs, 2), dtype=bool, device=x.device) 67 | aug_mask = torch.cat((token_mask, mask), 1) 68 | else: 69 | emb_token = torch.tile(self.emb_token, (bs,)).reshape(bs, -1) 70 | 71 | # adding the embedding token for all sequences 72 | xseq = torch.cat((emb_token[None], x), 0) 73 | 74 | # create a bigger mask, to allow attend to emb 75 | token_mask = torch.ones((bs, 1), dtype=bool, device=x.device) 76 | aug_mask = torch.cat((token_mask, mask), 1) 77 | 78 | # add positional encoding 79 | xseq = self.sequence_pos_encoding(xseq) 80 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 81 | 82 | if self.hparams.vae: 83 | mu, logvar = final[0], final[1] 84 | std = logvar.exp().pow(0.5) 85 | # https://github.com/kampta/pytorch-distributions/blob/master/gaussian_vae.py 86 | try: 87 | dist = torch.distributions.normal.Normal(mu, std) 88 | except ValueError: 89 | import ipdb; ipdb.set_trace() # noqa 90 | pass 91 | return dist 92 | else: 93 | return final[0] 94 | -------------------------------------------------------------------------------- /src/model/textencoder/t5_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.distributions.distribution import Distribution 7 | from src.utils.file_io import hack_path 8 | import pytorch_lightning as pl 9 | 10 | class T5TextEncoder(pl.LightningModule): 11 | 12 | def __init__( 13 | self, 14 | modelpath: str, 15 | finetune: bool = False, 16 | **kwargs 17 | ) -> None: 18 | 19 | super().__init__() 20 | self.save_hyperparameters(logger=False) 21 | from transformers import logging 22 | from transformers import AutoModel 23 | from transformers import AutoTokenizer, T5EncoderModel 24 | 25 | logging.set_verbosity_error() 26 | # Tokenizer 27 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 28 | 29 | self.tokenizer = AutoTokenizer.from_pretrained(hack_path(modelpath), 30 | legacy=True) 31 | self.language_model = T5EncoderModel.from_pretrained( 32 | hack_path(modelpath)) 33 | self.language_model.resize_token_embeddings(len(self.tokenizer)) 34 | self.max_length = 128 # self.tokenizer.model_max_length 35 | # Don't train the model 36 | if not finetune: 37 | self.language_model.training = False 38 | for p in self.language_model.parameters(): 39 | p.requires_grad = False 40 | 41 | def forward(self, texts: List[str]): 42 | 43 | # # Tokenize 44 | text_inputs = self.tokenizer(texts, 45 | padding='max_length', 46 | max_length=self.max_length, 47 | truncation=True, 48 | return_attention_mask=True, 49 | add_special_tokens=True, 50 | return_tensors="pt") 51 | 52 | 53 | # input_ids = self.tokenizer(texts, return_tensors="pt").input_ids # Batch size 1 54 | text_input_ids = text_inputs.input_ids.to(self.language_model.device) 55 | txt_att_mask = text_inputs.attention_mask.to(self.language_model.device) 56 | 57 | outputs = self.language_model(input_ids=text_input_ids, attention_mask=txt_att_mask) 58 | last_hidden_states = outputs.last_hidden_state 59 | 60 | return last_hidden_states, txt_att_mask.bool() -------------------------------------------------------------------------------- /src/model/tmr_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa 2 | from .temos import TEMOS # noqa 3 | from .tmr import TMR # noqa 4 | from .text_encoder import TextToEmb # noqa 5 | -------------------------------------------------------------------------------- /src/model/tmr_utils/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | import numpy as np 7 | 8 | from einops import repeat 9 | 10 | 11 | class PositionalEncoding(nn.Module): 12 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: 13 | super().__init__() 14 | self.batch_first = batch_first 15 | 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 20 | div_term = torch.exp( 21 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) 22 | ) 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | pe = pe.unsqueeze(0).transpose(0, 1) 26 | self.register_buffer("pe", pe, persistent=False) 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | if self.batch_first: 30 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] 31 | else: 32 | x = x + self.pe[: x.shape[0], :] 33 | return self.dropout(x) 34 | 35 | 36 | class ACTORStyleEncoder(nn.Module): 37 | # Similar to ACTOR but "action agnostic" and more general 38 | def __init__( 39 | self, 40 | nfeats: int, 41 | vae: bool, 42 | latent_dim: int = 256, 43 | ff_size: int = 1024, 44 | num_layers: int = 4, 45 | num_heads: int = 4, 46 | dropout: float = 0.1, 47 | activation: str = "gelu", 48 | ) -> None: 49 | super().__init__() 50 | 51 | self.nfeats = nfeats 52 | self.projection = nn.Linear(nfeats, latent_dim) 53 | 54 | self.vae = vae 55 | self.nbtokens = 2 if vae else 1 56 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) 57 | 58 | self.sequence_pos_encoding = PositionalEncoding( 59 | latent_dim, dropout=dropout, batch_first=True 60 | ) 61 | 62 | seq_trans_encoder_layer = nn.TransformerEncoderLayer( 63 | d_model=latent_dim, 64 | nhead=num_heads, 65 | dim_feedforward=ff_size, 66 | dropout=dropout, 67 | activation=activation, 68 | batch_first=True, 69 | ) 70 | 71 | self.seqTransEncoder = nn.TransformerEncoder( 72 | seq_trans_encoder_layer, num_layers=num_layers 73 | ) 74 | 75 | def forward(self, x_dict: Dict) -> Tensor: 76 | x = x_dict["x"] 77 | mask = x_dict["mask"] 78 | 79 | x = self.projection(x) 80 | 81 | device = x.device 82 | bs = len(x) 83 | 84 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) 85 | xseq = torch.cat((tokens, x), 1) 86 | 87 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) 88 | aug_mask = torch.cat((token_mask, mask), 1) 89 | 90 | # add positional encoding 91 | xseq = self.sequence_pos_encoding(xseq) 92 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 93 | return final[:, : self.nbtokens] 94 | 95 | 96 | class ACTORStyleDecoder(nn.Module): 97 | # Similar to ACTOR Decoder 98 | 99 | def __init__( 100 | self, 101 | nfeats: int, 102 | latent_dim: int = 256, 103 | ff_size: int = 1024, 104 | num_layers: int = 4, 105 | num_heads: int = 4, 106 | dropout: float = 0.1, 107 | activation: str = "gelu", 108 | ) -> None: 109 | super().__init__() 110 | output_feats = nfeats 111 | self.nfeats = nfeats 112 | 113 | self.sequence_pos_encoding = PositionalEncoding( 114 | latent_dim, dropout, batch_first=True 115 | ) 116 | 117 | seq_trans_decoder_layer = nn.TransformerDecoderLayer( 118 | d_model=latent_dim, 119 | nhead=num_heads, 120 | dim_feedforward=ff_size, 121 | dropout=dropout, 122 | activation=activation, 123 | batch_first=True, 124 | ) 125 | 126 | self.seqTransDecoder = nn.TransformerDecoder( 127 | seq_trans_decoder_layer, num_layers=num_layers 128 | ) 129 | 130 | self.final_layer = nn.Linear(latent_dim, output_feats) 131 | 132 | def forward(self, z_dict: Dict) -> Tensor: 133 | z = z_dict["z"] 134 | mask = z_dict["mask"] 135 | 136 | latent_dim = z.shape[1] 137 | bs, nframes = mask.shape 138 | 139 | z = z[:, None] # sequence of 1 element for the memory 140 | 141 | # Construct time queries 142 | time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device) 143 | time_queries = self.sequence_pos_encoding(time_queries) 144 | 145 | # Pass through the transformer decoder 146 | # with the latent vector for memory 147 | output = self.seqTransDecoder( 148 | tgt=time_queries, memory=z, tgt_key_padding_mask=~mask 149 | ) 150 | 151 | output = self.final_layer(output) 152 | # zero for padded area 153 | output[~mask] = 0 154 | return output 155 | -------------------------------------------------------------------------------- /src/model/tmr_utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # For reference 6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence 8 | class KLLoss: 9 | def __call__(self, q, p): 10 | mu_q, logvar_q = q 11 | mu_p, logvar_p = p 12 | 13 | log_var_ratio = logvar_q - logvar_p 14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp() 15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 16 | return div.mean() 17 | 18 | def __repr__(self): 19 | return "KLLoss()" 20 | 21 | 22 | class InfoNCE_with_filtering: 23 | def __init__(self, temperature=0.7, threshold_selfsim=0.8): 24 | self.temperature = temperature 25 | self.threshold_selfsim = threshold_selfsim 26 | 27 | def get_sim_matrix(self, x, y): 28 | x_logits = torch.nn.functional.normalize(x, dim=-1) 29 | y_logits = torch.nn.functional.normalize(y, dim=-1) 30 | sim_matrix = x_logits @ y_logits.T 31 | return sim_matrix 32 | 33 | def __call__(self, x, y, sent_emb=None): 34 | bs, device = len(x), x.device 35 | sim_matrix = self.get_sim_matrix(x, y) / self.temperature 36 | 37 | if sent_emb is not None and self.threshold_selfsim: 38 | # put the threshold value between -1 and 1 39 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1 40 | # Filtering too close values 41 | # mask them by putting -inf in the sim_matrix 42 | selfsim = sent_emb @ sent_emb.T 43 | selfsim_nodiag = selfsim - selfsim.diag().diag() 44 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim) 45 | sim_matrix[idx] = -torch.inf 46 | 47 | labels = torch.arange(bs, device=device) 48 | 49 | total_loss = ( 50 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels) 51 | ) / 2 52 | 53 | return total_loss 54 | 55 | def __repr__(self): 56 | return f"Constrastive(temp={self.temp})" 57 | -------------------------------------------------------------------------------- /src/model/tmr_utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def print_latex_metrics(metrics): 5 | vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]] 6 | t2m_keys = [f"t2m/R{i}" for i in vals] + ["t2m/MedR"] 7 | m2t_keys = [f"m2t/R{i}" for i in vals] + ["m2t/MedR"] 8 | 9 | keys = t2m_keys + m2t_keys 10 | 11 | def ff(val_): 12 | val = str(val_).ljust(5, "0") 13 | # make decimal fine when only one digit 14 | if val[1] == ".": 15 | val = str(val_).ljust(4, "0") 16 | return val 17 | 18 | str_ = "& " + " & ".join([ff(metrics[key]) for key in keys]) + r" \\" 19 | dico = {key: ff(metrics[key]) for key in keys} 20 | print(dico) 21 | print("Number of samples: {}".format(int(metrics["t2m/len"]))) 22 | print(str_) 23 | 24 | 25 | def all_contrastive_metrics( 26 | sims, emb=None, threshold=None, rounding=2, return_cols=False 27 | ): 28 | text_selfsim = None 29 | if emb is not None: 30 | text_selfsim = emb @ emb.T 31 | 32 | t2m_m, t2m_cols = contrastive_metrics( 33 | sims, text_selfsim, threshold, return_cols=True, rounding=rounding 34 | ) 35 | m2t_m, m2t_cols = contrastive_metrics( 36 | sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding 37 | ) 38 | 39 | all_m = {} 40 | for key in t2m_m: 41 | all_m[f"t2m/{key}"] = t2m_m[key] 42 | all_m[f"m2t/{key}"] = m2t_m[key] 43 | 44 | all_m["t2m/len"] = float(len(sims)) 45 | all_m["m2t/len"] = float(len(sims[0])) 46 | if return_cols: 47 | return all_m, t2m_cols, m2t_cols 48 | return all_m 49 | 50 | 51 | def contrastive_metrics( 52 | sims, 53 | text_selfsim=None, 54 | threshold=None, 55 | return_cols=False, 56 | rounding=2, 57 | break_ties="averaging", 58 | ): 59 | n, m = sims.shape 60 | assert n == m 61 | num_queries = n 62 | 63 | dists = -sims 64 | sorted_dists = np.sort(dists, axis=1) 65 | # GT is in the diagonal 66 | gt_dists = np.diag(dists)[:, None] 67 | 68 | if text_selfsim is not None and threshold is not None: 69 | real_threshold = 2 * threshold - 1 70 | idx = np.argwhere(text_selfsim > real_threshold) 71 | partition = np.unique(idx[:, 0], return_index=True)[1] 72 | # take as GT the minimum score of similar values 73 | gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition) 74 | gt_dists = gt_dists[:, None] 75 | 76 | rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT 77 | 78 | # if there are ties 79 | if rows.size > num_queries: 80 | assert np.unique(rows).size == num_queries, "issue in metric evaluation" 81 | if break_ties == "optimistically": 82 | opti_cols = break_ties_optimistically(sorted_dists, gt_dists) 83 | cols = opti_cols 84 | elif break_ties == "averaging": 85 | avg_cols = break_ties_average(sorted_dists, gt_dists) 86 | cols = avg_cols 87 | 88 | msg = "expected ranks to match queries ({} vs {}) " 89 | assert cols.size == num_queries, msg 90 | 91 | if return_cols: 92 | return cols2metrics(cols, num_queries, rounding=rounding), cols 93 | return cols2metrics(cols, num_queries, rounding=rounding) 94 | 95 | 96 | def break_ties_average(sorted_dists, gt_dists): 97 | # fast implementation, based on this code: 98 | # https://stackoverflow.com/a/49239335 99 | locs = np.argwhere((sorted_dists - gt_dists) == 0) 100 | 101 | # Find the split indices 102 | steps = np.diff(locs[:, 0]) 103 | splits = np.nonzero(steps)[0] + 1 104 | splits = np.insert(splits, 0, 0) 105 | 106 | # Compute the result columns 107 | summed_cols = np.add.reduceat(locs[:, 1], splits) 108 | counts = np.diff(np.append(splits, locs.shape[0])) 109 | avg_cols = summed_cols / counts 110 | return avg_cols 111 | 112 | 113 | def break_ties_optimistically(sorted_dists, gt_dists): 114 | rows, cols = np.where((sorted_dists - gt_dists) == 0) 115 | _, idx = np.unique(rows, return_index=True) 116 | cols = cols[idx] 117 | return cols 118 | 119 | 120 | def cols2metrics(cols, num_queries, rounding=2): 121 | metrics = {} 122 | vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]] 123 | for val in vals: 124 | metrics[f"R{val}"] = 100 * float(np.sum(cols < int(val))) / num_queries 125 | 126 | metrics["MedR"] = float(np.median(cols) + 1) 127 | 128 | if rounding is not None: 129 | for key in metrics: 130 | metrics[key] = round(metrics[key], rounding) 131 | return metrics 132 | -------------------------------------------------------------------------------- /src/model/tmr_utils/text_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | from typing import Dict, List 6 | import torch.nn.functional as F 7 | 8 | 9 | class TextToEmb(nn.Module): 10 | def __init__( 11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu" 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.device = device 16 | from transformers import AutoTokenizer, AutoModel 17 | from transformers import logging 18 | 19 | logging.set_verbosity_error() 20 | 21 | # Tokenizer 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 24 | 25 | # Text model 26 | self.text_model = AutoModel.from_pretrained(modelpath) 27 | # Then configure the model 28 | self.text_encoded_dim = self.text_model.config.hidden_size 29 | 30 | if mean_pooling: 31 | self.forward = self.forward_pooling 32 | 33 | # put it in eval mode by default 34 | self.eval() 35 | 36 | # Freeze the weights just in case 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | self.to(device) 41 | 42 | def train(self, mode: bool = True) -> nn.Module: 43 | # override it to be always false 44 | self.training = False 45 | for module in self.children(): 46 | module.train(False) 47 | return self 48 | 49 | @torch.no_grad() 50 | def forward(self, texts: List[str], device=None) -> Dict: 51 | device = device if device is not None else self.device 52 | 53 | squeeze = False 54 | if isinstance(texts, str): 55 | texts = [texts] 56 | squeeze = True 57 | 58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 59 | output = self.text_model(**encoded_inputs.to(device)) 60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1) 61 | 62 | if squeeze: 63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]} 64 | else: 65 | x_dict = {"x": output.last_hidden_state, "length": length} 66 | return x_dict 67 | 68 | @torch.no_grad() 69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor: 70 | device = device if device is not None else self.device 71 | 72 | squeeze = False 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | squeeze = True 76 | 77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 79 | output = self.text_model(**encoded_inputs.to(device)) 80 | attention_mask = encoded_inputs["attention_mask"] 81 | 82 | # Mean Pooling - Take attention mask into account for correct averaging 83 | token_embeddings = output["last_hidden_state"] 84 | input_mask_expanded = ( 85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 86 | ) 87 | sentence_embeddings = torch.sum( 88 | token_embeddings * input_mask_expanded, 1 89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 90 | # Normalize embeddings 91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 92 | if squeeze: 93 | sentence_embeddings = sentence_embeddings[0] 94 | return sentence_embeddings 95 | -------------------------------------------------------------------------------- /src/model/tmr_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | import os 10 | import json 11 | from omegaconf import DictConfig, OmegaConf 12 | 13 | 14 | def save_config(cfg: DictConfig) -> str: 15 | path = os.path.join(cfg.run_dir, "config.json") 16 | config = OmegaConf.to_container(cfg, resolve=True) 17 | with open(path, "w") as f: 18 | string = json.dumps(config, indent=4) 19 | f.write(string) 20 | return path 21 | 22 | 23 | def read_config(run_dir: str, return_json=False) -> DictConfig: 24 | path = os.path.join(run_dir, "config.json") 25 | with open(path, "r") as f: 26 | config = json.load(f) 27 | if return_json: 28 | return config 29 | cfg = OmegaConf.create(config) 30 | cfg.run_dir = run_dir 31 | return cfg 32 | 33 | # split the lightning checkpoint into 34 | # seperate state_dict modules for faster loading 35 | def extract_ckpt(run_dir, ckpt_name="last"): 36 | import torch 37 | 38 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") 39 | 40 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") 41 | os.makedirs(extracted_path, exist_ok=True) 42 | 43 | new_path_template = os.path.join(extracted_path, "{}.pt") 44 | ckpt_dict = torch.load(ckpt_path) 45 | state_dict = ckpt_dict["state_dict"] 46 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) 47 | 48 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example 49 | for module_name in module_names: 50 | path = new_path_template.format(module_name) 51 | sub_state_dict = { 52 | ".".join(x.split(".")[1:]): y.cpu() 53 | for x, y in state_dict.items() 54 | if x.split(".")[0] == module_name 55 | } 56 | torch.save(sub_state_dict, path) 57 | 58 | 59 | def load_model(run_dir, **params): 60 | # Load last config 61 | cfg = read_config(run_dir) 62 | cfg.run_dir = run_dir 63 | return load_model_from_cfg(cfg, **params) 64 | 65 | 66 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True): 67 | import torch 68 | 69 | run_dir = cfg.run_dir 70 | model = hydra.utils.instantiate(cfg.model) 71 | 72 | # Loading modules one by one 73 | # motion_encoder / text_encoder / text_decoder 74 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") 75 | 76 | if not os.path.exists(pt_path): 77 | logger.info("The extracted model is not found. Split into submodules..") 78 | extract_ckpt(run_dir, ckpt_name) 79 | 80 | for fname in os.listdir(pt_path): 81 | module_name, ext = os.path.splitext(fname) 82 | if ext != ".pt": 83 | continue 84 | 85 | module = getattr(model, module_name, None) 86 | if module is None: 87 | continue 88 | 89 | module_path = os.path.join(pt_path, fname) 90 | state_dict = torch.load(module_path) 91 | module.load_state_dict(state_dict) 92 | logger.info(f" {module_name} loaded") 93 | 94 | logger.info("Loading previous checkpoint done") 95 | model = model.to(device) 96 | logger.info(f"Put the model on {device}") 97 | if eval_mode: 98 | model = model.eval() 99 | logger.info("Put the model in eval mode") 100 | return model 101 | 102 | 103 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model") 104 | def hydra_load_model(cfg: DictConfig) -> None: 105 | run_dir = cfg.run_dir 106 | ckpt_name = cfg.ckpt 107 | device = cfg.device 108 | eval_mode = cfg.eval_mode 109 | return load_model(run_dir, ckpt_name, device, eval_mode) 110 | 111 | 112 | if __name__ == "__main__": 113 | hydra_load_model() 114 | -------------------------------------------------------------------------------- /src/model/trans_enc.py: -------------------------------------------------------------------------------- 1 | from src.model.DiT_models import * 2 | 3 | class EncoderBlock(nn.Module): 4 | """ 5 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 6 | """ 7 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 8 | super().__init__() 9 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 10 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 11 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 12 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 13 | approx_gelu = lambda: nn.GELU(approximate="tanh") 14 | self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 15 | 16 | def forward(self, x, c, mask=None): 17 | x = x + self.attn((self.norm1(x)), mask) 18 | x = x + self.mlp((self.norm2(x))) 19 | return x 20 | 21 | def forward_with_att(self, x, c, mask=None): 22 | attention_out, attention_mask = self.attn.forward_w_attention((self.norm1(x)), mask) 23 | x = x + attention_out 24 | x = x + self.mlp((self.norm2(x))) 25 | return x, attention_mask 26 | 27 | import copy 28 | def clones(module, N): 29 | """Produce N identical layers.""" 30 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 31 | 32 | class LayerNorm(nn.Module): 33 | def __init__(self, features: int, eps: float = 1e-6): 34 | # features = d_model 35 | super(LayerNorm, self).__init__() 36 | self.a = nn.Parameter(torch.ones(features)) 37 | self.b = nn.Parameter(torch.zeros(features)) 38 | self.eps = eps 39 | 40 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 41 | mean = x.mean(-1, keepdim=True) 42 | std = x.std(-1, keepdim=True) 43 | return self.a * (x - mean) / (std + self.eps) + self.b 44 | 45 | class TransEncoder(nn.Module): 46 | """Core encoder is a stack of N layers""" 47 | 48 | def __init__(self, layer, N: int): 49 | super(TransEncoder, self).__init__() 50 | self.layers = clones(layer, N) 51 | 52 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor: 53 | """Pass the input (and mask) through each layer in turn.""" 54 | for layer in self.layers: 55 | x = layer(x, mask) 56 | return x 57 | def forward_with_att(self, x, c, mask=None): 58 | attention_masks = [] 59 | for layer in self.layers: 60 | x, attention_mask = layer.forward_with_att(x, c, mask) 61 | attention_masks.append(attention_mask) 62 | return x, attention_masks -------------------------------------------------------------------------------- /src/model/transformer_encoder/encoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/encoder.py -------------------------------------------------------------------------------- /src/model/transformer_encoder/encoder_layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/encoder_layer.py -------------------------------------------------------------------------------- /src/model/transformer_encoder/feed_forward.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/model/transformer_encoder/feed_forward.py -------------------------------------------------------------------------------- /src/model/transformer_encoder/note.md: -------------------------------------------------------------------------------- 1 | Based on https://github.com/guocheng18/Transformer-Encoder/tree/master -------------------------------------------------------------------------------- /src/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .positional_encoding import PositionalEncoding 2 | from .vae import reparameterize 3 | -------------------------------------------------------------------------------- /src/model/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau 3 | from functools import partial 4 | import math 5 | import torch 6 | import math 7 | import functools 8 | 9 | def _cosine_decay_warmup(iteration, warmup_iterations, total_iterations, final_learning_rate, initial_learning_rate): 10 | """ 11 | Adjusted function to handle linear warmup and cosine decay that correctly transitions to the final learning rate. 12 | """ 13 | if iteration <= warmup_iterations: 14 | # Linear warmup: multiplier increases from 0 to 1 15 | return iteration / warmup_iterations 16 | else: 17 | # Cosine decay phase: decay from 1 to final_learning_rate/initial_learning_rate 18 | decayed = (iteration - warmup_iterations) / (total_iterations - warmup_iterations) 19 | decayed = 0.5 * (1 + math.cos(math.pi * decayed)) 20 | # Normalize the decay such that it ends at final_learning_rate/initial_learning_rate 21 | decay_multiplier = decayed * (1 - final_learning_rate / initial_learning_rate) + final_learning_rate / initial_learning_rate 22 | return decay_multiplier 23 | 24 | def CosineAnnealingLRWarmup(optimizer, T_max, T_warmup, lr_final, lr_initial): 25 | _decay_func = functools.partial( 26 | _cosine_decay_warmup, 27 | warmup_iterations=T_warmup, total_iterations=T_max, 28 | final_learning_rate=lr_final, 29 | initial_learning_rate=lr_initial 30 | ) 31 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, _decay_func) 32 | return scheduler 33 | -------------------------------------------------------------------------------- /src/model/utils/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class PositionalEncoding(nn.Module): 7 | def __init__(self, d_model, dropout=0.1, 8 | max_len=5000, batch_first=False, negative=False): 9 | super().__init__() 10 | self.batch_first = batch_first 11 | 12 | self.dropout = nn.Dropout(p=dropout) 13 | self.max_len = max_len 14 | 15 | self.negative = negative 16 | 17 | if negative: 18 | pe = torch.zeros(2*max_len, d_model) 19 | position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1) 20 | else: 21 | pe = torch.zeros(max_len, d_model) 22 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 23 | 24 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 25 | pe[:, 0::2] = torch.sin(position * div_term) 26 | pe[:, 1::2] = torch.cos(position * div_term) 27 | pe = pe.unsqueeze(0).transpose(0, 1) 28 | 29 | self.register_buffer('pe', pe, persistent=False) 30 | 31 | def forward(self, x, hist_frames=0): 32 | if not self.negative: 33 | center = 0 34 | assert hist_frames == 0 35 | first = 0 36 | else: 37 | center = self.max_len 38 | first = center-hist_frames 39 | if self.batch_first: 40 | last = first + x.shape[1] 41 | x = x + self.pe.permute(1, 0, 2)[:, first:last, :] 42 | else: 43 | last = first + x.shape[0] 44 | x = x + self.pe[first:last, :] 45 | return self.dropout(x) -------------------------------------------------------------------------------- /src/model/utils/timestep_embed.py: -------------------------------------------------------------------------------- 1 | import select 2 | from torch import nn 3 | import torch 4 | import math 5 | 6 | def get_timestep_embedding( 7 | timesteps: torch.Tensor, 8 | embedding_dim: int, 9 | flip_sin_to_cos: bool = False, 10 | downscale_freq_shift: float = 1, 11 | scale: float = 1, 12 | max_period: int = 10000, 13 | ): 14 | """ 15 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 16 | 17 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 18 | These may be fractional. 19 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 20 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 21 | """ 22 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 23 | 24 | half_dim = embedding_dim // 2 25 | exponent = -math.log(max_period) * torch.arange( 26 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 27 | ) 28 | exponent = exponent / (half_dim - downscale_freq_shift) 29 | 30 | emb = torch.exp(exponent) 31 | emb = timesteps[:, None].float() * emb[None, :] 32 | 33 | # scale embeddings 34 | emb = scale * emb 35 | 36 | # concat sine and cosine embeddings 37 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 38 | 39 | # flip sine and cosine embeddings 40 | if flip_sin_to_cos: 41 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 42 | 43 | # zero pad 44 | if embedding_dim % 2 == 1: 45 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 46 | return emb 47 | 48 | class TimestepEmbedding(nn.Module): 49 | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): 50 | super().__init__() 51 | 52 | self.linear_1 = nn.Linear(channel, time_embed_dim) 53 | self.act = None 54 | if act_fn == "silu": 55 | self.act = nn.SiLU() 56 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) 57 | 58 | def forward(self, sample): 59 | sample = self.linear_1(sample) 60 | 61 | if self.act is not None: 62 | sample = self.act(sample) 63 | 64 | sample = self.linear_2(sample) 65 | return sample 66 | 67 | 68 | class Timesteps(nn.Module): 69 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 70 | super().__init__() 71 | self.num_channels = num_channels 72 | self.flip_sin_to_cos = flip_sin_to_cos 73 | self.downscale_freq_shift = downscale_freq_shift 74 | 75 | def forward(self, timesteps): 76 | t_emb = get_timestep_embedding(timesteps, self.num_channels, 77 | flip_sin_to_cos=self.flip_sin_to_cos, 78 | downscale_freq_shift=self.downscale_freq_shift,) 79 | return t_emb 80 | 81 | 82 | class TimestepEmbedderMDM(nn.Module): 83 | def __init__(self, latent_dim): 84 | super().__init__() 85 | self.latent_dim = latent_dim 86 | from src.model.utils.positional_encoding import PositionalEncoding 87 | 88 | time_embed_dim = self.latent_dim 89 | self.sequence_pos_encoder = PositionalEncoding(d_model=self.latent_dim) 90 | self.time_embed = nn.Sequential( 91 | nn.Linear(self.latent_dim, time_embed_dim), 92 | nn.SiLU(), 93 | nn.Linear(time_embed_dim, time_embed_dim), 94 | ) 95 | 96 | def forward(self, timesteps): 97 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) 98 | -------------------------------------------------------------------------------- /src/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | from src.tools.transforms3d import transform_body_pose 2 | import torch 3 | 4 | def detach_to_numpy(tensor): 5 | return tensor.detach().cpu().numpy() 6 | 7 | def remove_padding(tensors, lengths): 8 | return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)] 9 | 10 | def pack_to_render(rots, trans, pose_repr='6d'): 11 | # make axis-angle 12 | # global_orient = transform_body_pose(rots, f"{pose_repr}->aa") 13 | if pose_repr != 'aa': 14 | body_pose = transform_body_pose(rots, f"{pose_repr}->aa") 15 | else: 16 | body_pose = rots 17 | if trans is None: 18 | trans = torch.zeros((rots.shape[0], rots.shape[1], 3), 19 | device=rots.device) 20 | render_d = {'body_transl': trans, 21 | 'body_orient': body_pose[..., :3], 22 | 'body_pose': body_pose[..., 3:]} 23 | return render_d -------------------------------------------------------------------------------- /src/model/utils/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reparameterize(mu, logvar, seed=None): 5 | std = torch.exp(logvar / 2) 6 | 7 | if seed is None: 8 | eps = std.data.new(std.size()).normal_() 9 | else: 10 | generator = torch.Generator(device=mu.device) 11 | generator.manual_seed(seed) 12 | eps = std.data.new(std.size()).normal_(generator=generator) 13 | 14 | return eps.mul(std).add_(mu) -------------------------------------------------------------------------------- /src/render/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | if 'blender' not in sys.executable: 3 | import sys 4 | from .anim import render_animation 5 | from .mesh_viz import render_motion 6 | -------------------------------------------------------------------------------- /src/repr_utils/tmr_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import hydra 3 | from pathlib import Path 4 | from tmr_evaluator.motion2motion_retr import read_config 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | import pytorch_lightning as pl 9 | import numpy as np 10 | from hydra.utils import instantiate 11 | from src.tmr.load_model import load_model_from_cfg 12 | from src.tmr.metrics import all_contrastive_metrics_mot2mot, print_latex_metrics_m2m 13 | from omegaconf import DictConfig 14 | from omegaconf import OmegaConf 15 | import src.launch.prepare 16 | from src.data.tools.collate import collate_batch_last_padding 17 | import torch 18 | from tmr_evaluator.motion2motion_retr import length_to_mask 19 | # Step 1: load TMR model 20 | 21 | def mean_flat(x): 22 | """ 23 | Take the mean over all non-batch dimensions. 24 | """ 25 | return torch.mean(x, dim=list(range(1, len(x.size())))) 26 | 27 | def masked_loss(loss, mask): 28 | # loss: B, T 29 | # mask: B, T 30 | masked_loss = loss * mask # 只保留 mask 位置的损失 31 | 32 | # 计算平均 masked loss 33 | # 1. 计算有效元素数量,避免除以零 34 | num_valid_elements = mask.sum() # 有效元素的总数 35 | masked_loss_mean = masked_loss.sum() / (num_valid_elements + 1) 36 | return masked_loss_mean 37 | 38 | def load_tmr_model(): 39 | protocol = ['normal', 'batches'] 40 | device = 'cpu' 41 | run_dir = 'eval-deps' 42 | ckpt_name = 'last' 43 | batch_size = 256 44 | 45 | protocols = protocol 46 | dataset = 'motionfix' # motionfix 47 | sets = 'test' # val all 48 | # save_dir = os.path.join(run_dir, "motionfix/contrastive_metrics") 49 | # os.makedirs(save_dir, exist_ok=True) 50 | 51 | # Load last config 52 | curdir = Path("/depot/bera89/data/li5280/project/motionfix") 53 | cfg = read_config(curdir / run_dir) 54 | logger.info("Loading the evaluation TMR model") 55 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 56 | return model 57 | 58 | 59 | 60 | # STEP 2: load dataset 61 | 62 | def load_testloader(cfg: DictConfig): 63 | # What do you want? 64 | exp_folder = Path("/depot/bera89/data/li5280/project/motionfix/experiments/tmed") 65 | prevcfg = OmegaConf.load(exp_folder / ".hydra/config.yaml") 66 | cfg = OmegaConf.merge(prevcfg, cfg) 67 | data_module = instantiate(cfg.data, amt_only=True, 68 | load_splits=['test', 'val']) 69 | # load the test set and collate it properly 70 | features_to_load = data_module.dataset['test'].load_feats 71 | SMPL_feats = ['body_transl', 'body_pose', 'body_orient'] 72 | for feat in SMPL_feats: 73 | if feat not in features_to_load: 74 | data_module.dataset['test'].load_feats.append(feat) 75 | print(features_to_load) 76 | 77 | # TODO: change features of including SMPL features 78 | test_dataset = data_module.dataset['test'] + data_module.dataset['val'] 79 | collate_fn = lambda b: collate_batch_last_padding(b, features_to_load) 80 | 81 | testloader = torch.utils.data.DataLoader(test_dataset, 82 | shuffle=False, 83 | num_workers=8, 84 | batch_size=128, 85 | collate_fn=collate_fn) 86 | return testloader 87 | 88 | from src.data.features import _get_body_transl_delta_pelv_infer 89 | def batch_to_smpl(batch, normalizer, mot_from='source'): 90 | # batch: dict 91 | # return: padded batch with lengths 92 | # simple concatenation of the body pose and body transl 93 | # trans_delta, body_pose_6d, global_orient_6d 94 | smpl_keys = ['body_transl', 'body_orient', 'body_pose'] 95 | lengths = batch[f'length_{mot_from}'] 96 | # T, zeros, trans, global, local 97 | tensor_list = [] 98 | trans = batch[f'body_transl_{mot_from}'] 99 | body_pose_6d = batch[f'body_pose_{mot_from}'] 100 | global_orient_6d = batch[f'body_orient_{mot_from}'] 101 | trans_delta = _get_body_transl_delta_pelv_infer(global_orient_6d, 102 | trans) # T, 3, [0] is [0,0,0] 103 | motion_smpl = torch.cat([trans_delta, body_pose_6d, 104 | global_orient_6d], dim=-1) 105 | motion_smpl = normalizer(motion_smpl) 106 | # In some motions, translation is always zeros. 107 | # first 3 dimensions of the first frame should be zeros 108 | # B, T, D 109 | return motion_smpl, lengths 110 | from src.tmr.data.motionfix_loader import Normalizer 111 | 112 | @hydra.main(config_path="/depot/bera89/data/li5280/project/motionfix/configs", config_name="motionfix_eval") 113 | def main(cfg: DictConfig): 114 | model = load_tmr_model() 115 | dataloader = load_testloader(cfg) 116 | normalizer = Normalizer("/depot/bera89/data/li5280/project/motionfix/eval-deps/stats/humanml3d/amass_feats") 117 | 118 | # STEP 3: dataset to SMPL 119 | for batch in dataloader: 120 | source_smpl, lengths = batch_to_smpl(batch, normalizer) # B, T, 135 121 | masks = length_to_mask(lengths, device=source_smpl.device) 122 | in_batch = {'x': source_smpl, 'mask': masks} 123 | res = model.encode_motion(in_batch) 124 | print(res.size()) 125 | break 126 | # what we want: original SMPL data, and lengths 127 | # Then will converted model-specific representation 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /src/tmr/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa 2 | from .temos import TEMOS # noqa 3 | from .tmr import TMR # noqa 4 | from .text_encoder import TextToEmb # noqa 5 | -------------------------------------------------------------------------------- /src/tmr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/tmr/data/__init__.py -------------------------------------------------------------------------------- /src/tmr/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | import os 7 | import json 8 | from omegaconf import DictConfig, OmegaConf 9 | 10 | 11 | def save_config(cfg: DictConfig) -> str: 12 | path = os.path.join(cfg.run_dir, "config.json") 13 | config = OmegaConf.to_container(cfg, resolve=True) 14 | with open(path, "w") as f: 15 | string = json.dumps(config, indent=4) 16 | f.write(string) 17 | return path 18 | 19 | 20 | def read_config(run_dir: str, return_json=False) -> DictConfig: 21 | path = os.path.join(run_dir, "config.json") 22 | with open(path, "r") as f: 23 | config = json.load(f) 24 | if return_json: 25 | return config 26 | cfg = OmegaConf.create(config) 27 | cfg.run_dir = run_dir 28 | return cfg 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | # split the lightning checkpoint into 34 | # seperate state_dict modules for faster loading 35 | def extract_ckpt(run_dir, ckpt_name="last"): 36 | import torch 37 | 38 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") 39 | 40 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") 41 | os.makedirs(extracted_path, exist_ok=True) 42 | 43 | new_path_template = os.path.join(extracted_path, "{}.pt") 44 | ckpt_dict = torch.load(ckpt_path) 45 | state_dict = ckpt_dict["state_dict"] 46 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) 47 | 48 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example 49 | for module_name in module_names: 50 | path = new_path_template.format(module_name) 51 | sub_state_dict = { 52 | ".".join(x.split(".")[1:]): y.cpu() 53 | for x, y in state_dict.items() 54 | if x.split(".")[0] == module_name 55 | } 56 | torch.save(sub_state_dict, path) 57 | 58 | 59 | def load_model(run_dir, **params): 60 | # Load last config 61 | cfg = read_config(run_dir) 62 | cfg.run_dir = run_dir 63 | return load_model_from_cfg(cfg, **params) 64 | 65 | 66 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True): 67 | import torch 68 | 69 | run_dir = cfg.run_dir 70 | 71 | from omegaconf import DictConfig, OmegaConf 72 | 73 | def replace_model_with_tmr(d): 74 | if isinstance(d, DictConfig): 75 | new_dict = {} 76 | for k, v in d.items(): 77 | new_key = k.replace('.model.', '.tmr.') 78 | new_value = replace_model_with_tmr(v) 79 | new_dict[new_key] = new_value 80 | return OmegaConf.create(new_dict) 81 | elif isinstance(d, dict): 82 | new_dict = {} 83 | for k, v in d.items(): 84 | new_key = k.replace('.model.', '.tmr.') 85 | new_value = replace_model_with_tmr(v) 86 | new_dict[new_key] = new_value 87 | return new_dict 88 | elif isinstance(d, list): 89 | return [replace_model_with_tmr(item) for item in d] 90 | elif isinstance(d, str): 91 | return d.replace('.model.', '.tmr.') 92 | else: 93 | return d 94 | model_conf = replace_model_with_tmr(cfg.model) 95 | 96 | # TODO: see what it is? 97 | # import ipdb;ipdb.set_trace() 98 | 99 | # model_conf = OmegaConf.to_yaml(model_conf_d) 100 | model = hydra.utils.instantiate(model_conf) 101 | 102 | # Loading modules one by one 103 | # motion_encoder / text_encoder / text_decoder 104 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") 105 | 106 | if not os.path.exists(pt_path): 107 | print("The extracted model is not found. Split into submodules..") 108 | extract_ckpt(run_dir, ckpt_name) 109 | 110 | for fname in os.listdir(pt_path): 111 | module_name, ext = os.path.splitext(fname) 112 | if ext != ".pt": 113 | continue 114 | 115 | module = getattr(model, module_name, None) 116 | if module is None: 117 | continue 118 | 119 | module_path = os.path.join(pt_path, fname) 120 | state_dict = torch.load(module_path) 121 | module.load_state_dict(state_dict) 122 | 123 | print("Loading previous checkpoint done") 124 | model = model.to(device) 125 | if eval_mode: 126 | model = model.eval() 127 | return model 128 | 129 | 130 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model") 131 | def hydra_load_model(cfg: DictConfig) -> None: 132 | run_dir = cfg.run_dir 133 | ckpt_name = cfg.ckpt 134 | device = cfg.device 135 | eval_mode = cfg.eval_mode 136 | return load_model(run_dir, ckpt_name, device, eval_mode) 137 | 138 | 139 | if __name__ == "__main__": 140 | hydra_load_model() 141 | -------------------------------------------------------------------------------- /src/tmr/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # For reference 6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence 8 | class KLLoss: 9 | def __call__(self, q, p): 10 | mu_q, logvar_q = q 11 | mu_p, logvar_p = p 12 | 13 | log_var_ratio = logvar_q - logvar_p 14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp() 15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 16 | return div.mean() 17 | 18 | def __repr__(self): 19 | return "KLLoss()" 20 | 21 | 22 | class InfoNCE_with_filtering: 23 | def __init__(self, temperature=0.7, threshold_selfsim=0.8): 24 | self.temperature = temperature 25 | self.threshold_selfsim = threshold_selfsim 26 | 27 | def get_sim_matrix(self, x, y): 28 | x_logits = torch.nn.functional.normalize(x, dim=-1) 29 | y_logits = torch.nn.functional.normalize(y, dim=-1) 30 | sim_matrix = x_logits @ y_logits.T 31 | return sim_matrix 32 | 33 | def __call__(self, x, y, sent_emb=None): 34 | bs, device = len(x), x.device 35 | sim_matrix = self.get_sim_matrix(x, y) / self.temperature 36 | 37 | if sent_emb is not None and self.threshold_selfsim: 38 | # put the threshold value between -1 and 1 39 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1 40 | # Filtering too close values 41 | # mask them by putting -inf in the sim_matrix 42 | selfsim = sent_emb @ sent_emb.T 43 | selfsim_nodiag = selfsim - selfsim.diag().diag() 44 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim) 45 | sim_matrix[idx] = -torch.inf 46 | 47 | labels = torch.arange(bs, device=device) 48 | 49 | total_loss = ( 50 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels) 51 | ) / 2 52 | 53 | return total_loss 54 | 55 | def __repr__(self): 56 | return f"Constrastive(temp={self.temp})" 57 | -------------------------------------------------------------------------------- /src/tmr/text_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | from typing import Dict, List 6 | import torch.nn.functional as F 7 | 8 | 9 | class TextToEmb(nn.Module): 10 | def __init__( 11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu" 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.device = device 16 | from transformers import AutoTokenizer, AutoModel 17 | from transformers import logging 18 | 19 | logging.set_verbosity_error() 20 | 21 | # Tokenizer 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 24 | 25 | # Text model 26 | self.text_model = AutoModel.from_pretrained(modelpath) 27 | # Then configure the model 28 | self.text_encoded_dim = self.text_model.config.hidden_size 29 | 30 | if mean_pooling: 31 | self.forward = self.forward_pooling 32 | 33 | # put it in eval mode by default 34 | self.eval() 35 | 36 | # Freeze the weights just in case 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | self.to(device) 41 | 42 | def train(self, mode: bool = True) -> nn.Module: 43 | # override it to be always false 44 | self.training = False 45 | for module in self.children(): 46 | module.train(False) 47 | return self 48 | 49 | @torch.no_grad() 50 | def forward(self, texts: List[str], device=None) -> Dict: 51 | device = device if device is not None else self.device 52 | 53 | squeeze = False 54 | if isinstance(texts, str): 55 | texts = [texts] 56 | squeeze = True 57 | 58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 59 | output = self.text_model(**encoded_inputs.to(device)) 60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1) 61 | 62 | if squeeze: 63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]} 64 | else: 65 | x_dict = {"x": output.last_hidden_state, "length": length} 66 | return x_dict 67 | 68 | @torch.no_grad() 69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor: 70 | device = device if device is not None else self.device 71 | 72 | squeeze = False 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | squeeze = True 76 | 77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 79 | output = self.text_model(**encoded_inputs.to(device)) 80 | attention_mask = encoded_inputs["attention_mask"] 81 | 82 | # Mean Pooling - Take attention mask into account for correct averaging 83 | token_embeddings = output["last_hidden_state"] 84 | input_mask_expanded = ( 85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 86 | ) 87 | sentence_embeddings = torch.sum( 88 | token_embeddings * input_mask_expanded, 1 89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 90 | # Normalize embeddings 91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 92 | if squeeze: 93 | sentence_embeddings = sentence_embeddings[0] 94 | return sentence_embeddings 95 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/tools/__init__.py -------------------------------------------------------------------------------- /src/tools/easyconvert.py: -------------------------------------------------------------------------------- 1 | import src.tools.geometry as geometry 2 | 3 | def nfeats_of(rottype): 4 | if rottype in ["rotvec", "axisangle"]: 5 | return 3 6 | elif rottype in ["rotquat", "quaternion"]: 7 | return 4 8 | elif rottype in ["rot6d", "6drot", "rotation6d"]: 9 | return 6 10 | elif rottype in ["rotmat"]: 11 | return 9 12 | else: 13 | return TypeError("This rotation type doesn't have features.") 14 | 15 | 16 | def axis_angle_to(newtype, rotations): 17 | if newtype in ["matrix"]: 18 | rotations = geometry.axis_angle_to_matrix(rotations) 19 | return rotations 20 | elif newtype in ["rotmat"]: 21 | rotations = geometry.axis_angle_to_matrix(rotations) 22 | rotations = matrix_to("rotmat", rotations) 23 | return rotations 24 | elif newtype in ["rot6d", "6drot", "rotation6d"]: 25 | rotations = geometry.axis_angle_to_matrix(rotations) 26 | rotations = matrix_to("rot6d", rotations) 27 | return rotations 28 | elif newtype in ["rotquat", "quaternion"]: 29 | rotations = geometry.axis_angle_to_quaternion(rotations) 30 | return rotations 31 | elif newtype in ["rotvec", "axisangle"]: 32 | return rotations 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | def matrix_to(newtype, rotations): 38 | if newtype in ["matrix"]: 39 | return rotations 40 | if newtype in ["rotmat"]: 41 | rotations = rotations.reshape((*rotations.shape[:-2], 9)) 42 | return rotations 43 | elif newtype in ["rot6d", "6drot", "rotation6d"]: 44 | rotations = geometry.matrix_to_rotation_6d(rotations) 45 | return rotations 46 | elif newtype in ["rotquat", "quaternion"]: 47 | rotations = geometry.matrix_to_quaternion(rotations) 48 | return rotations 49 | elif newtype in ["rotvec", "axisangle"]: 50 | rotations = geometry.matrix_to_axis_angle(rotations) 51 | return rotations 52 | else: 53 | raise NotImplementedError 54 | 55 | 56 | def to_matrix(oldtype, rotations): 57 | if oldtype in ["matrix"]: 58 | return rotations 59 | if oldtype in ["rotmat"]: 60 | rotations = rotations.reshape((*rotations.shape[:-2], 3, 3)) 61 | return rotations 62 | elif oldtype in ["rot6d", "6drot", "rotation6d"]: 63 | rotations = geometry.rotation_6d_to_matrix(rotations) 64 | return rotations 65 | elif oldtype in ["rotquat", "quaternion"]: 66 | rotations = geometry.quaternion_to_matrix(rotations) 67 | return rotations 68 | elif oldtype in ["rotvec", "axisangle"]: 69 | rotations = geometry.axis_angle_to_matrix(rotations) 70 | return rotations 71 | else: 72 | raise NotImplementedError 73 | -------------------------------------------------------------------------------- /src/tools/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tqdm 3 | 4 | 5 | class LevelsFilter(logging.Filter): 6 | def __init__(self, levels): 7 | self.levels = [getattr(logging, level) for level in levels] 8 | 9 | def filter(self, record): 10 | return record.levelno in self.levels 11 | 12 | 13 | class StreamToLogger(object): 14 | """ 15 | Fake file-like stream object that redirects writes to a logger instance. 16 | """ 17 | def __init__(self, logger, level): 18 | self.logger = logger 19 | self.level = level 20 | self.linebuf = '' 21 | 22 | def write(self, buf): 23 | for line in buf.rstrip().splitlines(): 24 | self.logger.log(self.level, line.rstrip()) 25 | 26 | def flush(self): 27 | pass 28 | 29 | 30 | class TqdmLoggingHandler(logging.Handler): 31 | def __init__(self, level=logging.NOTSET): 32 | super().__init__(level) 33 | 34 | def emit(self, record): 35 | try: 36 | msg = self.format(record) 37 | tqdm.tqdm.write(msg) 38 | self.flush() 39 | except Exception: 40 | self.handleError(record) 41 | -------------------------------------------------------------------------------- /src/tools/runid.py: -------------------------------------------------------------------------------- 1 | # 2 | """ 3 | runid util. 4 | Taken from wandb.sdk.lib.runid 5 | """ 6 | 7 | import shortuuid # type: ignore 8 | 9 | 10 | def generate_id() -> str: 11 | # ~3t run ids (36**8) 12 | run_gen = shortuuid.ShortUUID(alphabet=list("0123456789abcdefghijklmnopqrstuvwxyz")) 13 | return run_gen.random(8) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/art_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def rgba(c: str): 3 | from matplotlib import colors as mcolors 4 | return mcolors.to_rgba(c) 5 | 6 | def rgb(c: str): 7 | from matplotlib import colors as mcolors 8 | return mcolors.to_rgb(c) 9 | 10 | color_map = { 11 | 'source_motion': rgba('darkred'), 12 | 'source': rgba('darkred'), 13 | 'target_motion': rgba('olivedrab'), 14 | 'input': rgba('olivedrab'), 15 | 'target': rgba('olivedrab'), 16 | 'generation': rgba('purple'), 17 | 'generated': rgba('steelblue'), 18 | 'denoised': rgba('purple'), 19 | 'noised': rgba('darkgrey'), 20 | } -------------------------------------------------------------------------------- /src/utils/cherrypick.py: -------------------------------------------------------------------------------- 1 | test_subset_mfix = ['000044', '000050', '000161', '000164', '000169', '000175', 2 | '000320', '000331', '000488', '000496', '000513', '001247', 3 | '001288', '001297', '001345', '001361', '001362', '001372', 4 | '001386', '001221', '001745', '001745', '000604', '001876', 5 | '000763', '001856', '001577', '000469', '000555', '000555', 6 | '000605', '001907', '001687', '001437', '000819', '000161', 7 | '000973'] 8 | 9 | -------------------------------------------------------------------------------- /src/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | test_keyds= ["001895", "001893", "001892", "001888", "001848", "001881", 4 | "001827", "001845", "001858", "001812", "001798", "001755"] 5 | 6 | 7 | keyids_for_testing = ["001892", "001881", "001800", "001790", 8 | "001849", "001860", "001862", 9 | "001641", "001696", "001676", "001690", "001692", 10 | "001762", "001794", "001694", "001812", "001874", 11 | "001883"] 12 | 13 | def split_txt_into_multi_lines(input_str: str, line_length: int): 14 | words = input_str.split(' ') 15 | line_count = 0 16 | split_input = '' 17 | for word in words: 18 | line_count += 1 19 | line_count += len(word) 20 | if line_count > line_length: 21 | split_input += '\n' 22 | line_count = len(word) + 1 23 | split_input += word 24 | split_input += ' ' 25 | else: 26 | split_input += word 27 | split_input += ' ' 28 | 29 | return split_input 30 | 31 | def regroup_metrics(metrics): 32 | from src.info.joints import smplh_joints 33 | pose_names = smplh_joints[1:23] 34 | dico = {key: val.numpy() for key, val in metrics.items()} 35 | 36 | if "APE_pose" in dico: 37 | APE_pose = dico.pop("APE_pose") 38 | for name, ape in zip(pose_names, APE_pose): 39 | dico[f"APE_pose_{name}"] = ape 40 | 41 | if "APE_joints" in dico: 42 | APE_joints = dico.pop("APE_joints") 43 | for name, ape in zip(smplh_joints, APE_joints): 44 | dico[f"APE_joints_{name}"] = ape 45 | 46 | if "AVE_pose" in dico: 47 | AVE_pose = dico.pop("AVE_pose") 48 | for name, ave in zip(pose_names, AVE_pose): 49 | dico[f"AVE_pose_{name}"] = ave 50 | 51 | if "AVE_joints" in dico: 52 | AVE_joints = dico.pop("AVE_joints") 53 | for name, ape in zip(smplh_joints, AVE_joints): 54 | dico[f"AVE_joints_{name}"] = ave 55 | 56 | return dico 57 | 58 | 59 | def sanitize(dico): 60 | dico = {key: "{:.5f}".format(float(val)) for key, val in dico.items()} 61 | return dico 62 | 63 | def out2blender(dicto): 64 | blender_dic = {} 65 | blender_dic['trans'] = dicto['body_transl'] 66 | blender_dic['rots'] = torch.cat([dicto['body_orient'], 67 | dicto['body_pose']], dim=-1) 68 | return blender_dic -------------------------------------------------------------------------------- /src/utils/genutils.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | import logging 6 | import os 7 | import random 8 | from einops import rearrange 9 | from pathlib import Path 10 | 11 | def extract_data_path(full_path, directory_name="data"): 12 | """ 13 | Slices the given path up to and including the specified directory. 14 | 15 | Args: 16 | full_path (str): The full path as a string. 17 | directory_name (str): The directory to slice up to (included in the result). 18 | 19 | Returns: 20 | str: The sliced path as a string up to and including the specified directory. 21 | """ 22 | path = Path(full_path) 23 | subpath = Path() 24 | for part in path.parts: 25 | subpath /= part 26 | if part == directory_name: 27 | break 28 | 29 | return str(subpath) 30 | 31 | 32 | def freeze(model) -> None: 33 | r""" 34 | Freeze all params for inference. 35 | """ 36 | for param in model.parameters(): 37 | param.requires_grad = False 38 | 39 | model.eval() 40 | 41 | # A logger for this file 42 | log = logging.getLogger(__name__) 43 | def to_tensor(array): 44 | if torch.is_tensor(array): 45 | return array 46 | else: 47 | return torch.tensor(array) 48 | 49 | def DotDict(in_dict): 50 | if isinstance(in_dict, dotdict): 51 | return in_dict 52 | out_dict = copy(in_dict) 53 | for k,v in out_dict.items(): 54 | if isinstance(v,dict): 55 | out_dict[k] = DotDict(v) 56 | return dotdict(out_dict) 57 | 58 | 59 | def dict_to_device(tensor_dict, device): 60 | return {k: v.to(device) for k, v in tensor_dict.items()} 61 | 62 | class dotdict(dict): 63 | """dot.notation access to dictionary attributes""" 64 | __getattr__ = dict.get 65 | __setattr__ = dict.__setitem__ 66 | __delattr__ = dict.__delitem__ 67 | 68 | def cast_dict_to_tensors(d, device="cpu"): 69 | if isinstance(d, dict): 70 | return {k: cast_dict_to_tensors(v, device) for k, v in d.items()} 71 | elif isinstance(d, np.ndarray): 72 | return torch.from_numpy(d).float().to(device) 73 | elif isinstance(d, torch.Tensor): 74 | return d.to(device) 75 | else: 76 | return d -------------------------------------------------------------------------------- /src/utils/motionfix_utils.py: -------------------------------------------------------------------------------- 1 | test_subset_amt = ['001247','000038', '000091','000100', 2 | '000122','000133', '000177', '000227', 3 | '000230', '000246', '000289', '000362', 4 | '000522', '000544','000564', '000689', 5 | '000686','000847', '000917', '000947', 6 | '000931', '000899','001327', '001546', 7 | '000432'] 8 | -------------------------------------------------------------------------------- /tmr_evaluator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/tmr_evaluator/__init__.py -------------------------------------------------------------------------------- /tmr_evaluator/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 6 | """Numpy implementation of the Frechet Distance. 7 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 8 | and X_2 ~ N(mu_2, C_2) is 9 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 10 | Stable version by Dougal J. Sutherland. 11 | Params: 12 | -- mu1 : Numpy array containing the activations of a layer of the 13 | inception net (like returned by the function 'get_predictions') 14 | for generated samples. 15 | -- mu2 : The sample mean over activations, precalculated on an 16 | representative dataset set. 17 | -- sigma1: The covariance matrix over activations for generated samples. 18 | -- sigma2: The covariance matrix over activations, precalculated on an 19 | representative dataset set. 20 | Returns: 21 | -- : The Frechet Distance. 22 | """ 23 | 24 | mu1 = np.atleast_1d(mu1) 25 | mu2 = np.atleast_1d(mu2) 26 | 27 | sigma1 = np.atleast_2d(sigma1) 28 | sigma2 = np.atleast_2d(sigma2) 29 | 30 | assert mu1.shape == mu2.shape, \ 31 | 'Training and test mean vectors have different lengths' 32 | assert sigma1.shape == sigma2.shape, \ 33 | 'Training and test covariances have different dimensions' 34 | 35 | diff = mu1 - mu2 36 | 37 | # Product might be almost singular 38 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 39 | if not np.isfinite(covmean).all(): 40 | msg = ('fid calculation produces singular product; ' 41 | 'adding %s to diagonal of cov estimates') % eps 42 | print(msg) 43 | offset = np.eye(sigma1.shape[0]) * eps 44 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 45 | 46 | # Numerical error might give slight imaginary component 47 | if np.iscomplexobj(covmean): 48 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 49 | m = np.max(np.abs(covmean.imag)) 50 | raise ValueError('Imaginary component {}'.format(m)) 51 | covmean = covmean.real 52 | 53 | tr_covmean = np.trace(covmean) 54 | 55 | return (diff.dot(diff) + np.trace(sigma1) + 56 | np.trace(sigma2) - 2 * tr_covmean) 57 | 58 | def calculate_activation_statistics(activations): 59 | """ 60 | Params: 61 | -- activation: num_samples x dim_feat 62 | Returns: 63 | -- mu: dim_feat 64 | -- sigma: dim_feat x dim_feat 65 | """ 66 | mu = np.mean(activations, axis=0) 67 | cov = np.cov(activations, rowvar=False) 68 | return mu, cov 69 | 70 | def evaluate_fid(gt_motion_embeddings, gen_motion_embeddings): 71 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) 72 | mu, cov = calculate_activation_statistics(gen_motion_embeddings) 73 | fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) 74 | print(f'FID: {fid:.4f}') 75 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lzhyu/SimMotionEdit/384135a8707bc247fba23005d57cca7ca2d751ce/utils/__init__.py -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseMask(object): 5 | @property 6 | def bool_matrix(self): 7 | """Return a bool (uint8) matrix with 1s to all places that should be 8 | kept.""" 9 | raise NotImplementedError() 10 | 11 | @property 12 | def float_matrix(self): 13 | """Return the bool matrix as a float to be used as a multiplicative 14 | mask for non softmax attentions.""" 15 | if not hasattr(self, "_float_matrix"): 16 | with torch.no_grad(): 17 | self._float_matrix = self.bool_matrix.float() 18 | return self._float_matrix 19 | 20 | @property 21 | def lengths(self): 22 | """If the matrix is of the following form 23 | 24 | 1 1 1 0 0 0 0 25 | 1 0 0 0 0 0 0 26 | 1 1 0 0 0 0 0 27 | 28 | then return it as a vector of integers 29 | 30 | 3 1 2. 31 | """ 32 | if not hasattr(self, "_lengths"): 33 | with torch.no_grad(): 34 | lengths = self.bool_matrix.long().sum(dim=-1) 35 | # make sure that the mask starts with 1s and continues with 0s 36 | # this should be changed to something more efficient, however, 37 | # I chose simplicity over efficiency since the LengthMask class 38 | # will be used anyway (and the result is cached) 39 | m = self.bool_matrix.view(-1, self.shape[-1]) 40 | for i, l in enumerate(lengths.view(-1)): 41 | if not torch.all(m[i, :l]): 42 | raise ValueError("The mask is not a length mask") 43 | self._lengths = lengths 44 | return self._lengths 45 | 46 | @property 47 | def shape(self): 48 | """Return the shape of the boolean mask.""" 49 | return self.bool_matrix.shape 50 | 51 | @property 52 | def additive_matrix(self): 53 | """Return a float matrix to be added to an attention matrix before 54 | softmax.""" 55 | if not hasattr(self, "_additive_matrix"): 56 | with torch.no_grad(): 57 | self._additive_matrix = torch.log(self.bool_matrix.float()) 58 | return self._additive_matrix 59 | 60 | @property 61 | def additive_matrix_finite(self): 62 | """Same as additive_matrix but with -1e24 instead of infinity.""" 63 | if not hasattr(self, "_additive_matrix_finite"): 64 | with torch.no_grad(): 65 | self._additive_matrix_finite = ( 66 | (~self.bool_matrix).float() * (-1e24) 67 | ) 68 | return self._additive_matrix_finite 69 | 70 | @property 71 | def all_ones(self): 72 | """Return true if the mask is all ones.""" 73 | if not hasattr(self, "_all_ones"): 74 | with torch.no_grad(): 75 | self._all_ones = torch.all(self.bool_matrix) 76 | return self._all_ones 77 | 78 | @property 79 | def lower_triangular(self): 80 | """Return true if the attention is a triangular causal mask.""" 81 | if not hasattr(self, "_lower_triangular"): 82 | self._lower_triangular = False 83 | with torch.no_grad(): 84 | try: 85 | lengths = self.lengths 86 | if len(lengths.shape) == 1: 87 | target = torch.arange( 88 | 1, 89 | len(lengths)+1, 90 | device=lengths.device 91 | ) 92 | self._lower_triangular = torch.all(lengths == target) 93 | except ValueError: 94 | pass 95 | return self._lower_triangular 96 | 97 | class LengthMask(BaseMask): 98 | """Provide a BaseMask interface for lengths. Mostly to be used with 99 | sequences of different lengths. 100 | 101 | Arguments 102 | --------- 103 | lengths: The lengths as a PyTorch long tensor 104 | max_len: The maximum length for the mask (defaults to lengths.max()) 105 | device: The device to be used for creating the masks (defaults to 106 | lengths.device) 107 | """ 108 | def __init__(self, lengths, max_len=None, device=None): 109 | self._device = device or lengths.device 110 | with torch.no_grad(): 111 | self._lengths = lengths.clone().to(self._device) 112 | self._max_len = max_len or self._lengths.max() 113 | 114 | self._bool_matrix = None 115 | self._all_ones = torch.all(self._lengths == self._max_len).item() 116 | 117 | @property 118 | def bool_matrix(self): 119 | if self._bool_matrix is None: 120 | with torch.no_grad(): 121 | indices = torch.arange(self._max_len, device=self._device) 122 | self._bool_matrix = ( 123 | indices.view(1, -1) < self._lengths.view(-1, 1) 124 | ) 125 | return self._bool_matrix 126 | 127 | --------------------------------------------------------------------------------