├── configs ├── data │ ├── babel.yaml │ ├── kitml.yaml │ ├── humanml3d.yaml │ ├── babel_actions_120.yaml │ ├── babel_actions_60.yaml │ ├── humanml3d_kitml.yaml │ ├── humanml3d_kitml_babel.yaml │ ├── motion_loader │ │ ├── guoh3dfeats.yaml │ │ └── _base.yaml │ ├── _base_augmented.yaml │ └── _base.yaml ├── extract.yaml ├── load_model.yaml ├── model │ ├── tmr_text_averaging.yaml │ ├── tmr_hn.yaml │ ├── tmr_text_averaging_hn.yaml │ ├── tmr.yaml │ └── temos.yaml ├── train_hn_with_augmentation.yaml ├── encode_text.yaml ├── text_motion_sim.yaml ├── encode_motion.yaml ├── debug │ ├── profiler.yaml │ └── train.yaml ├── train_hn.yaml ├── renderer │ └── matplotlib.yaml ├── compute_guoh3dfeats.yaml ├── render.yaml ├── text_embeddings.yaml ├── text_embeddings_with_augmentation.yaml ├── defaults.yaml ├── encode_dataset.yaml ├── motion_stats.yaml ├── combine_datasets.yaml ├── hydra │ ├── hydra_logging │ │ └── tqdm.yaml │ └── job_logging │ │ └── tqdm.yaml ├── retrieval.yaml ├── text_dataset_sim.yaml ├── train.yaml ├── retrieval_action_multi_labels.yaml ├── train_with_augmentation.yaml └── trainer.yaml ├── src ├── callback │ ├── __init__.py │ ├── tqdmbar.py │ └── progress.py ├── guofeats │ ├── __init__.py │ ├── skeleton_example_h3d.npy │ └── paramUtil.py ├── model │ ├── __init__.py │ ├── tmr_text_averaging.py │ ├── text_encoder.py │ ├── losses.py │ ├── actor.py │ ├── tmr.py │ └── temos.py ├── prepare.py ├── logging.py ├── config.py ├── data │ ├── motion.py │ ├── collate.py │ ├── text_motion_multi_labels.py │ ├── text_motion.py │ ├── augmented_text_motion.py │ └── text.py ├── joints.py ├── load.py ├── logger │ ├── csv.py │ └── csv_fabric.py ├── rifke.py └── renderer │ └── matplotlib.py ├── stats ├── babel │ └── guoh3dfeats │ │ ├── std.pt │ │ └── mean.pt ├── kitml │ └── guoh3dfeats │ │ ├── std.pt │ │ └── mean.pt ├── humanml3d │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── humanml3d_kitml │ └── guoh3dfeats │ │ ├── std.pt │ │ └── mean.pt ├── babel_actions_120 │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── babel_actions_60 │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt └── humanml3d_kitml_babel │ └── guoh3dfeats │ ├── mean.pt │ └── std.pt ├── LICENSE.md ├── datasets └── annotations │ ├── kitml │ └── splits │ │ ├── test_tiny.txt │ │ ├── val_tiny.txt │ │ ├── train_tiny.txt │ │ ├── nsim_test.txt │ │ ├── val.txt │ │ └── test.txt │ └── humanml3d │ └── splits │ ├── test_tiny.txt │ ├── train_tiny.txt │ ├── val_tiny.txt │ └── nsim_test.txt ├── extract.py ├── prepare ├── download_pretrain_models.sh ├── motion_stats.py ├── tools.py ├── text_embeddings.py ├── compute_guoh3dfeats.py └── combine_datasets.py ├── requirements.txt ├── demo ├── load.py └── model.py ├── encode_text.py ├── encode_motion.py ├── text_motion_sim.py ├── train.py ├── encode_dataset.py ├── .gitignore ├── retrieval_action.py ├── retrieval_action_multi_labels.py ├── DATASETS.md ├── retrieval.py └── README.md /configs/data/babel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/kitml.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/babel_actions_120.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/babel_actions_60.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d_kitml.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d_kitml_babel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/extract.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | ckpt: last 7 | -------------------------------------------------------------------------------- /src/callback/__init__.py: -------------------------------------------------------------------------------- 1 | # from .render import RenderCallback 2 | from .progress import ProgressLogger 3 | -------------------------------------------------------------------------------- /src/guofeats/__init__.py: -------------------------------------------------------------------------------- 1 | from .motion_representation import joints_to_guofeats, guofeats_to_joints # noqa 2 | -------------------------------------------------------------------------------- /stats/babel/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/kitml/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/kitml/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/babel/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/kitml/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/kitml/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /src/guofeats/skeleton_example_h3d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/src/guofeats/skeleton_example_h3d.npy -------------------------------------------------------------------------------- /configs/load_model.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | ckpt: last 7 | device: cuda 8 | eval_mode: true 9 | -------------------------------------------------------------------------------- /configs/model/tmr_text_averaging.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr 3 | - _self_ 4 | 5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging 6 | -------------------------------------------------------------------------------- /stats/humanml3d_kitml/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d_kitml/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /configs/train_hn_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_with_augmentation 3 | - override model: tmr_text_averaging_hn 4 | - _self_ 5 | -------------------------------------------------------------------------------- /stats/babel_actions_120/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel_actions_120/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/babel_actions_120/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel_actions_120/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/babel_actions_60/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel_actions_60/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/babel_actions_60/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/babel_actions_60/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d_kitml/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Léore Bensabath 4 | 5 | TMR LICENCE 6 | -------------------------------------------------------------------------------- /configs/encode_text.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | text: ??? 7 | 8 | ckpt_name: last 9 | device: cuda 10 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/test_tiny.txt: -------------------------------------------------------------------------------- 1 | 00004 2 | 00010 3 | 00019 4 | 00033 5 | 00035 6 | 00036 7 | 00044 8 | 00053 9 | 00059 10 | 00074 11 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/val_tiny.txt: -------------------------------------------------------------------------------- 1 | 00006 2 | 00015 3 | 00025 4 | 00052 5 | 00068 6 | 00092 7 | 00093 8 | 00126 9 | 00143 10 | 00170 11 | -------------------------------------------------------------------------------- /stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml_babel/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/HEAD/stats/humanml3d_kitml_babel/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/train_tiny.txt: -------------------------------------------------------------------------------- 1 | 00001 2 | 00002 3 | 00003 4 | 00005 5 | 00007 6 | 00008 7 | 00009 8 | 00011 9 | 00013 10 | 00014 11 | -------------------------------------------------------------------------------- /configs/data/motion_loader/guoh3dfeats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | 5 | base_dir: datasets/motions/guoh3dfeats 6 | fps: 20.0 7 | nfeats: 263 8 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/test_tiny.txt: -------------------------------------------------------------------------------- 1 | 000000 2 | 000019 3 | 000021 4 | 000022 5 | 000026 6 | 000048 7 | 000055 8 | 000063 9 | 000066 10 | 000067 11 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/train_tiny.txt: -------------------------------------------------------------------------------- 1 | 000001 2 | 000002 3 | 000003 4 | 000004 5 | 000005 6 | 000006 7 | 000007 8 | 000008 9 | 000009 10 | 000010 11 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/val_tiny.txt: -------------------------------------------------------------------------------- 1 | 012698 2 | 012808 3 | 008646 4 | 013022 5 | 003172 6 | 008859 7 | 005095 8 | 012044 9 | 002345 10 | 008039 11 | -------------------------------------------------------------------------------- /configs/text_motion_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | npy: ??? 7 | feats: false 8 | text: ??? 9 | 10 | ckpt_name: last 11 | device: cuda 12 | -------------------------------------------------------------------------------- /configs/encode_motion.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | npy: ??? 7 | 8 | start: !!null 9 | end: !!null 10 | 11 | ckpt_name: last 12 | device: cuda 13 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: debug/ 4 | 5 | trainer: 6 | max_epochs: 1 7 | check_val_every_n_epoch: 1 8 | callbacks: null 9 | profiler: simple 10 | -------------------------------------------------------------------------------- /configs/train_hn.yaml: -------------------------------------------------------------------------------- 1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader} 2 | 3 | defaults: 4 | - train 5 | - override /model: tmr_hn 6 | - _self_ 7 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa 2 | from .temos import TEMOS # noqa 3 | from .tmr import TMR # noqa 4 | from .text_encoder import TextToEmb # noqa 5 | -------------------------------------------------------------------------------- /configs/renderer/matplotlib.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.renderer.matplotlib.MatplotlibRender 2 | 3 | jointstype: "guoh3djoints" 4 | fps: 20.0 5 | colors: ['black', 'magenta', 'red', 'green', 'blue'] 6 | figsize: 4 7 | canonicalize: true 8 | -------------------------------------------------------------------------------- /configs/compute_guoh3dfeats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | base_folder: datasets/motions/pose_data 6 | output_folder: datasets/motions/guoh3dfeats 7 | 8 | force_redo: false # true to recompute the features 9 | -------------------------------------------------------------------------------- /configs/model/tmr_hn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr 3 | - _self_ 4 | 5 | contrastive_loss: 6 | _target_: src.model.losses.HN_InfoNCE_with_filtering 7 | temperature: 0.1 8 | threshold_selfsim: 0.80 9 | alpha: 0.999 10 | beta: 0.5 11 | -------------------------------------------------------------------------------- /configs/render.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - renderer: matplotlib 3 | - defaults 4 | - _self_ 5 | 6 | npy_path: ??? 7 | title: "" 8 | 9 | swap_axis: false 10 | guofeats: false 11 | rifkefeats: false 12 | 13 | renderer: 14 | canonicalize: true 15 | -------------------------------------------------------------------------------- /configs/text_embeddings.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | annotations_filename: annotations.json 9 | output_folder_name_token: token_embeddings 10 | output_folder_name_sent: sent_embeddings 11 | -------------------------------------------------------------------------------- /configs/data/motion_loader/_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.motion.AMASSMotionLoader 2 | 3 | base_dir: ??? 4 | 5 | normalizer: 6 | _target_: src.data.motion.Normalizer 7 | base_dir: stats/${hydra:runtime.choices.data}/${hydra:runtime.choices.data/motion_loader} 8 | eps: 1e-12 9 | -------------------------------------------------------------------------------- /configs/text_embeddings_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | annotations_filename: annotations_all.json 9 | output_folder_name_token: token_embeddings_all 10 | output_folder_name_sent: sent_embeddings_all 11 | -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: logs 4 | 5 | hydra: 6 | run: 7 | dir: ${run_dir} 8 | 9 | seed: 1234 10 | logger_level: INFO 11 | 12 | 13 | defaults: 14 | - _self_ 15 | - override hydra/job_logging: tqdm 16 | - override hydra/hydra_logging: tqdm 17 | -------------------------------------------------------------------------------- /configs/encode_dataset.yaml: -------------------------------------------------------------------------------- 1 | dataloader: 2 | _target_: torch.utils.data.DataLoader 3 | batch_size: 32 4 | num_workers: 8 5 | shuffle: true 6 | 7 | defaults: 8 | - data: humanml3d 9 | - defaults 10 | - _self_ 11 | 12 | run_dir: ??? 13 | 14 | ckpt_name: last 15 | device: cuda 16 | -------------------------------------------------------------------------------- /configs/motion_stats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | data: 7 | preload: false 8 | motion_loader: 9 | normalizer: 10 | disable: true 11 | text_to_token_emb: 12 | disable: true 13 | text_to_sent_emb: 14 | disable: true 15 | -------------------------------------------------------------------------------- /configs/debug/train.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: debug/ 4 | 5 | data: 6 | tiny: true 7 | preload: false 8 | 9 | dataloader: 10 | num_workers: 0 11 | shuffle: false 12 | 13 | trainer: 14 | enable_model_summary: false 15 | max_epochs: 3 16 | check_val_every_n_epoch: 1 17 | -------------------------------------------------------------------------------- /configs/model/tmr_text_averaging_hn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr_hn 3 | - _self_ 4 | 5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging 6 | 7 | contrastive_loss: 8 | _target_: src.model.losses.HN_InfoNCE_with_filtering 9 | temperature: 0.1 10 | threshold_selfsim: 0.80 11 | alpha: 0.999 12 | beta: 0.5 13 | -------------------------------------------------------------------------------- /configs/combine_datasets.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | annotations_path: datasets/annotations 6 | 7 | datasets: 8 | - humanml3d 9 | - kitml 10 | 11 | test_sets: 12 | - humanml3d 13 | - kitml 14 | 15 | filter_babel_seg: False 16 | 17 | split_suffix: '' 18 | 19 | min_duration: null 20 | max_duration: null 21 | -------------------------------------------------------------------------------- /configs/hydra/hydra_logging/tqdm.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | verysimple: 5 | format: '%(message)s' 6 | 7 | handlers: 8 | console: 9 | class: src.logging.TqdmLoggingHandler 10 | formatter: verysimple 11 | 12 | root: 13 | level: ${logger_level} 14 | handlers: [console] 15 | 16 | 17 | disable_existing_loggers: false 18 | -------------------------------------------------------------------------------- /configs/model/tmr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - temos 3 | - _self_ 4 | 5 | _target_: src.model.TMR 6 | 7 | lmd: 8 | recons: 1.0 9 | latent: 1.0e-5 10 | kl: 1.0e-5 11 | contrastive: 0.1 12 | 13 | lr: 1e-4 14 | threshold_selfsim_metrics: 0.95 15 | 16 | contrastive_loss: 17 | _target_: src.model.losses.InfoNCE_with_filtering 18 | temperature: 0.1 19 | threshold_selfsim: 0.80 20 | -------------------------------------------------------------------------------- /configs/retrieval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | - data: humanml3d 5 | 6 | device: cuda 7 | 8 | run_dir: ??? 9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data} 10 | protocol: all # (is all 4), normal (a), threshold (b), nsim (c), guo (d) 11 | threshold: 0.95 # threashold to compute (b) 12 | 13 | ckpt: last 14 | batch_size: 256 15 | 16 | split: test 17 | -------------------------------------------------------------------------------- /configs/text_dataset_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | run_dir: outputs/tmr_humanml3d 7 | feats: false 8 | text: ??? 9 | split: train 10 | 11 | ckpt_name: last 12 | device: cuda 13 | 14 | data: 15 | preload: false 16 | text_to_token_emb: 17 | preload: false 18 | disable: true 19 | text_to_sent_emb: 20 | preload: false 21 | disable: true 22 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | ckpt: last 2 | resume_dir: null 3 | ckpt_path: null 4 | 5 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader} 6 | 7 | dataloader: 8 | _target_: torch.utils.data.DataLoader 9 | batch_size: 32 10 | num_workers: 8 11 | 12 | defaults: 13 | - data: humanml3d 14 | - data_val: null 15 | - model: tmr 16 | - trainer 17 | - defaults 18 | - _self_ 19 | -------------------------------------------------------------------------------- /src/prepare.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import warnings 4 | 5 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator") 6 | logger.setLevel(logging.ERROR) 7 | 8 | 9 | warnings.filterwarnings( 10 | "ignore", "The PyTorch API of nested tensors is in prototype stage*" 11 | ) 12 | 13 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*") 14 | 15 | torch.set_float32_matmul_precision("high") 16 | -------------------------------------------------------------------------------- /configs/data/_base_augmented.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - motion_loader: guoh3dfeats 3 | - _base 4 | - _self_ 5 | 6 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset 7 | 8 | paraphrase_filename: annotations_paraphrased.json 9 | summary_filename: annotations_summarized.json 10 | paraphrase_prob: 0.2 11 | summary_prob: 0.2 12 | averaging_prob: 0.4 13 | text_sampling_nbr: null 14 | 15 | text_to_token_emb: 16 | name: token_embeddings_all # TODO 17 | 18 | text_to_sent_emb: 19 | name: sent_embeddings_all 20 | -------------------------------------------------------------------------------- /configs/retrieval_action_multi_labels.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: babel_actions_120 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | run_dir: ??? 9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data} 10 | 11 | ckpt: last 12 | batch_size: 256 13 | 14 | split: test 15 | 16 | data: 17 | _target_: src.data.text_motion_multi_labels.TextMotionMultiLabelsDataset 18 | tiny: False 19 | 20 | text_to_token_emb: 21 | name: token_embeddings 22 | 23 | text_to_sent_emb: 24 | name: sent_embeddings 25 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hydra.main(config_path="configs", config_name="extract", version_base="1.3") 9 | def extract(cfg: DictConfig): 10 | run_dir = cfg.run_dir 11 | ckpt = cfg.ckpt 12 | 13 | from src.load import extract_ckpt 14 | 15 | logger.info("Extracting the checkpoint...") 16 | extract_ckpt(run_dir, ckpt_name=ckpt) 17 | logger.info("Done") 18 | 19 | 20 | if __name__ == "__main__": 21 | extract() 22 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tqdm 3 | 4 | 5 | # from https://stackoverflow.com/questions/38543506/change-logging-print-function-to-tqdm-write-so-logging-doesnt-interfere-wit 6 | class TqdmLoggingHandler(logging.Handler): 7 | def __init__(self, level=logging.NOTSET): 8 | super().__init__(level) 9 | 10 | def emit(self, record): 11 | try: 12 | msg = self.format(record) 13 | tqdm.tqdm.write(msg) 14 | self.flush() 15 | except Exception: 16 | self.handleError(record) 17 | -------------------------------------------------------------------------------- /prepare/download_pretrain_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo -e "The pretrained models will stored in the 'models' folder\n" 3 | mkdir -p models 4 | python -m gdown.cli "https://drive.google.com/uc?id=1n6kRb-d2gKsk8EXfFULFIpaUKYcnaYmm" 5 | 6 | echo -e "Please check that the md5sum is: 7b6d8814f9c1ca972f62852ebb6c7a6f" 7 | echo -e "+ md5sum tmr_models.tgz" 8 | md5sum tmr_models.tgz 9 | 10 | echo -e "If it is not, please rerun this script" 11 | 12 | sleep 5 13 | tar xfzv tmr_models.tgz 14 | 15 | echo -e "Cleaning\n" 16 | rm tmr_models.tgz 17 | 18 | echo -e "Downloading done!" 19 | -------------------------------------------------------------------------------- /configs/data/_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - motion_loader: guoh3dfeats 3 | - _self_ 4 | 5 | _target_: src.data.text_motion.TextMotionDataset 6 | 7 | path: datasets/annotations/${hydra:runtime.choices.data} 8 | 9 | text_to_token_emb: 10 | _target_: src.data.text.TokenEmbeddings 11 | path: datasets/annotations/${hydra:runtime.choices.data} 12 | modelname: distilbert-base-uncased 13 | modelpath: null 14 | preload: true 15 | 16 | text_to_sent_emb: 17 | _target_: src.data.text.SentenceEmbeddings 18 | path: datasets/annotations/${hydra:runtime.choices.data} 19 | modelname: sentence-transformers/all-mpnet-base-v2 20 | modelpath: null 21 | preload: true 22 | 23 | preload: true 24 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | 6 | def save_config(cfg: DictConfig) -> str: 7 | path = os.path.join(cfg.run_dir, "config.json") 8 | config = OmegaConf.to_container(cfg, resolve=True) 9 | with open(path, "w") as f: 10 | string = json.dumps(config, indent=4) 11 | f.write(string) 12 | return path 13 | 14 | 15 | def read_config(run_dir: str, return_json=False) -> DictConfig: 16 | path = os.path.join(run_dir, "config.json") 17 | with open(path, "r") as f: 18 | config = json.load(f) 19 | if return_json: 20 | return config 21 | cfg = OmegaConf.create(config) 22 | cfg.run_dir = run_dir 23 | return cfg 24 | -------------------------------------------------------------------------------- /configs/train_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_augmented_${hydra:runtime.choices.data/motion_loader} 2 | 3 | defaults: 4 | - train 5 | - data_val: null 6 | - override data: humanml3d_kitml 7 | - override model: tmr_text_averaging 8 | - _self_ 9 | 10 | data: 11 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset 12 | paraphrase_filename: annotations_paraphrases.json 13 | summary_filename: annotations_actions.json 14 | paraphrase_prob: 0.2 15 | summary_prob: 0.1 16 | averaging_prob: 0.3 17 | preload: True 18 | text_sampling_nbr: null 19 | 20 | text_to_token_emb: 21 | name: token_embeddings_all 22 | 23 | text_to_sent_emb: 24 | name: sent_embeddings_all 25 | 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.5 2 | aiosignal==1.3.1 3 | antlr4-python3-runtime==4.9.3 4 | async-timeout==4.0.3 5 | attrs==23.1.0 6 | certifi==2023.7.22 7 | charset-normalizer==2.1.1 8 | cmake==3.25.0 9 | colorlog==6.7.0 10 | einops==0.6.1 11 | filelock==3.9.0 12 | frozenlist==1.4.0 13 | fsspec==2023.9.0 14 | hydra-colorlog==1.2.0 15 | hydra-core==1.3.2 16 | idna==3.4 17 | Jinja2==3.1.2 18 | lightning-utilities==0.9.0 19 | lit==15.0.7 20 | MarkupSafe==2.1.2 21 | mpmath==1.3.0 22 | multidict==6.0.4 23 | networkx==3.0 24 | numpy==1.24.1 25 | omegaconf==2.3.0 26 | orjson==3.9.7 27 | packaging==23.1 28 | Pillow==9.3.0 29 | pytorch-lightning==2.0.9 30 | PyYAML==6.0.1 31 | requests==2.31.0 32 | scipy==1.11.2 33 | sympy==1.11.1 34 | torchmetrics==1.1.2 35 | transformers==4.41.2 36 | tqdm==4.66.1 37 | triton==2.0.0 38 | typing_extensions==4.4.0 39 | urllib3==1.26.13 40 | yarl==1.9.2 41 | -------------------------------------------------------------------------------- /prepare/motion_stats.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import hydra 3 | from omegaconf import DictConfig 4 | from hydra.utils import instantiate 5 | from tqdm import tqdm 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | @hydra.main(config_path="../configs", config_name="motion_stats", version_base="1.3") 11 | def motion_stats(cfg: DictConfig): 12 | logger.info("Computing motion stats") 13 | import src.prepare # noqa 14 | 15 | train_dataset = instantiate(cfg.data, split="train") 16 | import torch 17 | 18 | feats = torch.cat([x["motion_x_dict"]["x"] for x in tqdm(train_dataset)]) 19 | mean = feats.mean(0) 20 | std = feats.std(0) 21 | 22 | normalizer = train_dataset.motion_loader.normalizer 23 | logger.info(f"Saving them in {normalizer.base_dir}") 24 | normalizer.save(mean, std) 25 | 26 | 27 | if __name__ == "__main__": 28 | motion_stats() 29 | -------------------------------------------------------------------------------- /configs/model/temos.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.TEMOS 2 | 3 | motion_encoder: 4 | _target_: src.model.ACTORStyleEncoder 5 | nfeats: ${data.motion_loader.nfeats} 6 | vae: true 7 | latent_dim: 256 8 | ff_size: 1024 9 | num_layers: 6 10 | num_heads: 4 11 | dropout: 0.1 12 | activation: gelu 13 | 14 | text_encoder: 15 | _target_: src.model.ACTORStyleEncoder 16 | nfeats: 768 17 | vae: true 18 | latent_dim: 256 19 | ff_size: 1024 20 | num_layers: 6 21 | num_heads: 4 22 | dropout: 0.1 23 | activation: gelu 24 | 25 | motion_decoder: 26 | _target_: src.model.ACTORStyleDecoder 27 | nfeats: ${data.motion_loader.nfeats} 28 | latent_dim: 256 29 | ff_size: 1024 30 | num_layers: 6 31 | num_heads: 4 32 | dropout: 0.1 33 | activation: gelu 34 | 35 | vae: true 36 | 37 | lmd: 38 | recons: 1.0 39 | latent: 1.0e-5 40 | kl: 1.0e-5 41 | 42 | lr: 1e-4 43 | -------------------------------------------------------------------------------- /configs/hydra/job_logging/tqdm.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | simple: 5 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 6 | datefmt: '%d/%m/%y %H:%M:%S' 7 | 8 | colorlog: 9 | (): colorlog.ColoredFormatter 10 | format: '[%(white)s%(asctime)s%(reset)s] %(log_color)s%(levelname)s%(reset)s %(message)s' 11 | datefmt: '%d/%m/%y %H:%M:%S' 12 | 13 | log_colors: 14 | DEBUG: purple 15 | INFO: blue 16 | WARNING: yellow 17 | ERROR: red 18 | CRITICAL: red 19 | 20 | handlers: 21 | console: 22 | class: src.logging.TqdmLoggingHandler 23 | formatter: colorlog 24 | file_out: 25 | class: logging.FileHandler 26 | formatter: simple 27 | filename: ${run_dir}/${hydra.job.name}.out 28 | 29 | root: 30 | level: ${logger_level} 31 | handlers: [console, file_out] 32 | 33 | disable_existing_loggers: false 34 | -------------------------------------------------------------------------------- /configs/trainer.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | _target_: pytorch_lightning.Trainer 3 | 4 | max_epochs: 500 5 | log_every_n_steps: 50 6 | num_sanity_val_steps: 0 7 | check_val_every_n_epoch: 1 8 | accelerator: gpu 9 | devices: 1 10 | 11 | callbacks: 12 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 13 | filename: latest-{epoch} 14 | every_n_epochs: 1 15 | save_top_k: 1 16 | save_last: true 17 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 18 | filename: latest-{epoch} 19 | monitor: step 20 | mode: max 21 | every_n_epochs: 100 22 | save_top_k: -1 23 | save_last: false 24 | - _target_: src.callback.progress.ProgressLogger 25 | precision: 3 26 | - _target_: src.callback.tqdmbar.TQDMProgressBar 27 | 28 | logger: 29 | _target_: src.logger.csv.CSVLogger 30 | save_dir: ${run_dir} 31 | name: logs 32 | -------------------------------------------------------------------------------- /prepare/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | 5 | 6 | def loop_amass( 7 | base_folder, 8 | new_base_folder, 9 | ext=".npz", 10 | newext=".npz", 11 | force_redo=False, 12 | exclude=None, 13 | ): 14 | match_str = f"**/*{ext}" 15 | 16 | for motion_file in tqdm(glob(match_str, root_dir=base_folder, recursive=True)): 17 | if exclude and exclude in motion_file: 18 | continue 19 | 20 | motion_path = os.path.join(base_folder, motion_file) 21 | 22 | if motion_path.endswith("shape.npz"): 23 | continue 24 | 25 | new_motion_path = os.path.join( 26 | new_base_folder, motion_file.replace(ext, newext) 27 | ) 28 | if not force_redo and os.path.exists(new_motion_path): 29 | continue 30 | 31 | new_folder = os.path.split(new_motion_path)[0] 32 | os.makedirs(new_folder, exist_ok=True) 33 | 34 | yield motion_path, new_motion_path 35 | -------------------------------------------------------------------------------- /demo/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import orjson 4 | import codecs as cs 5 | import torch 6 | 7 | 8 | def load_json(json_path): 9 | with open(json_path, "rb") as ff: 10 | return orjson.loads(ff.read()) 11 | 12 | 13 | def load_unit_embeddings(run_dir, dataset, device="cpu"): 14 | save_dir = os.path.join(run_dir, "latents") 15 | unit_emb_path = os.path.join(save_dir, f"{dataset}_all_unit.npy") 16 | motion_embs = torch.from_numpy(np.load(unit_emb_path)).to(device) 17 | 18 | # Loading the correspondance 19 | keyids_index = load_json(os.path.join(save_dir, f"{dataset}_keyids_index_all.json")) 20 | index_keyids = load_json(os.path.join(save_dir, f"{dataset}_index_keyids_all.json")) 21 | 22 | return motion_embs, keyids_index, index_keyids 23 | 24 | 25 | def load_split(path, split): 26 | split_file = os.path.join(path, "splits", split + ".txt") 27 | id_list = [] 28 | with cs.open(split_file, "r") as f: 29 | for line in f.readlines(): 30 | id_list.append(line.strip()) 31 | return id_list 32 | 33 | 34 | def load_splits(dataset, splits=["test", "all"]): 35 | path = f"datasets/annotations/{dataset}" 36 | return {split: load_split(path, split) for split in splits} 37 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/nsim_test.txt: -------------------------------------------------------------------------------- 1 | 00355 2 | 01496 3 | 01344 4 | 03107 5 | 02245 6 | 01109 7 | 01349 8 | 03128 9 | 02906 10 | 03861 11 | 03670 12 | 00989 13 | 02070 14 | 02971 15 | 01727 16 | 03045 17 | 02510 18 | 01391 19 | 01532 20 | 03771 21 | 03036 22 | 01472 23 | 03691 24 | 01463 25 | 00978 26 | 01639 27 | 02407 28 | 03917 29 | 01450 30 | 01854 31 | 00594 32 | 00736 33 | 03087 34 | 01440 35 | 02021 36 | 01444 37 | 03683 38 | 01367 39 | 00390 40 | 03577 41 | 01485 42 | 02148 43 | 03190 44 | 01223 45 | 03215 46 | 03098 47 | 02139 48 | 02435 49 | 03532 50 | M00355 51 | M01496 52 | M02751 53 | M01344 54 | M03107 55 | M02245 56 | M00452 57 | M02556 58 | M01109 59 | M01027 60 | M01349 61 | M03128 62 | M02906 63 | M03861 64 | M03670 65 | M00989 66 | M02070 67 | M02971 68 | M01727 69 | M03045 70 | M02510 71 | M01391 72 | M01532 73 | M03036 74 | M01472 75 | M03691 76 | M01463 77 | M02407 78 | M03917 79 | M01450 80 | M01854 81 | M00594 82 | M00736 83 | M03087 84 | M01440 85 | M02021 86 | M01444 87 | M03683 88 | M00568 89 | M01298 90 | M01491 91 | M01367 92 | M03577 93 | M01485 94 | M02148 95 | M03190 96 | M03215 97 | M03098 98 | M02139 99 | M02435 100 | M03532 101 | M00669 102 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/nsim_test.txt: -------------------------------------------------------------------------------- 1 | 007742 2 | 005935 3 | 008597 4 | 010546 5 | 005697 6 | 005668 7 | 000565 8 | 000119 9 | 008877 10 | 006058 11 | 009566 12 | 005180 13 | 012578 14 | 011493 15 | 005946 16 | 006630 17 | 008900 18 | 004601 19 | 002459 20 | 006186 21 | 000552 22 | 005674 23 | 004545 24 | 001052 25 | 002635 26 | 005672 27 | 011004 28 | 013440 29 | 009455 30 | 003463 31 | 000824 32 | 006549 33 | 007655 34 | 006762 35 | 012222 36 | 012655 37 | 012956 38 | 004973 39 | 013403 40 | 008730 41 | 003439 42 | 008824 43 | 008340 44 | 010823 45 | 007806 46 | 013898 47 | 004996 48 | 010384 49 | 004344 50 | 005048 51 | 001152 52 | 012568 53 | M008664 54 | M007889 55 | M000389 56 | M011343 57 | M012558 58 | M010392 59 | M014283 60 | M001538 61 | M011643 62 | M003677 63 | M011972 64 | M009880 65 | M013023 66 | M012399 67 | M002761 68 | M014109 69 | M004319 70 | M001648 71 | M013778 72 | M008383 73 | M000178 74 | M009148 75 | M006433 76 | M011569 77 | M001577 78 | M008275 79 | M012813 80 | M012084 81 | M009123 82 | M000179 83 | M012639 84 | M010671 85 | M008583 86 | M000972 87 | M008349 88 | M002824 89 | M003301 90 | M008490 91 | M003902 92 | M002252 93 | M008668 94 | M000903 95 | M003689 96 | M003373 97 | M010964 98 | M001193 99 | M006533 100 | M014384 101 | -------------------------------------------------------------------------------- /src/callback/tqdmbar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import TQDMProgressBar as OriginalTQDMProgressBar 4 | 5 | 6 | def customize_bar(bar): 7 | if not sys.stdout.isatty(): 8 | bar.disable = True 9 | bar.leave = True # remove the bar after completion 10 | return bar 11 | 12 | 13 | class TQDMProgressBar(OriginalTQDMProgressBar): 14 | # remove the annoying v_num in the bar 15 | def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 16 | items_dict = super().get_metrics(trainer, pl_module).copy() 17 | 18 | if "v_num" in items_dict: 19 | items_dict.pop("v_num") 20 | return items_dict 21 | 22 | def init_sanity_tqdm(self): 23 | bar = super().init_sanity_tqdm() 24 | return customize_bar(bar) 25 | 26 | def init_train_tqdm(self): 27 | bar = super().init_train_tqdm() 28 | return customize_bar(bar) 29 | 30 | def init_validation_tqdm(self): 31 | bar = super().init_validation_tqdm() 32 | bar.disable = True 33 | return bar 34 | 35 | def init_predict_tqdm(self): 36 | bar = super().init_predict_tqdm() 37 | return customize_bar(bar) 38 | 39 | def init_test_tqdm(self): 40 | bar = super().init_test_tqdm() 41 | return customize_bar(bar) 42 | -------------------------------------------------------------------------------- /encode_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_text") 10 | def encode_text(cfg: DictConfig) -> None: 11 | device = cfg.device 12 | run_dir = cfg.run_dir 13 | ckpt_name = cfg.ckpt_name 14 | text = cfg.text 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | 25 | cfg = read_config(run_dir) 26 | 27 | logger.info("Loading the text model") 28 | text_model = instantiate(cfg.data.text_to_token_emb, device=device) 29 | 30 | logger.info("Loading the model") 31 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 32 | 33 | seed_everything(cfg.seed) 34 | with torch.inference_mode(): 35 | text_x_dict = collate_x_dict(text_model([text])) 36 | latent = model.encode(text_x_dict, sample_mean=True)[0] 37 | latent = latent.cpu().numpy() 38 | 39 | fname = text.lower().replace(" ", "_") + ".npy" 40 | 41 | output_folder = os.path.join(run_dir, "encoded") 42 | os.makedirs(output_folder, exist_ok=True) 43 | path = os.path.join(output_folder, fname) 44 | 45 | np.save(path, latent) 46 | logger.info(f"Encoding done, latent saved in:\n{path}") 47 | 48 | 49 | if __name__ == "__main__": 50 | encode_text() 51 | -------------------------------------------------------------------------------- /prepare/text_embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(config_path="../configs", config_name="text_embeddings", version_base="1.3") 10 | def text_embeddings(cfg: DictConfig): 11 | device = cfg.device 12 | 13 | from src.data.text import save_token_embeddings, save_sent_embeddings 14 | 15 | annotations_filename = cfg.annotations_filename 16 | 17 | # Compute token embeddings 18 | modelname = cfg.data.text_to_token_emb.modelname 19 | modelpath = cfg.data.text_to_token_emb.modelpath 20 | logger.info(f"Compute token embeddings for {modelname}") 21 | path = cfg.data.text_to_token_emb.path 22 | output_folder_name = cfg.output_folder_name_token 23 | save_token_embeddings(path, annotations_filename=annotations_filename, 24 | output_folder_name=output_folder_name, 25 | modelname=modelname, modelpath=modelpath, 26 | device=device) 27 | 28 | # Compute sent embeddings 29 | modelname = cfg.data.text_to_sent_emb.modelname 30 | modelpath = cfg.data.text_to_sent_emb.modelpath 31 | logger.info(f"Compute sentence embeddings for {modelname}") 32 | path = cfg.data.text_to_sent_emb.path 33 | output_folder_name = cfg.output_folder_name_sent 34 | save_sent_embeddings(path, annotations_filename=annotations_filename, 35 | output_folder_name=output_folder_name, 36 | modelname=modelname, modelpath=modelpath, 37 | device=device) 38 | 39 | 40 | if __name__ == "__main__": 41 | text_embeddings() 42 | -------------------------------------------------------------------------------- /encode_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_motion") 10 | def encode_motion(cfg: DictConfig) -> None: 11 | device = cfg.device 12 | run_dir = cfg.run_dir 13 | ckpt_name = cfg.ckpt_name 14 | npy_path = cfg.npy 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | 25 | cfg = read_config(run_dir) 26 | 27 | logger.info("Loading the model") 28 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 29 | normalizer = instantiate(cfg.data.motion_loader.normalizer) 30 | 31 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float) 32 | motion = normalizer(motion) 33 | motion = motion.to(device) 34 | 35 | motion_x_dict = {"x": motion, "length": len(motion)} 36 | 37 | seed_everything(cfg.seed) 38 | with torch.inference_mode(): 39 | motion_x_dict = collate_x_dict([motion_x_dict]) 40 | latent = model.encode(motion_x_dict, sample_mean=True)[0] 41 | latent = latent.cpu().numpy() 42 | 43 | fname = os.path.split(npy_path)[1] 44 | output_folder = os.path.join(run_dir, "encoded") 45 | os.makedirs(output_folder, exist_ok=True) 46 | path = os.path.join(output_folder, fname) 47 | 48 | np.save(path, latent) 49 | logger.info(f"Encoding done, latent saved in:\n{path}") 50 | 51 | 52 | if __name__ == "__main__": 53 | encode_motion() 54 | -------------------------------------------------------------------------------- /text_motion_sim.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | import logging 3 | import hydra 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hydra.main(version_base=None, config_path="configs", config_name="text_motion_sim") 9 | def text_motion_sim(cfg: DictConfig) -> None: 10 | device = cfg.device 11 | run_dir = cfg.run_dir 12 | ckpt_name = cfg.ckpt_name 13 | npy_path = cfg.npy 14 | text = cfg.text 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | from src.model.tmr import get_score_matrix 25 | 26 | cfg = read_config(run_dir) 27 | 28 | seed_everything(cfg.seed) 29 | 30 | logger.info("Loading the text model") 31 | text_model = instantiate(cfg.data.text_to_token_emb, device=device) 32 | 33 | logger.info("Loading the model") 34 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 35 | 36 | normalizer = instantiate(cfg.data.motion_loader.normalizer) 37 | 38 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float) 39 | motion = normalizer(motion) 40 | motion = motion.to(device) 41 | 42 | motion_x_dict = {"x": motion, "length": len(motion)} 43 | 44 | with torch.inference_mode(): 45 | # motion -> latent 46 | motion_x_dict = collate_x_dict([motion_x_dict]) 47 | lat_m = model.encode(motion_x_dict, sample_mean=True)[0] 48 | 49 | # text -> latent 50 | text_x_dict = collate_x_dict(text_model([text])) 51 | lat_t = model.encode(text_x_dict, sample_mean=True)[0] 52 | 53 | score = get_score_matrix(lat_t, lat_m).cpu() 54 | 55 | score_str = f"{score:.3}" 56 | logger.info( 57 | f"The similariy score s (0 <= s <= 1) between the text and the motion is: {score_str}" 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | text_motion_sim() 63 | -------------------------------------------------------------------------------- /src/guofeats/paramUtil.py: -------------------------------------------------------------------------------- 1 | # Taken from 2 | # https://github.com/EricGuo5513/HumanML3D/blob/main/paramUtil.py 3 | 4 | import numpy as np 5 | 6 | # Define a kinematic tree for the skeletal struture 7 | kit_kinematic_chain = [ 8 | [0, 11, 12, 13, 14, 15], 9 | [0, 16, 17, 18, 19, 20], 10 | [0, 1, 2, 3, 4], 11 | [3, 5, 6, 7], 12 | [3, 8, 9, 10], 13 | ] 14 | 15 | kit_raw_offsets = np.array( 16 | [ 17 | [0, 0, 0], 18 | [0, 1, 0], 19 | [0, 1, 0], 20 | [0, 1, 0], 21 | [0, 1, 0], 22 | [1, 0, 0], 23 | [0, -1, 0], 24 | [0, -1, 0], 25 | [-1, 0, 0], 26 | [0, -1, 0], 27 | [0, -1, 0], 28 | [1, 0, 0], 29 | [0, -1, 0], 30 | [0, -1, 0], 31 | [0, 0, 1], 32 | [0, 0, 1], 33 | [-1, 0, 0], 34 | [0, -1, 0], 35 | [0, -1, 0], 36 | [0, 0, 1], 37 | [0, 0, 1], 38 | ] 39 | ) 40 | 41 | t2m_raw_offsets = np.array( 42 | [ 43 | [0, 0, 0], 44 | [1, 0, 0], 45 | [-1, 0, 0], 46 | [0, 1, 0], 47 | [0, -1, 0], 48 | [0, -1, 0], 49 | [0, 1, 0], 50 | [0, -1, 0], 51 | [0, -1, 0], 52 | [0, 1, 0], 53 | [0, 0, 1], 54 | [0, 0, 1], 55 | [0, 1, 0], 56 | [1, 0, 0], 57 | [-1, 0, 0], 58 | [0, 0, 1], 59 | [0, -1, 0], 60 | [0, -1, 0], 61 | [0, -1, 0], 62 | [0, -1, 0], 63 | [0, -1, 0], 64 | [0, -1, 0], 65 | ] 66 | ) 67 | 68 | t2m_kinematic_chain = [ 69 | [0, 2, 5, 8, 11], 70 | [0, 1, 4, 7, 10], 71 | [0, 3, 6, 9, 12, 15], 72 | [9, 14, 17, 19, 21], 73 | [9, 13, 16, 18, 20], 74 | ] 75 | t2m_left_hand_chain = [ 76 | [20, 22, 23, 24], 77 | [20, 34, 35, 36], 78 | [20, 25, 26, 27], 79 | [20, 31, 32, 33], 80 | [20, 28, 29, 30], 81 | ] 82 | t2m_right_hand_chain = [ 83 | [21, 43, 44, 45], 84 | [21, 46, 47, 48], 85 | [21, 40, 41, 42], 86 | [21, 37, 38, 39], 87 | [21, 49, 50, 51], 88 | ] 89 | 90 | 91 | kit_tgt_skel_id = "03950" 92 | t2m_tgt_skel_id = "000021" 93 | -------------------------------------------------------------------------------- /src/data/motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class AMASSMotionLoader: 7 | def __init__( 8 | self, base_dir, fps, normalizer=None, disable: bool = False, nfeats=None 9 | ): 10 | self.fps = fps 11 | self.base_dir = base_dir 12 | self.motions = {} 13 | self.normalizer = normalizer 14 | self.disable = disable 15 | self.nfeats = nfeats 16 | 17 | def __call__(self, path, start, end): 18 | if self.disable: 19 | return {"x": path, "length": int(self.fps * (end - start))} 20 | 21 | begin = int(start * self.fps) 22 | end = int(end * self.fps) 23 | if path not in self.motions: 24 | motion_path = os.path.join(self.base_dir, path + ".npy") 25 | motion = np.load(motion_path) 26 | motion = torch.from_numpy(motion).to(torch.float) 27 | if self.normalizer is not None: 28 | motion = self.normalizer(motion) 29 | self.motions[path] = motion 30 | 31 | motion = self.motions[path][begin:end] 32 | x_dict = {"x": motion, "length": len(motion)} 33 | return x_dict 34 | 35 | 36 | class Normalizer: 37 | def __init__(self, base_dir: str, eps: float = 1e-12, disable: bool = False): 38 | self.base_dir = base_dir 39 | self.mean_path = os.path.join(base_dir, "mean.pt") 40 | self.std_path = os.path.join(base_dir, "std.pt") 41 | self.eps = eps 42 | 43 | self.disable = disable 44 | if not disable: 45 | self.load() 46 | 47 | def load(self): 48 | self.mean = torch.load(self.mean_path) 49 | self.std = torch.load(self.std_path) 50 | 51 | def save(self, mean, std): 52 | os.makedirs(self.base_dir, exist_ok=True) 53 | torch.save(mean, self.mean_path) 54 | torch.save(std, self.std_path) 55 | 56 | def __call__(self, x): 57 | if self.disable: 58 | return x 59 | x = (x - self.mean) / (self.std + self.eps) 60 | return x 61 | 62 | def inverse(self, x): 63 | if self.disable: 64 | return x 65 | x = x * (self.std + self.eps) + self.mean 66 | return x 67 | -------------------------------------------------------------------------------- /src/model/tmr_text_averaging.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from torch import Tensor 3 | 4 | import torch 5 | from .tmr import TMR 6 | 7 | class TMRTextAveraging(TMR): 8 | """Compatible with AugmentedTextMotionDataset Dataset object and collate_text_motion_multiple_texts collate function.""" 9 | 10 | # Forward: X => motions 11 | def forward( 12 | self, 13 | inputs, 14 | text_slices: Optional[List[int]] = None, 15 | lengths: Optional[List[int]] = None, 16 | mask: Optional[Tensor] = None, 17 | sample_mean: Optional[bool] = None, 18 | fact: Optional[float] = None, 19 | return_all: bool = False, 20 | ) -> List[Tensor]: 21 | 22 | # Encoding the inputs and sampling if needed 23 | latent_vectors, distributions = self.encode( 24 | inputs, sample_mean=sample_mean, fact=fact, return_distribution=True 25 | ) 26 | 27 | # Averages over the different text embbedings for each sample. 28 | if text_slices is not None: 29 | latent_vectors = [torch.mean(latent_vectors[i:j], dim=0) for i, j in text_slices] 30 | latent_vectors = torch.stack(latent_vectors, dim=0) 31 | distributions = list(distributions) 32 | distributions[0] = torch.stack([torch.mean(distributions[0][i:j], dim=0) for i, j in text_slices], dim=0) 33 | distributions[1] = torch.stack([torch.mean(distributions[1][i:j], dim=0) for i, j in text_slices], dim=0) 34 | distributions = tuple(distributions) 35 | 36 | # Decoding the latent vector: generating motions 37 | motions = self.decode(latent_vectors, lengths, mask) 38 | 39 | if return_all: 40 | return {"motions": motions, 41 | "latent_vectors": latent_vectors, 42 | "distributions": distributions} 43 | 44 | return {"motions": motions} 45 | 46 | def call_models(self, batch): 47 | text_x_dict = batch["text_x_dict"] 48 | motion_x_dict = batch["motion_x_dict"] 49 | text_slices = batch["text_slices"] 50 | 51 | mask = motion_x_dict["mask"] 52 | 53 | # text -> motion 54 | t_results = self(text_x_dict, mask=mask, return_all=True, text_slices=text_slices) 55 | 56 | # motion -> motion 57 | m_results = self(motion_x_dict, mask=mask, return_all=True) 58 | 59 | return t_results, m_results 60 | -------------------------------------------------------------------------------- /src/joints.py: -------------------------------------------------------------------------------- 1 | JOINT_NAMES = { 2 | "smpljoints": [ 3 | "pelvis", 4 | "left_hip", 5 | "right_hip", 6 | "spine1", 7 | "left_knee", 8 | "right_knee", 9 | "spine2", 10 | "left_ankle", 11 | "right_ankle", 12 | "spine3", 13 | "left_foot", 14 | "right_foot", 15 | "neck", 16 | "left_collar", 17 | "right_collar", 18 | "head", 19 | "left_shoulder", 20 | "right_shoulder", 21 | "left_elbow", 22 | "right_elbow", 23 | "left_wrist", 24 | "right_wrist", 25 | "left_hand", 26 | "right_hand", 27 | ], 28 | "guoh3djoints": [ 29 | "pelvis", 30 | "left_hip", 31 | "right_hip", 32 | "spine1", 33 | "left_knee", 34 | "right_knee", 35 | "spine2", 36 | "left_ankle", 37 | "right_ankle", 38 | "spine3", 39 | "left_foot", 40 | "right_foot", 41 | "neck", 42 | "left_collar", 43 | "right_collar", 44 | "head", 45 | "left_shoulder", 46 | "right_shoulder", 47 | "left_elbow", 48 | "right_elbow", 49 | "left_wrist", 50 | "right_wrist", 51 | ], 52 | } 53 | 54 | INFOS = { 55 | "smpljoints": { 56 | "LM": JOINT_NAMES["smpljoints"].index("left_ankle"), 57 | "RM": JOINT_NAMES["smpljoints"].index("right_ankle"), 58 | "LF": JOINT_NAMES["smpljoints"].index("left_foot"), 59 | "RF": JOINT_NAMES["smpljoints"].index("right_foot"), 60 | "LS": JOINT_NAMES["smpljoints"].index("left_shoulder"), 61 | "RS": JOINT_NAMES["smpljoints"].index("right_shoulder"), 62 | "LH": JOINT_NAMES["smpljoints"].index("left_hip"), 63 | "RH": JOINT_NAMES["smpljoints"].index("right_hip"), 64 | "njoints": len(JOINT_NAMES["smpljoints"]), 65 | }, 66 | "guoh3djoints": { 67 | "LM": JOINT_NAMES["guoh3djoints"].index("left_ankle"), 68 | "RM": JOINT_NAMES["guoh3djoints"].index("right_ankle"), 69 | "LF": JOINT_NAMES["guoh3djoints"].index("left_foot"), 70 | "RF": JOINT_NAMES["guoh3djoints"].index("right_foot"), 71 | "LS": JOINT_NAMES["guoh3djoints"].index("left_shoulder"), 72 | "RS": JOINT_NAMES["guoh3djoints"].index("right_shoulder"), 73 | "LH": JOINT_NAMES["guoh3djoints"].index("left_hip"), 74 | "RH": JOINT_NAMES["guoh3djoints"].index("right_hip"), 75 | "njoints": len(JOINT_NAMES["guoh3djoints"]), 76 | }, 77 | } 78 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import instantiate 3 | import logging 4 | from omegaconf import DictConfig 5 | import os 6 | import pytorch_lightning as pl 7 | 8 | from src.config import read_config, save_config 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @hydra.main(config_path="configs", config_name="train", version_base="1.3") 14 | def train(cfg: DictConfig): 15 | # Resuming if needed 16 | ckpt = None 17 | 18 | if cfg.ckpt is not None: 19 | ckpt = cfg.ckpt 20 | 21 | if cfg.resume_dir is not None: 22 | assert cfg.ckpt is not None 23 | max_epochs = cfg.trainer.max_epochs 24 | ckpt = os.path.join(cfg.resume_dir, 'logs', 'checkpoints', f'{cfg.ckpt}.ckpt') 25 | cfg = read_config(cfg.resume_dir) 26 | cfg.trainer.max_epochs = max_epochs 27 | logger.info("Resuming training") 28 | logger.info(f"The config is loaded from: \n{cfg.run_dir}") 29 | else: 30 | if "ckpt_path" in cfg and cfg.ckpt_path is not None: 31 | ckpt = cfg.ckpt_path 32 | config_path = save_config(cfg) 33 | logger.info("Training script") 34 | logger.info(f"The config can be found here: \n{config_path}") 35 | 36 | pl.seed_everything(cfg.seed) 37 | 38 | text_to_token_emb = instantiate(cfg.data.text_to_token_emb) 39 | text_to_sent_emb = instantiate(cfg.data.text_to_sent_emb) 40 | 41 | logger.info("Loading the dataloaders") 42 | train_dataset = instantiate(cfg.data, split="train", 43 | text_to_token_emb=text_to_token_emb, 44 | text_to_sent_emb=text_to_sent_emb) 45 | 46 | if "data_val" not in cfg: 47 | data_val = cfg.data 48 | else: 49 | data_val = cfg.data_val 50 | text_to_token_emb = instantiate(cfg.data_val.text_to_token_emb) 51 | text_to_sent_emb = instantiate(cfg.data_val.text_to_sent_emb) 52 | 53 | val_dataset = instantiate(data_val, split="val", 54 | text_to_token_emb=text_to_token_emb, 55 | text_to_sent_emb=text_to_sent_emb) 56 | 57 | train_dataloader = instantiate( 58 | cfg.dataloader, 59 | dataset=train_dataset, 60 | collate_fn=train_dataset.collate_fn, 61 | shuffle=True, 62 | ) 63 | 64 | val_dataloader = instantiate( 65 | cfg.dataloader, 66 | dataset=val_dataset, 67 | collate_fn=val_dataset.collate_fn, 68 | shuffle=False, 69 | ) 70 | 71 | logger.info("Loading the model") 72 | model = instantiate(cfg.model) 73 | 74 | logger.info(f"Using checkpoint: {ckpt}") 75 | logger.info("Training") 76 | trainer = instantiate(cfg.trainer) 77 | trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt) 78 | 79 | 80 | if __name__ == "__main__": 81 | train() 82 | -------------------------------------------------------------------------------- /src/callback/progress.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning import LightningModule, Trainer 4 | from pytorch_lightning.callbacks import Callback 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ProgressLogger(Callback): 11 | def __init__(self, precision: int = 2): 12 | self.precision = precision 13 | 14 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, **kwargs): 15 | logger.info("Training started") 16 | 17 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs): 18 | logger.info("Training done") 19 | 20 | def on_validation_epoch_end( 21 | self, trainer: Trainer, pl_module: LightningModule, **kwargs 22 | ): 23 | if trainer.sanity_checking: 24 | logger.info("Sanity checking ok.") 25 | 26 | def on_train_epoch_end( 27 | self, trainer: Trainer, pl_module: LightningModule, **kwargs 28 | ): 29 | metric_format = f"{{:.{self.precision}e}}" 30 | line = f"Epoch {trainer.current_epoch}" 31 | metrics_str = [] 32 | 33 | losses_dict = trainer.callback_metrics 34 | 35 | def is_contrastive_metrics(x): 36 | return "t2m" in x or "m2t" in x 37 | 38 | losses_to_print = [ 39 | x 40 | for x in losses_dict.keys() 41 | for y in [x.split("_")] 42 | if len(y) == 3 43 | and y[2] == "epoch" 44 | and ( 45 | y[1] in pl_module.lmd or y[1] == "loss" or is_contrastive_metrics(y[1]) 46 | ) 47 | ] 48 | 49 | # Natual order for contrastive 50 | letters = "0123456789" 51 | mapping = str.maketrans(letters, letters[::-1]) 52 | 53 | def sort_losses(x): 54 | split, name, epoch_step = x.split("_") 55 | if is_contrastive_metrics(x): 56 | # put them at the end 57 | name = "a" + name.translate(mapping) 58 | return (name, split) 59 | 60 | losses_to_print = sorted(losses_to_print, key=sort_losses, reverse=True) 61 | for metric_name in losses_to_print: 62 | split, name, _ = metric_name.split("_") 63 | 64 | metric = losses_dict[metric_name].item() 65 | 66 | if is_contrastive_metrics(metric_name): 67 | if "len" in metric_name: 68 | metric = str(int(metric)) 69 | elif "MedR" in metric_name: 70 | metric = str(int(metric * 100) / 100) + "%" 71 | else: 72 | metric = str(int(metric * 100) / 100) + "%" 73 | else: 74 | metric = metric_format.format(metric) 75 | 76 | if split == "train": 77 | mname = name 78 | else: 79 | mname = f"v_{name}" 80 | 81 | metric = f"{mname} {metric}" 82 | metrics_str.append(metric) 83 | 84 | if len(metrics_str) == 0: 85 | return 86 | 87 | line = line + ": " + " ".join(metrics_str) 88 | logger.info(line) 89 | -------------------------------------------------------------------------------- /src/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | from src.config import read_config 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | # split the lightning checkpoint into 12 | # seperate state_dict modules for faster loading 13 | def extract_ckpt(run_dir, ckpt_name="last"): 14 | import torch 15 | 16 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") 17 | 18 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") 19 | os.makedirs(extracted_path, exist_ok=True) 20 | 21 | new_path_template = os.path.join(extracted_path, "{}.pt") 22 | ckpt_dict = torch.load(ckpt_path) 23 | state_dict = ckpt_dict["state_dict"] 24 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) 25 | 26 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example 27 | for module_name in module_names: 28 | path = new_path_template.format(module_name) 29 | sub_state_dict = { 30 | ".".join(x.split(".")[1:]): y.cpu() 31 | for x, y in state_dict.items() 32 | if x.split(".")[0] == module_name 33 | } 34 | torch.save(sub_state_dict, path) 35 | 36 | 37 | def load_model(run_dir, **params): 38 | # Load last config 39 | cfg = read_config(run_dir) 40 | cfg.run_dir = run_dir 41 | return load_model_from_cfg(cfg, **params) 42 | 43 | 44 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True): 45 | import src.prepare # noqa 46 | import torch 47 | 48 | run_dir = cfg.run_dir 49 | model = hydra.utils.instantiate(cfg.model) 50 | 51 | # Loading modules one by one 52 | # motion_encoder / text_encoder / text_decoder 53 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") 54 | 55 | if not os.path.exists(pt_path): 56 | logger.info("The extracted model is not found. Split into submodules..") 57 | extract_ckpt(run_dir, ckpt_name) 58 | 59 | for fname in os.listdir(pt_path): 60 | module_name, ext = os.path.splitext(fname) 61 | 62 | if ext != ".pt": 63 | continue 64 | 65 | module = getattr(model, module_name, None) 66 | if module is None: 67 | continue 68 | 69 | module_path = os.path.join(pt_path, fname) 70 | state_dict = torch.load(module_path) 71 | module.load_state_dict(state_dict) 72 | logger.info(f" {module_name} loaded") 73 | 74 | logger.info("Loading previous checkpoint done") 75 | model = model.to(device) 76 | logger.info(f"Put the model on {device}") 77 | if eval_mode: 78 | model = model.eval() 79 | logger.info("Put the model in eval mode") 80 | return model 81 | 82 | 83 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model") 84 | def hydra_load_model(cfg: DictConfig) -> None: 85 | run_dir = cfg.run_dir 86 | ckpt_name = cfg.ckpt 87 | device = cfg.device 88 | eval_mode = cfg.eval_mode 89 | return load_model(run_dir, ckpt_name, device, eval_mode) 90 | 91 | 92 | if __name__ == "__main__": 93 | hydra_load_model() 94 | -------------------------------------------------------------------------------- /src/data/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List, Dict, Optional 4 | from torch import Tensor 5 | from torch.utils.data import default_collate 6 | 7 | 8 | def length_to_mask(length, device: torch.device = None) -> Tensor: 9 | if device is None: 10 | device = "cpu" 11 | 12 | if isinstance(length, list): 13 | length = torch.tensor(length, device=device) 14 | 15 | max_len = max(length) 16 | mask = torch.arange(max_len, device=device).expand( 17 | len(length), max_len 18 | ) < length.unsqueeze(1) 19 | return mask 20 | 21 | 22 | def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor: 23 | dims = batch[0].dim() 24 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] 25 | size = (len(batch),) + tuple(max_size) 26 | canvas = batch[0].new_zeros(size=size) 27 | for i, b in enumerate(batch): 28 | sub_tensor = canvas[i] 29 | for d in range(dims): 30 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) 31 | sub_tensor.add_(b) 32 | return canvas 33 | 34 | 35 | def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = None) -> Dict: 36 | x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict]) 37 | if device is not None: 38 | x = x.to(device) 39 | length = [x_dict["length"] for x_dict in lst_x_dict] 40 | mask = length_to_mask(length, device=x.device) 41 | batch = {"x": x, "length": length, "mask": mask} 42 | return batch 43 | 44 | 45 | def collate_text_motion(lst_elements: List, *, device: Optional[str] = None) -> Dict: 46 | one_el = lst_elements[0] 47 | keys = one_el.keys() 48 | 49 | x_dict_keys = [key for key in keys if "x_dict" in key] 50 | other_keys = [key for key in keys if "x_dict" not in key] 51 | 52 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys} 53 | for key, val in batch.items(): 54 | if isinstance(val, torch.Tensor) and device is not None: 55 | batch[key] = val.to(device) 56 | 57 | for key in x_dict_keys: 58 | batch[key] = collate_x_dict([x[key] for x in lst_elements], device=device) 59 | return batch 60 | 61 | 62 | def collate_text_motion_multiple_texts(lst_elements: List, *, device: Optional[str] = None): 63 | other_keys = ['keyid', 'sent_emb'] 64 | 65 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys} 66 | batch["text"] = [elt["text"] for elt in lst_elements] 67 | 68 | for key, val in batch.items(): 69 | if isinstance(val, torch.Tensor) and device is not None: 70 | batch[key] = val.to(device) 71 | 72 | batch["motion_x_dict"] = collate_x_dict([x["motion_x_dict"] for x in lst_elements], device=device) 73 | 74 | batch["text_slices"] = [] 75 | current_index = 0 76 | for elt in lst_elements: 77 | batch["text_slices"].append((current_index, current_index + len(elt["text"]))) 78 | current_index += len(elt["text"]) 79 | 80 | texts_concat = [x_dict for x in lst_elements for x_dict in x["text_x_dict"]] 81 | batch["text_x_dict"] = collate_x_dict( 82 | texts_concat, 83 | device=device 84 | ) 85 | return batch 86 | -------------------------------------------------------------------------------- /encode_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import json 6 | from hydra.core.hydra_config import HydraConfig 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def x_dict_to_device(x_dict, device): 13 | import torch 14 | 15 | for key, val in x_dict.items(): 16 | if isinstance(val, torch.Tensor): 17 | x_dict[key] = val.to(device) 18 | return x_dict 19 | 20 | 21 | def write_json(data, path): 22 | with open(path, "w") as ff: 23 | ff.write(json.dumps(data, indent=4)) 24 | 25 | 26 | @hydra.main(version_base=None, config_path="configs", config_name="encode_dataset") 27 | def encode_dataset(cfg: DictConfig) -> None: 28 | device = cfg.device 29 | run_dir = cfg.run_dir 30 | ckpt_name = cfg.ckpt_name 31 | cfg_data = cfg.data 32 | 33 | choices = HydraConfig.get().runtime.choices 34 | data_name = choices.data 35 | 36 | import src.prepare # noqa 37 | import torch 38 | import numpy as np 39 | from src.config import read_config 40 | from src.load import load_model_from_cfg 41 | from hydra.utils import instantiate 42 | from pytorch_lightning import seed_everything 43 | 44 | cfg = read_config(run_dir) 45 | 46 | logger.info("Loading the model") 47 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 48 | 49 | save_dir = os.path.join(run_dir, "latents") 50 | os.makedirs(save_dir, exist_ok=True) 51 | 52 | dataset = instantiate(cfg_data, split="all") 53 | dataloader = instantiate( 54 | cfg.dataloader, 55 | dataset=dataset, 56 | collate_fn=dataset.collate_fn, 57 | shuffle=False, 58 | ) 59 | seed_everything(cfg.seed) 60 | 61 | all_latents = [] 62 | all_keyids = [] 63 | 64 | with torch.inference_mode(): 65 | for batch in dataloader: 66 | motion_x_dict = batch["motion_x_dict"] 67 | x_dict_to_device(motion_x_dict, device) 68 | latents = model.encode(motion_x_dict, sample_mean=True) 69 | all_latents.append(latents.cpu().numpy()) 70 | keyids = batch["keyid"] 71 | all_keyids.extend(keyids) 72 | 73 | latents = np.concatenate(all_latents) 74 | path = os.path.join(save_dir, f"{data_name}_all.npy") 75 | logger.info(f"Encoding the latents of all the splits in {path}") 76 | np.save(path, latents) 77 | 78 | path_unit = os.path.join(save_dir, f"{data_name}_all_unit.npy") 79 | logger.info(f"Encoding the unit latents of all the splits in {path_unit}") 80 | 81 | unit_latents = latents / np.linalg.norm(latents, axis=-1)[:, None] 82 | np.save(path_unit, unit_latents) 83 | 84 | # Writing the correspondance 85 | logger.info("Writing the correspondance files") 86 | keyids_index_path = os.path.join(save_dir, f"{data_name}_keyids_index_all.json") 87 | index_keyids_path = os.path.join(save_dir, f"{data_name}_index_keyids_all.json") 88 | 89 | keyids_index = {x: i for i, x in enumerate(all_keyids)} 90 | index_keyids = {i: x for i, x in enumerate(all_keyids)} 91 | 92 | write_json(keyids_index, keyids_index_path) 93 | write_json(index_keyids, index_keyids_path) 94 | 95 | 96 | if __name__ == "__main__": 97 | encode_dataset() 98 | -------------------------------------------------------------------------------- /src/model/text_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | from typing import Dict, List 6 | import torch.nn.functional as F 7 | 8 | 9 | class TextToEmb(nn.Module): 10 | def __init__( 11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu" 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.device = device 16 | from transformers import AutoTokenizer, AutoModel 17 | from transformers import logging 18 | 19 | logging.set_verbosity_error() 20 | 21 | # Tokenizer 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 24 | 25 | # Text model 26 | self.text_model = AutoModel.from_pretrained(modelpath) 27 | # Then configure the model 28 | self.text_encoded_dim = self.text_model.config.hidden_size 29 | 30 | if mean_pooling: 31 | self.forward = self.forward_pooling 32 | 33 | # put it in eval mode by default 34 | self.eval() 35 | 36 | # Freeze the weights just in case 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | self.to(device) 41 | 42 | def train(self, mode: bool = True) -> nn.Module: 43 | # override it to be always false 44 | self.training = False 45 | for module in self.children(): 46 | module.train(False) 47 | return self 48 | 49 | @torch.no_grad() 50 | def forward(self, texts: List[str], device=None) -> Dict: 51 | device = device if device is not None else self.device 52 | 53 | squeeze = False 54 | if isinstance(texts, str): 55 | texts = [texts] 56 | squeeze = True 57 | 58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 59 | output = self.text_model(**encoded_inputs.to(device)) 60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1) 61 | 62 | if squeeze: 63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]} 64 | else: 65 | x_dict = {"x": output.last_hidden_state, "length": length} 66 | return x_dict 67 | 68 | @torch.no_grad() 69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor: 70 | device = device if device is not None else self.device 71 | 72 | squeeze = False 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | squeeze = True 76 | 77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 79 | output = self.text_model(**encoded_inputs.to(device)) 80 | attention_mask = encoded_inputs["attention_mask"] 81 | 82 | # Mean Pooling - Take attention mask into account for correct averaging 83 | token_embeddings = output["last_hidden_state"] 84 | input_mask_expanded = ( 85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 86 | ) 87 | sentence_embeddings = torch.sum( 88 | token_embeddings * input_mask_expanded, 1 89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 90 | # Normalize embeddings 91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 92 | if squeeze: 93 | sentence_embeddings = sentence_embeddings[0] 94 | return sentence_embeddings 95 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/val.txt: -------------------------------------------------------------------------------- 1 | 00006 2 | 00015 3 | 00025 4 | 00052 5 | 00068 6 | 00092 7 | 00093 8 | 00126 9 | 00143 10 | 00170 11 | 00183 12 | 00222 13 | 00244 14 | 00254 15 | 00270 16 | 00271 17 | 00306 18 | 00311 19 | 00331 20 | 00332 21 | 00374 22 | 00427 23 | 00432 24 | 00476 25 | 00491 26 | 00498 27 | 00510 28 | 00539 29 | 00549 30 | 00582 31 | 00589 32 | 00605 33 | 00612 34 | 00646 35 | 00685 36 | 00690 37 | 00691 38 | 00709 39 | 00728 40 | 00733 41 | 00754 42 | 00764 43 | 00834 44 | 00841 45 | 00873 46 | 00897 47 | 00905 48 | 00908 49 | 00909 50 | 00918 51 | 00922 52 | 00943 53 | 00954 54 | 00964 55 | 00967 56 | 00981 57 | 00983 58 | 00988 59 | 01001 60 | 01006 61 | 01016 62 | 01030 63 | 01070 64 | 01073 65 | 01092 66 | 01097 67 | 01118 68 | 01123 69 | 01176 70 | 01177 71 | 01209 72 | 01264 73 | 01282 74 | 01299 75 | 01319 76 | 01369 77 | 01407 78 | 01409 79 | 01412 80 | 01418 81 | 01428 82 | 01429 83 | 01443 84 | 01447 85 | 01479 86 | 01483 87 | 01494 88 | 01503 89 | 01655 90 | 01664 91 | 01700 92 | 01703 93 | 01745 94 | 01762 95 | 01794 96 | 01795 97 | 01809 98 | 01825 99 | 01834 100 | 01836 101 | 01877 102 | 01905 103 | 01931 104 | 01948 105 | 01975 106 | 01992 107 | 01995 108 | 02048 109 | 02110 110 | 02118 111 | 02121 112 | 02171 113 | 02184 114 | 02315 115 | 02322 116 | 02418 117 | 02499 118 | 02593 119 | 02598 120 | 02797 121 | 02904 122 | 02922 123 | 02940 124 | 02947 125 | 03009 126 | 03155 127 | 03227 128 | 03232 129 | 03262 130 | 03276 131 | 03284 132 | 03290 133 | 03365 134 | 03419 135 | 03462 136 | 03480 137 | 03631 138 | 03682 139 | 03708 140 | 03722 141 | 03732 142 | 03813 143 | 03817 144 | 03832 145 | 03877 146 | 03899 147 | M00006 148 | M00015 149 | M00025 150 | M00052 151 | M00068 152 | M00092 153 | M00093 154 | M00126 155 | M00143 156 | M00170 157 | M00183 158 | M00222 159 | M00244 160 | M00254 161 | M00270 162 | M00271 163 | M00306 164 | M00311 165 | M00331 166 | M00332 167 | M00374 168 | M00427 169 | M00432 170 | M00476 171 | M00491 172 | M00498 173 | M00510 174 | M00539 175 | M00549 176 | M00582 177 | M00589 178 | M00605 179 | M00612 180 | M00646 181 | M00685 182 | M00690 183 | M00691 184 | M00709 185 | M00728 186 | M00733 187 | M00754 188 | M00764 189 | M00834 190 | M00841 191 | M00873 192 | M00897 193 | M00905 194 | M00908 195 | M00909 196 | M00918 197 | M00922 198 | M00943 199 | M00954 200 | M00964 201 | M00967 202 | M00981 203 | M00983 204 | M00988 205 | M01001 206 | M01006 207 | M01016 208 | M01030 209 | M01070 210 | M01073 211 | M01092 212 | M01097 213 | M01118 214 | M01123 215 | M01176 216 | M01177 217 | M01209 218 | M01264 219 | M01282 220 | M01299 221 | M01319 222 | M01369 223 | M01407 224 | M01409 225 | M01412 226 | M01418 227 | M01428 228 | M01429 229 | M01443 230 | M01447 231 | M01479 232 | M01483 233 | M01494 234 | M01503 235 | M01655 236 | M01664 237 | M01700 238 | M01703 239 | M01745 240 | M01762 241 | M01794 242 | M01795 243 | M01809 244 | M01825 245 | M01834 246 | M01836 247 | M01877 248 | M01905 249 | M01931 250 | M01948 251 | M01975 252 | M01992 253 | M01995 254 | M02048 255 | M02110 256 | M02118 257 | M02121 258 | M02171 259 | M02184 260 | M02315 261 | M02322 262 | M02418 263 | M02499 264 | M02593 265 | M02598 266 | M02797 267 | M02904 268 | M02922 269 | M02940 270 | M02947 271 | M03009 272 | M03155 273 | M03227 274 | M03232 275 | M03262 276 | M03276 277 | M03284 278 | M03290 279 | M03365 280 | M03419 281 | M03462 282 | M03480 283 | M03631 284 | M03682 285 | M03708 286 | M03722 287 | M03732 288 | M03813 289 | M03817 290 | M03832 291 | M03877 292 | M03899 293 | -------------------------------------------------------------------------------- /prepare/compute_guoh3dfeats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | import numpy as np 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def extract_h3d(feats): 12 | from einops import unpack 13 | 14 | root_data, ric_data, rot_data, local_vel, feet_l, feet_r = unpack( 15 | feats, [[4], [63], [126], [66], [2], [2]], "i *" 16 | ) 17 | return root_data, ric_data, rot_data, local_vel, feet_l, feet_r 18 | 19 | 20 | def swap_left_right(data): 21 | assert len(data.shape) == 3 and data.shape[-1] == 3 22 | data = data.copy() 23 | data[..., 0] *= -1 24 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21] 25 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20] 26 | left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30] 27 | right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51] 28 | tmp = data[:, right_chain] 29 | data[:, right_chain] = data[:, left_chain] 30 | data[:, left_chain] = tmp 31 | if data.shape[1] > 24: 32 | tmp = data[:, right_hand_chain] 33 | data[:, right_hand_chain] = data[:, left_hand_chain] 34 | data[:, left_hand_chain] = tmp 35 | return data 36 | 37 | 38 | @hydra.main( 39 | config_path="../configs", config_name="compute_guoh3dfeats", version_base="1.3" 40 | ) 41 | def compute_guoh3dfeats(cfg: DictConfig): 42 | base_folder = cfg.base_folder 43 | output_folder = cfg.output_folder 44 | force_redo = cfg.force_redo 45 | 46 | from src.guofeats import joints_to_guofeats 47 | from .tools import loop_amass 48 | 49 | output_folder_M = os.path.join(output_folder, "M") 50 | 51 | print("Get h3d features from Guo et al.") 52 | print("The processed motions will be stored in this folder:") 53 | print(output_folder) 54 | 55 | iterator = loop_amass( 56 | base_folder, output_folder, ext=".npy", newext=".npy", force_redo=force_redo 57 | ) 58 | 59 | for motion_path, new_motion_path in iterator: 60 | joints = np.load(motion_path) 61 | 62 | if "humanact12" not in motion_path: 63 | # This is because the authors of HumanML3D 64 | # save the motions by swapping Y and Z (det = -1) 65 | # which is not a proper rotation (det = 1) 66 | # so we should invert x, to make it a rotation 67 | # that is why the authors use "data[..., 0] *= -1" inside the "if" 68 | # before swapping left/right 69 | # https://github.com/EricGuo5513/HumanML3D/blob/main/raw_pose_processing.ipynb 70 | joints[..., 0] *= -1 71 | # the humanact12 motions are normally saved correctly, no need to swap again 72 | # (but in fact this may not be true and the orignal H3D features 73 | # corresponding to HumanAct12 appears to be left/right flipped..) 74 | # At least we are compatible with previous work :/ 75 | 76 | joints_m = swap_left_right(joints) 77 | 78 | # apply transformation 79 | try: 80 | features = joints_to_guofeats(joints) 81 | features_m = joints_to_guofeats(joints_m) 82 | except (IndexError, ValueError): 83 | # The sequence should be only 1 frame long 84 | # so we cannot compute features (which involve velocities etc) 85 | assert len(joints) == 1 86 | continue 87 | # save the features 88 | np.save(new_motion_path, features) 89 | 90 | # save the mirrored features as well 91 | new_motion_path_M = new_motion_path.replace(output_folder, output_folder_M) 92 | os.makedirs(os.path.split(new_motion_path_M)[0], exist_ok=True) 93 | np.save(new_motion_path_M, features_m) 94 | 95 | 96 | if __name__ == "__main__": 97 | compute_guoh3dfeats() 98 | -------------------------------------------------------------------------------- /src/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # For reference 6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence 8 | class KLLoss: 9 | def __call__(self, q, p): 10 | mu_q, logvar_q = q 11 | mu_p, logvar_p = p 12 | 13 | log_var_ratio = logvar_q - logvar_p 14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp() 15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 16 | return div.mean() 17 | 18 | def __repr__(self): 19 | return "KLLoss()" 20 | 21 | 22 | def get_sim_matrix(x, y): 23 | x_logits = torch.nn.functional.normalize(x, dim=-1) 24 | y_logits = torch.nn.functional.normalize(y, dim=-1) 25 | sim_matrix = x_logits @ y_logits.T 26 | return sim_matrix 27 | 28 | 29 | class InfoNCE_with_filtering: 30 | def __init__(self, temperature=0.7, threshold_selfsim=0.8): 31 | self.temperature = temperature 32 | self.threshold_selfsim = threshold_selfsim 33 | 34 | def filter_sim_mat_with_sent_emb(self, sim_matrix, sent_emb): 35 | # put the threshold value between -1 and 1 36 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1 37 | # Filtering too close values 38 | # mask them by putting -inf in the sim_matrix 39 | selfsim = sent_emb @ sent_emb.T 40 | selfsim_nodiag = selfsim - selfsim.diag().diag() 41 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim) 42 | sim_matrix[idx] = -torch.inf 43 | return sim_matrix # TODO check if return necessary or in place operation 44 | 45 | def __call__(self, x, y, sent_emb=None): 46 | bs, device = len(x), x.device 47 | sim_matrix = get_sim_matrix(x, y) / self.temperature 48 | 49 | if sent_emb is not None and self.threshold_selfsim: 50 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb) 51 | 52 | labels = torch.arange(bs, device=device) 53 | 54 | total_loss = ( 55 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels) 56 | ) / 2 57 | 58 | return total_loss 59 | 60 | def __repr__(self): 61 | return f"Constrastive(temp={self.temp})" 62 | 63 | 64 | class HN_InfoNCE_with_filtering(InfoNCE_with_filtering): 65 | def __init__(self, temperature=0.7, threshold_selfsim=0.8, alpha=1.0, beta=0.25): 66 | super().__init__(temperature=temperature, threshold_selfsim=threshold_selfsim) 67 | self.alpha = alpha 68 | self.beta = beta 69 | 70 | def cross_entropy_with_HN_weights(self, sim_matrix): 71 | n = sim_matrix.shape[0] 72 | 73 | labels = range(sim_matrix.shape[0]) 74 | exp_mat = torch.exp(sim_matrix) 75 | num = exp_mat[range(exp_mat.shape[0]), labels] 76 | 77 | exp_mat_beta = torch.exp(self.beta * sim_matrix) 78 | weights = (n - 1) * exp_mat_beta / torch.unsqueeze((torch.sum(exp_mat_beta, axis=1) - exp_mat_beta.diag()), dim=1) 79 | weights = weights.fill_diagonal_(self.alpha) 80 | denum = torch.sum(weights * exp_mat, axis=1) 81 | 82 | return -torch.mean(torch.log(num/denum)) 83 | 84 | def __call__(self, x, y, sent_emb=None): 85 | bs, device = len(x), x.device 86 | sim_matrix = get_sim_matrix(x, y) / self.temperature 87 | 88 | if sent_emb is not None and self.threshold_selfsim: 89 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb) 90 | 91 | total_loss = ( 92 | self.cross_entropy_with_HN_weights(sim_matrix) + self.cross_entropy_with_HN_weights(sim_matrix.T) 93 | ) / 2 94 | 95 | return total_loss 96 | 97 | def __repr__(self): 98 | return f"Constrastive(temp={self.temp})" 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /src/data/text_motion_multi_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs as cs 3 | import orjson # loading faster than json 4 | import json 5 | import logging 6 | import random 7 | 8 | import torch 9 | import numpy as np 10 | from .text_motion import TextMotionDataset 11 | from tqdm import tqdm 12 | 13 | from .collate import collate_text_motion_multiple_texts 14 | 15 | 16 | 17 | def read_split(path, split): 18 | split_file = os.path.join(path, "splits", split + ".txt") 19 | id_list = [] 20 | with cs.open(split_file, "r") as f: 21 | for line in f.readlines(): 22 | id_list.append(line.strip()) 23 | return id_list 24 | 25 | 26 | def load_annotations(path, name="annotations.json"): 27 | json_path = os.path.join(path, name) 28 | with open(json_path, "rb") as ff: 29 | return orjson.loads(ff.read()) 30 | 31 | 32 | class TextMotionMultiLabelsDataset(TextMotionDataset): 33 | def __init__( 34 | self, 35 | path: str, 36 | motion_loader, 37 | text_to_sent_emb, 38 | text_to_token_emb, 39 | split: str = "train", 40 | min_seconds: float = 2.0, 41 | max_seconds: float = 10.0, 42 | preload: bool = True, 43 | tiny: bool = False, 44 | ): 45 | if tiny: 46 | split = split + "_tiny" 47 | 48 | self.collate_fn = collate_text_motion_multiple_texts 49 | self.split = split 50 | self.keyids = read_split(path, split) 51 | 52 | self.text_to_sent_emb = text_to_sent_emb 53 | self.text_to_token_emb = text_to_token_emb 54 | self.motion_loader = motion_loader 55 | 56 | self.min_seconds = min_seconds 57 | self.max_seconds = max_seconds 58 | 59 | # remove too short or too long annotations 60 | self.annotations = load_annotations(path) 61 | if "test" not in split: 62 | self.annotations = self.filter_annotations(self.annotations) 63 | 64 | self.is_training = "train" in split 65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations] 66 | 67 | self.nfeats = self.motion_loader.nfeats 68 | 69 | if preload: 70 | for _ in tqdm(self, desc="Preloading the dataset"): 71 | continue 72 | 73 | def load_keyid(self, keyid, device=None, text_idx=None, sent_emb_mode="first"): 74 | annotations = self.annotations[keyid] 75 | 76 | index = 0 77 | 78 | path = annotations["path"] 79 | annotation = annotations["annotations"][index] 80 | start = annotation["start"] 81 | end = annotation["end"] 82 | 83 | texts = [ann["text"] for ann in annotations["annotations"]] 84 | 85 | text_x_dicts = self.text_to_token_emb(texts) # [{"x": ..., "length": ...}, {"x": ..., "length"}: ..., ... ] 86 | motion_x_dict = self.motion_loader( 87 | path=path, 88 | start=start, 89 | end=end, 90 | ) 91 | 92 | if sent_emb_mode == "first": 93 | sent_emb = self.text_to_sent_emb(texts[0]) 94 | elif sent_emb_mode == "average": 95 | sent_emb = torch.stack([self.text_to_sent_emb(text) for text in texts]) 96 | sent_emb = torch.mean(sent_emb, axis=0) 97 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0) 98 | 99 | output = { 100 | "motion_x_dict": motion_x_dict, 101 | "text_x_dict": text_x_dicts, 102 | "text": texts, 103 | "keyid": keyid, 104 | "sent_emb": sent_emb, 105 | } 106 | 107 | if device is not None: 108 | output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device) 109 | for text_x_dict in output["text_x_dict"]: 110 | text_x_dict["x"] = text_x_dict["x"].to(device) 111 | 112 | return output 113 | 114 | 115 | def write_json(data, path): 116 | with open(path, "w") as ff: 117 | ff.write(json.dumps(data, indent=4)) 118 | 119 | -------------------------------------------------------------------------------- /src/data/text_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs as cs 3 | import orjson # loading faster than json 4 | import json 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | 10 | from .collate import collate_text_motion 11 | 12 | 13 | def read_split(path, split): 14 | split_file = os.path.join(path, "splits", split + ".txt") 15 | id_list = [] 16 | with cs.open(split_file, "r") as f: 17 | for line in f.readlines(): 18 | id_list.append(line.strip()) 19 | return id_list 20 | 21 | 22 | def load_annotations(path, name="annotations.json"): 23 | json_path = os.path.join(path, name) 24 | with open(json_path, "rb") as ff: 25 | return orjson.loads(ff.read()) 26 | 27 | 28 | class TextMotionDataset(Dataset): 29 | def __init__( 30 | self, 31 | path: str, 32 | motion_loader, 33 | text_to_sent_emb, 34 | text_to_token_emb, 35 | split: str = "train", 36 | min_seconds: float = 2.0, 37 | max_seconds: float = 10.0, 38 | preload: bool = True, 39 | tiny: bool = False, 40 | ): 41 | if tiny: 42 | split = split + "_tiny" 43 | 44 | self.collate_fn = collate_text_motion 45 | self.split = split 46 | self.keyids = read_split(path, split) 47 | 48 | self.text_to_sent_emb = text_to_sent_emb 49 | self.text_to_token_emb = text_to_token_emb 50 | self.motion_loader = motion_loader 51 | 52 | self.min_seconds = min_seconds 53 | self.max_seconds = max_seconds 54 | 55 | # remove too short or too long annotations 56 | self.annotations = load_annotations(path) 57 | 58 | # filter annotations (min/max) 59 | # but not for the test set 60 | # otherwise it is not fair for everyone 61 | if "test" not in split: 62 | self.annotations = self.filter_annotations(self.annotations) 63 | 64 | self.is_training = "train" in split 65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations] 66 | self.nfeats = self.motion_loader.nfeats 67 | 68 | if preload: 69 | for _ in tqdm(self, desc="Preloading the dataset"): 70 | continue 71 | 72 | def __len__(self): 73 | return len(self.keyids) 74 | 75 | def __getitem__(self, index): 76 | keyid = self.keyids[index] 77 | return self.load_keyid(keyid) 78 | 79 | def load_keyid(self, keyid): 80 | annotations = self.annotations[keyid] 81 | 82 | # Take the first one for testing/validation 83 | # Otherwise take a random one 84 | index = 0 85 | if self.is_training: 86 | index = np.random.randint(len(annotations["annotations"])) 87 | annotation = annotations["annotations"][index] 88 | 89 | text = annotation["text"] 90 | text_x_dict = self.text_to_token_emb(text) 91 | motion_x_dict = self.motion_loader( 92 | path=annotations["path"], 93 | start=annotation["start"], 94 | end=annotation["end"], 95 | ) 96 | sent_emb = self.text_to_sent_emb(text) 97 | 98 | output = { 99 | "motion_x_dict": motion_x_dict, 100 | "text_x_dict": text_x_dict, 101 | "text": text, 102 | "keyid": keyid, 103 | "sent_emb": sent_emb, 104 | } 105 | return output 106 | 107 | def filter_annotations(self, annotations): 108 | filtered_annotations = {} 109 | for key, val in annotations.items(): 110 | annots = val.pop("annotations") 111 | filtered_annots = [] 112 | for annot in annots: 113 | duration = annot["end"] - annot["start"] 114 | if self.max_seconds >= duration >= self.min_seconds: 115 | filtered_annots.append(annot) 116 | 117 | if filtered_annots: 118 | val["annotations"] = filtered_annots 119 | filtered_annotations[key] = val 120 | 121 | return filtered_annotations 122 | 123 | 124 | def write_json(data, path): 125 | with open(path, "w") as ff: 126 | ff.write(json.dumps(data, indent=4)) 127 | -------------------------------------------------------------------------------- /retrieval_action.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(dataset) / batch_size)) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | 35 | for data in tqdm(all_data_splitted, leave=True): 36 | batch = collate_text_motion(data, device=device) 37 | 38 | # Text is already encoded 39 | text_x_dict = batch["text_x_dict"] 40 | motion_x_dict = batch["motion_x_dict"] 41 | sent_emb = batch["sent_emb"] 42 | 43 | # Encode both motion and text 44 | latent_text = model.encode(text_x_dict, sample_mean=True) 45 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 46 | 47 | latent_texts.append(latent_text) 48 | latent_motions.append(latent_motion) 49 | sent_embs.append(sent_emb) 50 | 51 | latent_texts = torch.cat(latent_texts) 52 | action_latent_text = torch.unique(latent_texts, dim=0) 53 | 54 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))} 55 | 56 | latent_motions = torch.cat(latent_motions) 57 | motion_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_motions))] 58 | 59 | #sent_embs = torch.cat(sent_embs) 60 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions) 61 | returned = { 62 | "sim_matrix": sim_matrix.cpu().numpy(), 63 | "motion_cat_idx": motion_cat_idx 64 | } 65 | return returned 66 | 67 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval") 68 | def retrieval(newcfg: DictConfig) -> None: 69 | device = newcfg.device 70 | run_dir = newcfg.run_dir 71 | ckpt_name = newcfg.ckpt 72 | batch_size = newcfg.batch_size 73 | save_file_name = newcfg.save_file_name 74 | split = newcfg.split 75 | 76 | assert split == "test" 77 | protocols = ["normal"] 78 | 79 | save_dir = os.path.join(run_dir, save_file_name) 80 | os.makedirs(save_dir, exist_ok=True) 81 | 82 | # Load last config 83 | from src.config import read_config 84 | import src.prepare # noqa 85 | 86 | cfg = read_config(run_dir) 87 | 88 | import pytorch_lightning as pl 89 | import numpy as np 90 | from hydra.utils import instantiate 91 | from src.load import load_model_from_cfg 92 | from src.model.metrics import all_contrastive_metrics_action_retrieval, print_latex_metrics 93 | 94 | pl.seed_everything(cfg.seed) 95 | 96 | logger.info("Loading the model") 97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 98 | 99 | 100 | data = newcfg.data 101 | if data is None: 102 | data = cfg.data 103 | 104 | datasets = {} 105 | for protocol in protocols: 106 | # Load the dataset if not already 107 | if protocol not in datasets: 108 | dataset = instantiate(data, split=split) 109 | datasets.update( 110 | {key: dataset for key in ["normal", "threshold", "guo"]} 111 | ) 112 | dataset = datasets[protocol] 113 | 114 | # Compute sim_matrix for each protocol 115 | protocol = "normal" 116 | result = compute_sim_matrix( 117 | model, dataset, dataset.keyids, batch_size=batch_size 118 | ) 119 | 120 | # Compute the metrics 121 | sim_matrix = result["sim_matrix"] 122 | motion_cat_idx = result["motion_cat_idx"] 123 | 124 | protocol_name = protocol 125 | metrics = all_contrastive_metrics_action_retrieval(sim_matrix, motion_cat_idx, norm_metrics=True) 126 | 127 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False) 128 | 129 | metric_name = f"{protocol_name}.yaml" 130 | path = os.path.join(save_dir, metric_name) 131 | save_metric(path, metrics) 132 | 133 | logger.info(f"Testing done, metrics saved in:\n{path}") 134 | 135 | 136 | if __name__ == "__main__": 137 | retrieval() 138 | -------------------------------------------------------------------------------- /retrieval_action_multi_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion_multiple_texts 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(dataset) / batch_size)) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | text_indices = [] 35 | indices_shift = 0 36 | 37 | #for data in tqdm(all_data_splitted, leave=True): 38 | for data in all_data_splitted: 39 | 40 | batch = collate_text_motion_multiple_texts(data, device=device) 41 | # Text is already encoded 42 | text_x_dict = batch["text_x_dict"] 43 | motion_x_dict = batch["motion_x_dict"] 44 | sent_emb = batch["sent_emb"] 45 | 46 | # Encode both motion and text 47 | latent_text = model.encode(text_x_dict, sample_mean=True) 48 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 49 | 50 | latent_texts.append(latent_text) 51 | latent_motions.append(latent_motion) 52 | sent_embs.append(sent_emb) 53 | idx = batch["text_slices"] 54 | idx = [[elt[0] + indices_shift, elt[1] + indices_shift] for elt in idx] 55 | text_indices.extend(idx) 56 | indices_shift += len(latent_text) 57 | 58 | latent_texts = torch.cat(latent_texts) 59 | action_latent_text = torch.unique(latent_texts, dim=0) 60 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))} 61 | text_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_texts))] 62 | 63 | latent_motions = torch.cat(latent_motions) 64 | motion_cat_idx = [] 65 | for start_ind, end_ind in text_indices: 66 | motion_cat_idx.append(text_cat_idx[start_ind:end_ind]) 67 | 68 | #sent_embs = torch.cat(sent_embs) 69 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions) 70 | 71 | returned = { 72 | "sim_matrix": sim_matrix.cpu().numpy(), 73 | "motion_cat_idx": motion_cat_idx 74 | } 75 | return returned 76 | 77 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval_action_multi_labels") 78 | def retrieval_action_multi_labels(newcfg: DictConfig) -> None: 79 | device = newcfg.device 80 | run_dir = newcfg.run_dir 81 | ckpt_name = newcfg.ckpt 82 | batch_size = newcfg.batch_size 83 | save_file_name = newcfg.save_file_name 84 | split = newcfg.split 85 | 86 | assert split == "test" 87 | 88 | save_dir = os.path.join(run_dir, save_file_name) 89 | os.makedirs(save_dir, exist_ok=True) 90 | 91 | # Load last config 92 | from src.config import read_config 93 | import src.prepare # noqa 94 | 95 | cfg = read_config(run_dir) 96 | 97 | import pytorch_lightning as pl 98 | import numpy as np 99 | from hydra.utils import instantiate 100 | from src.load import load_model_from_cfg 101 | from src.model.metrics import all_contrastive_metrics_action_retrieval_multi_labels, print_latex_metrics 102 | 103 | pl.seed_everything(cfg.seed) 104 | 105 | logger.info("Loading the model") 106 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 107 | 108 | 109 | data = newcfg.data 110 | if data is None: 111 | data = cfg.data 112 | 113 | dataset = instantiate(data, split=split) 114 | 115 | # Compute sim_matrix for each protocol 116 | protocol = "normal" 117 | result = compute_sim_matrix( 118 | model, dataset, dataset.keyids, batch_size=batch_size 119 | ) 120 | 121 | # Compute the metrics 122 | sim_matrix = result["sim_matrix"] 123 | motion_cat_idx = result["motion_cat_idx"] 124 | 125 | protocol_name = protocol 126 | metrics = all_contrastive_metrics_action_retrieval_multi_labels(sim_matrix, motion_cat_idx, norm_metrics=True) 127 | 128 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False) 129 | 130 | metric_name = f"{protocol_name}.yaml" 131 | path = os.path.join(save_dir, metric_name) 132 | save_metric(path, metrics) 133 | 134 | logger.info(f"Testing done, metrics saved in:\n{path}") 135 | 136 | 137 | if __name__ == "__main__": 138 | retrieval_action_multi_labels() 139 | -------------------------------------------------------------------------------- /src/model/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | import numpy as np 7 | 8 | from einops import repeat 9 | 10 | 11 | class PositionalEncoding(nn.Module): 12 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: 13 | super().__init__() 14 | self.batch_first = batch_first 15 | 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 20 | div_term = torch.exp( 21 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) 22 | ) 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | pe = pe.unsqueeze(0).transpose(0, 1) 26 | self.register_buffer("pe", pe, persistent=False) 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | if self.batch_first: 30 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] 31 | else: 32 | x = x + self.pe[: x.shape[0], :] 33 | return self.dropout(x) 34 | 35 | 36 | class ACTORStyleEncoder(nn.Module): 37 | # Similar to ACTOR but "action agnostic" and more general 38 | def __init__( 39 | self, 40 | nfeats: int, 41 | vae: bool, 42 | latent_dim: int = 256, 43 | ff_size: int = 1024, 44 | num_layers: int = 4, 45 | num_heads: int = 4, 46 | dropout: float = 0.1, 47 | activation: str = "gelu", 48 | ) -> None: 49 | super().__init__() 50 | 51 | self.nfeats = nfeats 52 | self.projection = nn.Linear(nfeats, latent_dim) 53 | 54 | self.vae = vae 55 | self.nbtokens = 2 if vae else 1 56 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) 57 | 58 | self.sequence_pos_encoding = PositionalEncoding( 59 | latent_dim, dropout=dropout, batch_first=True 60 | ) 61 | 62 | seq_trans_encoder_layer = nn.TransformerEncoderLayer( 63 | d_model=latent_dim, 64 | nhead=num_heads, 65 | dim_feedforward=ff_size, 66 | dropout=dropout, 67 | activation=activation, 68 | batch_first=True, 69 | ) 70 | 71 | self.seqTransEncoder = nn.TransformerEncoder( 72 | seq_trans_encoder_layer, num_layers=num_layers 73 | ) 74 | 75 | def forward(self, x_dict: Dict) -> Tensor: 76 | x = x_dict["x"] 77 | mask = x_dict["mask"] 78 | 79 | x = self.projection(x) 80 | 81 | device = x.device 82 | bs = len(x) 83 | 84 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) 85 | xseq = torch.cat((tokens, x), 1) 86 | 87 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) 88 | aug_mask = torch.cat((token_mask, mask), 1) 89 | 90 | # add positional encoding 91 | xseq = self.sequence_pos_encoding(xseq) 92 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 93 | return final[:, : self.nbtokens] 94 | 95 | 96 | class ACTORStyleDecoder(nn.Module): 97 | # Similar to ACTOR Decoder 98 | 99 | def __init__( 100 | self, 101 | nfeats: int, 102 | latent_dim: int = 256, 103 | ff_size: int = 1024, 104 | num_layers: int = 4, 105 | num_heads: int = 4, 106 | dropout: float = 0.1, 107 | activation: str = "gelu", 108 | ) -> None: 109 | super().__init__() 110 | output_feats = nfeats 111 | self.nfeats = nfeats 112 | 113 | self.sequence_pos_encoding = PositionalEncoding( 114 | latent_dim, dropout, batch_first=True 115 | ) 116 | 117 | seq_trans_decoder_layer = nn.TransformerDecoderLayer( 118 | d_model=latent_dim, 119 | nhead=num_heads, 120 | dim_feedforward=ff_size, 121 | dropout=dropout, 122 | activation=activation, 123 | batch_first=True, 124 | ) 125 | 126 | self.seqTransDecoder = nn.TransformerDecoder( 127 | seq_trans_decoder_layer, num_layers=num_layers 128 | ) 129 | 130 | self.final_layer = nn.Linear(latent_dim, output_feats) 131 | 132 | def forward(self, z_dict: Dict) -> Tensor: 133 | z = z_dict["z"] 134 | mask = z_dict["mask"] 135 | 136 | latent_dim = z.shape[1] 137 | bs, nframes = mask.shape 138 | 139 | z = z[:, None] # sequence of 1 element for the memory 140 | 141 | # Construct time queries 142 | time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device) 143 | time_queries = self.sequence_pos_encoding(time_queries) 144 | 145 | # Pass through the transformer decoder 146 | # with the latent vector for memory 147 | output = self.seqTransDecoder( 148 | tgt=time_queries, memory=z, tgt_key_padding_mask=~mask 149 | ) 150 | 151 | output = self.final_layer(output) 152 | # zero for padded area 153 | output[~mask] = 0 154 | return output 155 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Note on datasets 2 | 3 | Currently, three datasets are widely used for 3D text-to-motion: [KIT-ML](https://motion-annotation.humanoids.kit.edu/dataset/), [HumanML3D](https://github.com/EricGuo5513/HumanML3D) and [BABEL](https://babel.is.tue.mpg.de). 4 | 5 | ### Unifying the datasets 6 | 7 | As explained on their website, [AMASS](https://amass.is.tue.mpg.de) dataset is a large database of human motion unifying different optical marker-based motion capture datasets by representing them within a common framework and parameterization. 8 | 9 | Except from a part of HumanML3D which is based on [HumanAct12](https://ericguo5513.github.io/action-to-motion/) (which is also based on [PhSPD](https://drive.google.com/drive/folders/1ZGkpiI99J-4ygD9i3ytJdmyk_hkejKCd?usp=sharing)), almost all the motion data of KIT-ML, HumanML3D and BABEL are included in AMASS. 10 | 11 | Currently, the text-to-motion datasets are not compatible in terms of motion representation: 12 | - KIT-ML uses [Master Motor Map](https://mmm.humanoids.kit.edu) (robot-like joints) 13 | - HumanML3D takes motion from AMASS, extract joints using the SMPL layer, rotate the joints (make Y the gravity axis), crop the motions, make all the skeleton similar to a reference, and compute motion features. 14 | - BABEL use raw SMPL parameters from AMASS 15 | 16 | To be able to use any text-to-motion dataset with the same representation, I propose in [this repo](https://github.com/Mathux/AMASS-Annotation-Unifier) to unify the datasets, to have the same annotation format. With the agreement of the authors, I included the annotations files in TMR repo, in this folder: [datasets/annotations](datasets/annotations) (for BABEL please follow the instructions). For each datasets, I provide a .json file with: 17 | - The ID of the motion (as found in the original dataset) 18 | - The path of the motion in AMASS (or HumanAct12) 19 | - The duration in seconds 20 | - A list of annotations which contains: 21 | - An ID 22 | - The corresponding text 23 | - The start and end in seconds 24 | 25 | Like this one: 26 | 27 | ```json 28 | { 29 | "000000": { 30 | "path": "KIT/3/kick_high_left02_poses", 31 | "duration": 5.82, 32 | "annotations": [ 33 | { 34 | "seg_id": "000000_0", 35 | "text": "a man kicks something or someone with his left leg.", 36 | "start": 0.0, 37 | "end": 5.82 38 | }, 39 | ... 40 | ``` 41 | 42 | We are now free to use any motion representation. 43 | 44 | 45 | ### Motion representation 46 | 47 | Guo et al. uses a representation of motion which includes rotation invariant forward kinematics features, 3D rotations, velocities, foot contacts. Currently, a lot of works in 3D motion generation uses these features. However, these features are not the same for HumanML3D and KIT-ML (not the same number of joints, the scale is different, the reference skeleton is different etc). 48 | 49 | To let people use TMR as an evaluator, and be comparable with Guo et al. feature extractor, I propose to process the whole AMASS (+HumanAct12) dataset into the HumanML3D Guo features (which I refer to ``guoh3dfeats`` in the code). Then, we can crop each feature file according to any dataset. I also included the mirrored version of each motions. 50 | 51 | ### Differences with the released version of HumanML3D 52 | For motion shorter than 10s, this process corresponds to exactly the features file of HumanML3D (example "000000.npy"). 53 | As a sanity check, you can verify in python that both .npy corresponds to the same data: 54 | 55 | ```python 56 | import numpy as np 57 | new = np.load("datasets/motions/guoh3dfeats/humanact12/humanact12/P11G01R02F1812T1847A0402.npy") 58 | old = np.load("/path/to/HumanML3D/HumanML3D/new_joint_vecs/000001.npy") 59 | assert np.abs(new - old).mean() < 1e-10 60 | ``` 61 | 62 | For motion longer than 10s and which are cropped (like "000004.npy"), the results of cropping the features is a bit different than computing the features of the cropped motion. That is because the ``uniform skeleton`` function takes the first frame as reference to compute bone length. However, the difference is quite small. 63 | 64 | ### Installation 65 | Go to the section "Installation - Set up the datasets" of the [README.md](README.md) to compute the features. 66 | 67 | 68 | ## Credits 69 | For all the datasets, be sure to read and follow their license agreements, and cite them accordingly. 70 | 71 | ### KIT-ML 72 | ```bibtex 73 | @article{Plappert2016, 74 | author = {Matthias Plappert and Christian Mandery and Tamim Asfour}, 75 | title = {The {KIT} Motion-Language Dataset}, 76 | journal = {Big Data} 77 | year = 2016 78 | } 79 | ``` 80 | 81 | ### HumanML3D 82 | ```bibtex 83 | @inproceedings{Guo_2022_CVPR, 84 | author = {Guo, Chuan and Zou, Shihao and Zuo, Xinxin and Wang, Sen and Ji, Wei and Li, Xingyu and Cheng, Li}, 85 | title = {Generating Diverse and Natural 3D Human Motions From Text}, 86 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 87 | year = 2022 88 | } 89 | ``` 90 | 91 | ### BABEL 92 | ```bibtex 93 | @inproceedings{BABEL:CVPR:2021, 94 | title = {{BABEL}: Bodies, Action and Behavior with English Labels}, 95 | author = {Punnakkal, Abhinanda R. and Chandrasekaran, Arjun and Athanasiou, Nikos and Quiros-Ramirez, Alejandra and Black, Michael J.}, 96 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 97 | year = 2021 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /src/logger/csv.py: -------------------------------------------------------------------------------- 1 | # from pytorch_lightning/loggers/csv_logs.py 2 | # of pytorch_lightning version 2.04 3 | 4 | # Copyright The Lightning AI team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | CSV logger 19 | ---------- 20 | 21 | CSV logger for basic experiment logging that does not require opening ports 22 | 23 | """ 24 | import logging 25 | import os 26 | from argparse import Namespace 27 | from typing import Any, Dict, Optional, Union 28 | 29 | # from lightning_fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter 30 | # from lightning_fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger 31 | # Local replacement 32 | from .csv_fabric import _ExperimentWriter as _FabricExperimentWriter 33 | from .csv_fabric import CSVLogger as FabricCSVLogger 34 | 35 | from lightning_fabric.loggers.logger import rank_zero_experiment 36 | from lightning_fabric.utilities.logger import _convert_params 37 | from lightning_fabric.utilities.types import _PATH 38 | from pytorch_lightning.core.saving import save_hparams_to_yaml 39 | from pytorch_lightning.loggers.logger import Logger 40 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 41 | 42 | log = logging.getLogger(__name__) 43 | 44 | 45 | class ExperimentWriter(_FabricExperimentWriter): 46 | r"""Experiment writer for CSVLogger. 47 | 48 | Currently, supports to log hyperparameters and metrics in YAML and CSV 49 | format, respectively. 50 | 51 | Args: 52 | log_dir: Directory for the experiment logs 53 | """ 54 | 55 | NAME_HPARAMS_FILE = "hparams.yaml" 56 | 57 | def __init__(self, log_dir: str) -> None: 58 | super().__init__(log_dir=log_dir) 59 | self.hparams: Dict[str, Any] = {} 60 | 61 | def log_hparams(self, params: Dict[str, Any]) -> None: 62 | """Record hparams.""" 63 | self.hparams.update(params) 64 | 65 | def save(self) -> None: 66 | """Save recorded hparams and metrics into files.""" 67 | hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) 68 | save_hparams_to_yaml(hparams_file, self.hparams) 69 | return super().save() 70 | 71 | 72 | class CSVLogger(Logger, FabricCSVLogger): 73 | r"""Log to local file system in yaml and CSV format. 74 | 75 | Logs are saved to ``os.path.join(save_dir, name)``. 76 | 77 | Example: 78 | >>> from pytorch_lightning import Trainer 79 | >>> from pytorch_lightning.loggers import CSVLogger 80 | >>> logger = CSVLogger("logs", name="my_exp_name") 81 | >>> trainer = Trainer(logger=logger) 82 | 83 | Args: 84 | save_dir: Save directory 85 | name: Experiment name. Defaults to ``'lightning_logs'``. 86 | prefix: A string to put at the beginning of metric keys. 87 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). 88 | """ 89 | 90 | LOGGER_JOIN_CHAR = "-" 91 | 92 | def __init__( 93 | self, 94 | save_dir: _PATH, 95 | name: str = "lightning_logs", 96 | prefix: str = "", 97 | flush_logs_every_n_steps: int = 100, 98 | ): 99 | super().__init__( 100 | root_dir=save_dir, 101 | name=name, 102 | prefix=prefix, 103 | flush_logs_every_n_steps=flush_logs_every_n_steps, 104 | ) 105 | self._save_dir = os.fspath(save_dir) 106 | 107 | @property 108 | def root_dir(self) -> str: 109 | """Parent directory for all checkpoint subdirectories. 110 | 111 | If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will 112 | be saved in "save_dir/" 113 | """ 114 | return os.path.join(self.save_dir, self.name) 115 | 116 | @property 117 | def log_dir(self) -> str: 118 | """The log directory for this run.""" 119 | return self.root_dir 120 | 121 | @property 122 | def save_dir(self) -> str: 123 | """The current directory where logs are saved. 124 | 125 | Returns: 126 | The path to current directory where logs are saved. 127 | """ 128 | return self._save_dir 129 | 130 | @rank_zero_only 131 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 132 | # don't log hyperparameters 133 | # already done in the config 134 | return 135 | 136 | @property 137 | @rank_zero_experiment 138 | def experiment(self) -> _FabricExperimentWriter: 139 | r""" 140 | 141 | Actual _ExperimentWriter object. To use _ExperimentWriter features in your 142 | :class:`~pytorch_lightning.core.module.LightningModule` do the following. 143 | 144 | Example:: 145 | 146 | self.logger.experiment.some_experiment_writer_function() 147 | 148 | """ 149 | if self._experiment is not None: 150 | return self._experiment 151 | 152 | self._fs.makedirs(self.root_dir, exist_ok=True) 153 | self._experiment = ExperimentWriter(log_dir=self.log_dir) 154 | return self._experiment 155 | -------------------------------------------------------------------------------- /src/data/augmented_text_motion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import random 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from .collate import collate_text_motion_multiple_texts 9 | from .text_motion import load_annotations, TextMotionDataset 10 | 11 | 12 | class AugmentedTextMotionDataset(TextMotionDataset): 13 | def __init__( 14 | self, 15 | path: str, 16 | motion_loader, 17 | text_to_sent_emb, 18 | text_to_token_emb, 19 | split: str = "train", 20 | min_seconds: float = 2.0, 21 | max_seconds: float = 10.0, 22 | preload: bool = True, 23 | tiny: bool = False, 24 | paraphrase_filename: str = None, 25 | summary_filename: str = None, 26 | paraphrase_prob: float = 0.2, 27 | summary_prob: float = 0.2, 28 | averaging_prob: float = 0.4, 29 | text_sampling_nbr: int = 4 30 | ): 31 | super().__init__(path, motion_loader, text_to_sent_emb, text_to_token_emb, 32 | split=split, min_seconds=min_seconds, max_seconds=max_seconds, 33 | preload=False, tiny=tiny) 34 | 35 | self.collate_fn = collate_text_motion_multiple_texts 36 | 37 | assert paraphrase_prob == 0 or paraphrase_filename is not None 38 | assert summary_prob == 0 or summary_filename is not None 39 | 40 | self.text_sampling_nbr = text_sampling_nbr 41 | self.paraphrase_prob = 0 42 | if split=="train" and paraphrase_filename is not None: 43 | self.annotations_paraphrased = load_annotations(path, name=paraphrase_filename) 44 | self.paraphrase_prob = paraphrase_prob 45 | self.summary_prob = 0 46 | if split=="train" and summary_filename is not None: 47 | self.annotations_summary = load_annotations(path, name=summary_filename) 48 | self.summary_prob = summary_prob 49 | self.averaging_prob = 0 50 | if split=="train" and paraphrase_filename is not None: 51 | self.averaging_prob = averaging_prob 52 | 53 | # filter annotations (min/max) 54 | # but not for the test set 55 | # otherwise it is not fair for everyone 56 | if "test" not in split: 57 | if "train" in split and paraphrase_filename is not None: 58 | self.annotations_paraphrased = self.filter_annotations(self.annotations_paraphrased) 59 | if "train" in split and summary_filename is not None: 60 | self.annotations_summary = self.filter_annotations(self.annotations_summary) 61 | 62 | if preload: 63 | for _ in tqdm(self, desc="Preloading the dataset"): 64 | continue 65 | 66 | def load_keyid(self, keyid, text_idx=None, sent_emb_mode="first"): 67 | 68 | p = random.random() # Probability that will determine if we use data from augmentation, and with which config 69 | averaging = False 70 | if self.is_training and p < self.paraphrase_prob: 71 | annotations = self.annotations_paraphrased[keyid] 72 | elif self.is_training and p < self.summary_prob + self.paraphrase_prob: 73 | if keyid in self.annotations_summary: 74 | annotations = self.annotations_summary[keyid] 75 | else: 76 | annotations = self.annotations_paraphrased[keyid] # For Babel that has no summary 77 | elif self.is_training and p < self.averaging_prob + self.summary_prob + self.paraphrase_prob: 78 | annotations = copy.deepcopy(self.annotations[keyid]) 79 | if hasattr(self, "annotations_paraphrased") and keyid in self.annotations_paraphrased: 80 | annotations["annotations"] += self.annotations_paraphrased[keyid]["annotations"] 81 | if hasattr(self, "annotations_summary") and keyid in self.annotations_summary: 82 | annotations["annotations"] += self.annotations_summary[keyid]["annotations"] 83 | averaging = True 84 | else: 85 | annotations = self.annotations[keyid] 86 | 87 | # Take the first one for testing/validation 88 | # Otherwise take a random one 89 | index = 0 90 | if averaging: 91 | if isinstance(self.text_sampling_nbr, int): # If number of samples if provided 92 | n = min(self.text_sampling_nbr, len(annotations["annotations"])) 93 | else: # If number of sample not provided, it's chosen randomly 94 | n = random.randint(2, len(annotations["annotations"])) 95 | index = random.sample(range(0, len(annotations["annotations"])), n) 96 | elif text_idx is not None: 97 | index = text_idx % len(annotations["annotations"]) 98 | elif self.is_training: 99 | index = np.random.randint(len(annotations["annotations"])) 100 | 101 | if isinstance(index, int): 102 | index = [index] 103 | 104 | annotation_list = [annotations["annotations"][i] for i in index] 105 | text = [ann["text"] for ann in annotation_list] 106 | annotation0 = annotations["annotations"][index[0]] 107 | 108 | text_x_dict = [self.text_to_token_emb(t) for t in text] 109 | 110 | motion_x_dict = self.motion_loader( 111 | path=annotations["path"], 112 | start=annotation0["start"], 113 | end=annotation0["end"], 114 | ) 115 | 116 | if sent_emb_mode == "first": 117 | sent_emb = self.text_to_sent_emb(text[0]) 118 | elif sent_emb_mode == "average": 119 | sent_emb = torch.stack([self.text_to_sent_emb(t) for t in text]) 120 | sent_emb = torch.mean(sent_emb, axis=0) 121 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0) 122 | 123 | output = { 124 | "motion_x_dict": motion_x_dict, 125 | "text_x_dict": text_x_dict, 126 | "text": text, 127 | "keyid": keyid, 128 | "sent_emb": sent_emb, 129 | } 130 | 131 | # TODO 132 | #if device is not None: 133 | # output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device) 134 | # for i in range(len(output["text_x_dict"][i])): 135 | # output["text_x_dict"][i]["x"] = output["text_x_dict"][i]["x"].to(device) 136 | 137 | return output 138 | -------------------------------------------------------------------------------- /demo/model.py: -------------------------------------------------------------------------------- 1 | # Text model + TMR text encoder only 2 | 3 | from typing import List 4 | import torch.nn as nn 5 | import os 6 | 7 | import torch 8 | import numpy as np 9 | from torch import Tensor 10 | from transformers import AutoTokenizer, AutoModel 11 | from torch.nn.functional import normalize 12 | from einops import repeat 13 | import json 14 | import warnings 15 | 16 | import logging 17 | 18 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator") 19 | logger.setLevel(logging.ERROR) 20 | 21 | 22 | warnings.filterwarnings( 23 | "ignore", "The PyTorch API of nested tensors is in prototype stage*" 24 | ) 25 | 26 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*") 27 | 28 | torch.set_float32_matmul_precision("high") 29 | 30 | 31 | class PositionalEncoding(nn.Module): 32 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: 33 | super().__init__() 34 | self.batch_first = batch_first 35 | 36 | self.dropout = nn.Dropout(p=dropout) 37 | 38 | pe = torch.zeros(max_len, d_model) 39 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 40 | div_term = torch.exp( 41 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) 42 | ) 43 | pe[:, 0::2] = torch.sin(position * div_term) 44 | pe[:, 1::2] = torch.cos(position * div_term) 45 | pe = pe.unsqueeze(0).transpose(0, 1) 46 | self.register_buffer("pe", pe, persistent=False) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | if self.batch_first: 50 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] 51 | else: 52 | x = x + self.pe[: x.shape[0], :] 53 | return self.dropout(x) 54 | 55 | 56 | def read_config(run_dir: str): 57 | path = os.path.join(run_dir, "config.json") 58 | with open(path, "r") as f: 59 | config = json.load(f) 60 | return config 61 | 62 | 63 | class TMR_text_encoder(nn.Module): 64 | def __init__(self, run_dir: str) -> None: 65 | config = read_config(run_dir) 66 | modelpath = config["data"]["text_to_token_emb"]["modelname"] 67 | 68 | text_encoder_conf = config["model"]["text_encoder"] 69 | 70 | vae = text_encoder_conf["vae"] 71 | latent_dim = text_encoder_conf["latent_dim"] 72 | ff_size = text_encoder_conf["ff_size"] 73 | num_layers = text_encoder_conf["num_layers"] 74 | num_heads = text_encoder_conf["num_heads"] 75 | activation = text_encoder_conf["activation"] 76 | nfeats = text_encoder_conf["nfeats"] 77 | 78 | super().__init__() 79 | 80 | # Projection of the text-outputs into the latent space 81 | self.projection = nn.Linear(nfeats, latent_dim) 82 | self.vae = vae 83 | self.nbtokens = 2 if vae else 1 84 | 85 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) 86 | self.sequence_pos_encoding = PositionalEncoding( 87 | latent_dim, dropout=0.0, batch_first=True 88 | ) 89 | 90 | seq_trans_encoder_layer = nn.TransformerEncoderLayer( 91 | d_model=latent_dim, 92 | nhead=num_heads, 93 | dim_feedforward=ff_size, 94 | dropout=0.0, 95 | activation=activation, 96 | batch_first=True, 97 | ) 98 | 99 | self.seqTransEncoder = nn.TransformerEncoder( 100 | seq_trans_encoder_layer, num_layers=num_layers 101 | ) 102 | 103 | text_encoder_pt_path = os.path.join(run_dir, "last_weights/text_encoder.pt") 104 | state_dict = torch.load(text_encoder_pt_path) 105 | self.load_state_dict(state_dict) 106 | 107 | from transformers import logging 108 | 109 | # load text model 110 | logging.set_verbosity_error() 111 | 112 | # Tokenizer 113 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 114 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 115 | 116 | # Text model 117 | self.text_model = AutoModel.from_pretrained(modelpath) 118 | # Then configure the model 119 | self.text_encoded_dim = self.text_model.config.hidden_size 120 | self.eval() 121 | 122 | def get_last_hidden_state(self, texts: List[str], return_mask: bool = False): 123 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 124 | output = self.text_model(**encoded_inputs.to(self.text_model.device)) 125 | if not return_mask: 126 | return output.last_hidden_state 127 | return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool) 128 | 129 | def forward(self, texts: List[str]) -> Tensor: 130 | text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True) 131 | 132 | x = self.projection(text_encoded) 133 | 134 | device = x.device 135 | bs = len(x) 136 | 137 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) 138 | xseq = torch.cat((tokens, x), 1) 139 | 140 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) 141 | aug_mask = torch.cat((token_mask, mask), 1) 142 | 143 | # add positional encoding 144 | xseq = self.sequence_pos_encoding(xseq) 145 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 146 | return final[:, 0] 147 | 148 | # compute score for retrieval 149 | def compute_scores(self, texts, unit_embs=None, embs=None): 150 | # not both empty 151 | assert not (unit_embs is None and embs is None) 152 | # not both filled 153 | assert not (unit_embs is not None and embs is not None) 154 | 155 | output_str = False 156 | # if one input, squeeze the output 157 | if isinstance(texts, str): 158 | texts = [texts] 159 | output_str = True 160 | 161 | # compute unit_embs from embs if not given 162 | if embs is not None: 163 | unit_embs = normalize(embs) 164 | 165 | with torch.no_grad(): 166 | latent_unit_texts = normalize(self(texts)) 167 | # compute cosine similarity between 0 and 1 168 | scores = (unit_embs @ latent_unit_texts.T).T / 2 + 0.5 169 | scores = scores.cpu().numpy() 170 | 171 | if output_str: 172 | scores = scores[0] 173 | 174 | return scores 175 | -------------------------------------------------------------------------------- /src/rifke.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | import numpy as np 5 | 6 | from torch import Tensor 7 | 8 | from .geometry import axis_angle_rotation, matrix_to_axis_angle 9 | from .joints import INFOS 10 | 11 | 12 | def joints_to_rifke(joints, jointstype="smpljoints"): 13 | # Joints to rotation invariant poses (Holden et. al.) 14 | # Similar function than fke2rifke in Language2Pose repository 15 | # Adapted from the pytorch version of TEMOS 16 | # https://github.com/Mathux/TEMOS 17 | # Estimate the last velocities based on acceleration 18 | # Difference of rotations are in SO3 space now 19 | 20 | # First remove the ground 21 | ground = joints[..., 2].min() 22 | poses = joints.clone() 23 | poses[..., 2] -= ground 24 | 25 | poses = joints.clone() 26 | translation = poses[..., 0, :].clone() 27 | 28 | # Let the root have the Z translation --> gravity axis 29 | root_grav_axis = translation[..., 2] 30 | 31 | # Trajectory => Translation without gravity axis (Z) 32 | trajectory = translation[..., [0, 1]] 33 | 34 | # Compute the forward direction (before removing the root joint) 35 | forward = get_forward_direction(poses, jointstype=jointstype) 36 | 37 | # Delete the root joints of the poses 38 | poses = poses[..., 1:, :] 39 | 40 | # Remove the trajectory of the poses 41 | poses[..., [0, 1]] -= trajectory[..., None, :] 42 | 43 | vel_trajectory = torch.diff(trajectory, dim=-2) 44 | 45 | # repeat the last acceleration 46 | # for the last (not seen) velocity 47 | last_acceleration = vel_trajectory[..., -1, :] - vel_trajectory[..., -2, :] 48 | 49 | future_velocity = vel_trajectory[..., -1, :] + last_acceleration 50 | vel_trajectory = torch.cat((vel_trajectory, future_velocity[..., None, :]), dim=-2) 51 | 52 | angles = torch.atan2(*(forward.transpose(0, -1))).transpose(0, -1) 53 | 54 | # True difference of angles 55 | mat_rotZ = axis_angle_rotation("Z", angles) 56 | vel_mat_rotZ = mat_rotZ[..., 1:, :, :] @ mat_rotZ.transpose(-1, -2)[..., :-1, :, :] 57 | # repeat the last acceleration (same as the trajectory but in the 3D rotation space) 58 | last_acc_rotZ = ( 59 | vel_mat_rotZ[..., -1, :, :] @ vel_mat_rotZ.transpose(-1, -2)[..., -2, :, :] 60 | ) 61 | future_vel_rotZ = vel_mat_rotZ[..., -1, :, :] @ last_acc_rotZ 62 | vel_mat_rotZ = torch.cat((vel_mat_rotZ, future_vel_rotZ[..., None, :, :]), dim=-3) 63 | vel_angles = matrix_to_axis_angle(vel_mat_rotZ)[..., 2] 64 | 65 | # Construct the inverse rotation matrix 66 | rotations_inv = mat_rotZ.transpose(-1, -2)[..., :2, :2] 67 | 68 | poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 1]], rotations_inv) 69 | poses_local = torch.stack( 70 | (poses_local[..., 0], poses_local[..., 1], poses[..., 2]), axis=-1 71 | ) 72 | 73 | # stack the xyz joints into feature vectors 74 | poses_features = rearrange(poses_local, "... joints xyz -> ... (joints xyz)") 75 | 76 | # Rotate the vel_trajectory 77 | vel_trajectory_local = torch.einsum( 78 | "...j,...jk->...k", vel_trajectory, rotations_inv 79 | ) 80 | 81 | # Stack things together 82 | features = group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local) 83 | return features 84 | 85 | 86 | def rifke_to_joints(features: Tensor, jointstype="smpljoints") -> Tensor: 87 | root_grav_axis, poses_features, vel_angles, vel_trajectory_local = ungroup(features) 88 | 89 | # Remove the dummy last angle and integrate the angles 90 | angles = torch.cumsum(vel_angles[..., :-1], dim=-1) 91 | # The first angle is zero 92 | angles = torch.cat((0 * angles[..., [0]], angles), dim=-1) 93 | rotations = axis_angle_rotation("Z", angles)[..., :2, :2] 94 | 95 | # Get back the poses 96 | poses_local = rearrange(poses_features, "... (joints xyz) -> ... joints xyz", xyz=3) 97 | 98 | # Rotate the poses 99 | poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 1]], rotations) 100 | poses = torch.stack((poses[..., 0], poses[..., 1], poses_local[..., 2]), axis=-1) 101 | 102 | # Rotate the vel_trajectory 103 | vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, rotations) 104 | # Remove the dummy last velocity and integrate the trajectory 105 | trajectory = torch.cumsum(vel_trajectory[..., :-1, :], dim=-2) 106 | # The first position is zero 107 | trajectory = torch.cat((0 * trajectory[..., [0], :], trajectory), dim=-2) 108 | 109 | # Add the root joints (which is still zero) 110 | poses = torch.cat((0 * poses[..., [0], :], poses), -2) 111 | 112 | # put back the gravity offset 113 | poses[..., 0, 2] = root_grav_axis 114 | 115 | # Add the trajectory globally 116 | poses[..., [0, 1]] += trajectory[..., None, :] 117 | return poses 118 | 119 | 120 | def group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local): 121 | # Stack things together 122 | features = torch.cat( 123 | ( 124 | root_grav_axis[..., None], 125 | poses_features, 126 | vel_angles[..., None], 127 | vel_trajectory_local, 128 | ), 129 | -1, 130 | ) 131 | return features 132 | 133 | 134 | def ungroup(features: Tensor) -> tuple[Tensor]: 135 | # Unbind things 136 | root_grav_axis = features[..., 0] 137 | poses_features = features[..., 1:-3] 138 | vel_angles = features[..., -3] 139 | vel_trajectory_local = features[..., -2:] 140 | return root_grav_axis, poses_features, vel_angles, vel_trajectory_local 141 | 142 | 143 | def get_forward_direction(poses, jointstype="smpljoints"): 144 | assert jointstype in INFOS 145 | infos = INFOS[jointstype] 146 | assert poses.shape[-2] == infos["njoints"] 147 | RH, LH, RS, LS = infos["RH"], infos["LH"], infos["RS"], infos["LS"] 148 | across = ( 149 | poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - poses[..., LS, :] 150 | ) 151 | forward = torch.stack((-across[..., 1], across[..., 0]), axis=-1) 152 | forward = torch.nn.functional.normalize(forward, dim=-1) 153 | return forward 154 | 155 | 156 | def canonicalize_rotation(joints, jointstype="smpljoints"): 157 | return_np = False 158 | if isinstance(joints, np.ndarray): 159 | joints = torch.from_numpy(joints) 160 | return_np = True 161 | 162 | features = joints_to_rifke(joints, jointstype=jointstype) 163 | joints_c = rifke_to_joints(features, jointstype=jointstype) 164 | if return_np: 165 | joints_c = joints_c.numpy() 166 | return joints_c 167 | -------------------------------------------------------------------------------- /src/logger/csv_fabric.py: -------------------------------------------------------------------------------- 1 | # from lightning_fabric/loggers/csv_logs.py 2 | # of lightning_fabric version 2.04 3 | 4 | # Copyright The Lightning AI team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import csv 19 | import logging 20 | import os 21 | from argparse import Namespace 22 | from typing import Any, Dict, List, Optional, Union 23 | 24 | from torch import Tensor 25 | 26 | from lightning_fabric.loggers.logger import Logger, rank_zero_experiment 27 | from lightning_fabric.utilities.cloud_io import get_filesystem 28 | from lightning_fabric.utilities.logger import _add_prefix 29 | from lightning_fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn 30 | from lightning_fabric.utilities.types import _PATH 31 | 32 | log = logging.getLogger(__name__) 33 | 34 | 35 | class CSVLogger(Logger): 36 | r"""Log to the local file system in CSV format. 37 | 38 | Logs are saved to ``os.path.join(root_dir, name)``. 39 | 40 | Args: 41 | root_dir: The root directory in which all your experiments with different names and versions will be stored. 42 | name: Experiment name. Defaults to ``'lightning_logs'``. 43 | prefix: A string to put at the beginning of metric keys. 44 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). 45 | 46 | Example:: 47 | 48 | from lightning_fabric.loggers import CSVLogger 49 | 50 | logger = CSVLogger("path/to/logs/root", name="my_model") 51 | logger.log_metrics({"loss": 0.235, "acc": 0.75}) 52 | logger.finalize("success") 53 | """ 54 | 55 | LOGGER_JOIN_CHAR = "-" 56 | 57 | def __init__( 58 | self, 59 | root_dir: _PATH, 60 | name: str = "lightning_logs", 61 | prefix: str = "", 62 | flush_logs_every_n_steps: int = 100, 63 | ): 64 | super().__init__() 65 | root_dir = os.fspath(root_dir) 66 | self._root_dir = root_dir 67 | self._name = name or "" 68 | self._prefix = prefix 69 | self._fs = get_filesystem(root_dir) 70 | self._experiment: Optional[_ExperimentWriter] = None 71 | self._flush_logs_every_n_steps = flush_logs_every_n_steps 72 | 73 | @property 74 | def name(self) -> str: 75 | """Gets the name of the experiment. 76 | 77 | Returns: 78 | The name of the experiment. 79 | """ 80 | return self._name 81 | 82 | @property 83 | def version(self) -> str: 84 | return "" 85 | 86 | @property 87 | def root_dir(self) -> str: 88 | """Gets the save directory where the versioned CSV experiments are saved.""" 89 | return self._root_dir 90 | 91 | @property 92 | def log_dir(self) -> str: 93 | """The log directory for this run.""" 94 | # create a pseudo standard path 95 | return os.path.join(self.root_dir, self.name) 96 | 97 | @property 98 | @rank_zero_experiment 99 | def experiment(self) -> "_ExperimentWriter": 100 | """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the 101 | following. 102 | 103 | Example:: 104 | 105 | self.logger.experiment.some_experiment_writer_function() 106 | """ 107 | if self._experiment is not None: 108 | return self._experiment 109 | 110 | os.makedirs(self.root_dir, exist_ok=True) 111 | self._experiment = _ExperimentWriter(log_dir=self.log_dir) 112 | return self._experiment 113 | 114 | @rank_zero_only 115 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 116 | raise NotImplementedError( 117 | "The `CSVLogger` does not yet support logging hyperparameters." 118 | ) 119 | 120 | @rank_zero_only 121 | def log_metrics( 122 | self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None 123 | ) -> None: 124 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) 125 | self.experiment.log_metrics(metrics, step) 126 | if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0: 127 | self.save() 128 | 129 | @rank_zero_only 130 | def save(self) -> None: 131 | super().save() 132 | self.experiment.save() 133 | 134 | @rank_zero_only 135 | def finalize(self, status: str) -> None: 136 | if self._experiment is None: 137 | # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been 138 | # initialized there 139 | return 140 | self.save() 141 | 142 | 143 | class _ExperimentWriter: 144 | r"""Experiment writer for CSVLogger. 145 | 146 | Args: 147 | log_dir: Directory for the experiment logs 148 | """ 149 | 150 | NAME_METRICS_FILE = "metrics.csv" 151 | 152 | def __init__(self, log_dir: str) -> None: 153 | self.metrics: List[Dict[str, float]] = [] 154 | 155 | self._fs = get_filesystem(log_dir) 156 | self.log_dir = log_dir 157 | self._fs.makedirs(self.log_dir, exist_ok=True) 158 | self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) 159 | 160 | if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir): 161 | # Read previous logs 162 | if os.path.exists(self.metrics_file_path): 163 | with self._fs.open(self.metrics_file_path, "r") as f: 164 | reader = csv.DictReader(f) 165 | self.metrics = [x for x in reader] 166 | 167 | def log_metrics( 168 | self, metrics_dict: Dict[str, float], step: Optional[int] = None 169 | ) -> None: 170 | """Record metrics.""" 171 | 172 | def _handle_value(value: Union[Tensor, Any]) -> Any: 173 | if isinstance(value, Tensor): 174 | return value.item() 175 | return value 176 | 177 | if step is None: 178 | step = len(self.metrics) 179 | 180 | metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} 181 | metrics["step"] = step 182 | self.metrics.append(metrics) 183 | 184 | def save(self) -> None: 185 | """Save recorded metrics into files.""" 186 | if not self.metrics: 187 | return 188 | 189 | last_m = {} 190 | for m in self.metrics: 191 | last_m.update(m) 192 | metrics_keys = list(last_m.keys()) 193 | 194 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 195 | writer = csv.DictWriter(f, fieldnames=metrics_keys) 196 | writer.writeheader() 197 | writer.writerows(self.metrics) 198 | -------------------------------------------------------------------------------- /src/model/tmr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from torch import Tensor 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .temos import TEMOS 7 | from .losses import InfoNCE_with_filtering 8 | from .metrics import all_contrastive_metrics 9 | 10 | 11 | # x.T will be deprecated in pytorch 12 | def transpose(x): 13 | return x.permute(*torch.arange(x.ndim - 1, -1, -1)) 14 | 15 | 16 | def get_sim_matrix(x, y): 17 | x_logits = torch.nn.functional.normalize(x, dim=-1) 18 | y_logits = torch.nn.functional.normalize(y, dim=-1) 19 | sim_matrix = x_logits @ transpose(y_logits) 20 | return sim_matrix 21 | 22 | 23 | # Scores are between 0 and 1 24 | def get_score_matrix(x, y): 25 | sim_matrix = get_sim_matrix(x, y) 26 | scores = sim_matrix / 2 + 0.5 27 | return scores 28 | 29 | 30 | class TMR(TEMOS): 31 | r"""TMR: Text-to-Motion Retrieval 32 | Using Contrastive 3D Human Motion Synthesis 33 | Find more information about the model on the following website: 34 | https://mathis.petrovich.fr/tmr 35 | 36 | Args: 37 | motion_encoder: a module to encode the input motion features in the latent space (required). 38 | text_encoder: a module to encode the text embeddings in the latent space (required). 39 | motion_decoder: a module to decode the latent vector into motion features (required). 40 | vae: a boolean to make the model probabilistic (required). 41 | fact: a scaling factor for sampling the VAE (optional). 42 | sample_mean: sample the mean vector instead of random sampling (optional). 43 | lmd: dictionary of losses weights (optional). 44 | lr: learninig rate for the optimizer (optional). 45 | temperature: temperature of the softmax in the contrastive loss (optional). 46 | threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional). 47 | threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional). 48 | """ 49 | 50 | def __init__( 51 | self, 52 | motion_encoder: nn.Module, 53 | text_encoder: nn.Module, 54 | motion_decoder: nn.Module, 55 | vae: bool, 56 | contrastive_loss: Optional[InfoNCE_with_filtering] = None, 57 | temperature: float = 0.7, # For compatibility with TMR original code 58 | threshold_selfsim: float = 0.80, # For compatibility with TMR original code 59 | fact: Optional[float] = None, 60 | sample_mean: Optional[bool] = False, 61 | lmd: Dict = {"recons": 1.0, "latent": 1.0e-5, "kl": 1.0e-5, "contrastive": 0.1}, 62 | lr: float = 1e-4, 63 | threshold_selfsim_metrics: float = 0.95, 64 | ) -> None: 65 | # Initialize module like TEMOS 66 | super().__init__( 67 | motion_encoder=motion_encoder, 68 | text_encoder=text_encoder, 69 | motion_decoder=motion_decoder, 70 | vae=vae, 71 | fact=fact, 72 | sample_mean=sample_mean, 73 | lmd=lmd, 74 | lr=lr, 75 | ) 76 | 77 | # adding the contrastive loss 78 | self.contrastive_loss_fn = contrastive_loss 79 | if self.contrastive_loss_fn is None: # For compatibility with TMR original code 80 | self.contrastive_loss_fn = InfoNCE_with_filtering( 81 | temperature=temperature, threshold_selfsim=threshold_selfsim 82 | ) 83 | self.threshold_selfsim_metrics = threshold_selfsim_metrics 84 | 85 | # store validation values to compute retrieval metrics 86 | # on the whole validation set 87 | self.validation_step_t_latents = [] 88 | self.validation_step_m_latents = [] 89 | self.validation_step_sent_emb = [] 90 | 91 | def compute_loss(self, batch: Dict, return_all=False) -> Dict: 92 | t_results, m_results = self.call_models(batch) 93 | t_motions, t_latents, t_dists = t_results["motions"], t_results["latent_vectors"], t_results["distributions"] 94 | m_motions, m_latents, m_dists = m_results["motions"], m_results["latent_vectors"], m_results["distributions"] 95 | 96 | ref_motions = batch["motion_x_dict"]["x"] 97 | 98 | # sentence embeddings 99 | sent_emb = batch["sent_emb"] 100 | 101 | # Store all losses 102 | losses = {} 103 | 104 | # Reconstructions losses 105 | # fmt: off 106 | losses["recons"] = ( 107 | + self.reconstruction_loss_fn(t_motions, ref_motions) # text -> motion 108 | + self.reconstruction_loss_fn(m_motions, ref_motions) # motion -> motion 109 | ) 110 | # fmt: on 111 | 112 | # VAE losses 113 | if self.vae: 114 | # Create a centred normal distribution to compare with 115 | # logvar = 0 -> std = 1 116 | ref_mus = torch.zeros_like(m_dists[0]) 117 | ref_logvar = torch.zeros_like(m_dists[1]) 118 | ref_dists = (ref_mus, ref_logvar) 119 | 120 | losses["kl"] = ( 121 | self.kl_loss_fn(t_dists, m_dists) # text_to_motion 122 | + self.kl_loss_fn(m_dists, t_dists) # motion_to_text 123 | + self.kl_loss_fn(m_dists, ref_dists) # motion 124 | + self.kl_loss_fn(t_dists, ref_dists) # text 125 | ) 126 | 127 | # Latent manifold loss 128 | losses["latent"] = self.latent_loss_fn(t_latents, m_latents) 129 | 130 | # TMR: adding the contrastive loss 131 | losses["contrastive"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb) 132 | 133 | # Weighted average of the losses 134 | losses["loss"] = sum( 135 | self.lmd[x] * val for x, val in losses.items() if x in self.lmd 136 | ) 137 | 138 | # Used for the validation step 139 | if return_all: 140 | return losses, t_latents, m_latents 141 | 142 | return losses 143 | 144 | def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: 145 | bs = len(batch["motion_x_dict"]["x"]) 146 | losses, t_latents, m_latents = self.compute_loss(batch, return_all=True) 147 | 148 | # Store the latent vectors 149 | self.validation_step_t_latents.append(t_latents) 150 | self.validation_step_m_latents.append(m_latents) 151 | self.validation_step_sent_emb.append(batch["sent_emb"]) 152 | 153 | for loss_name in sorted(losses): 154 | loss_val = losses[loss_name] 155 | self.log( 156 | f"val_{loss_name}", 157 | loss_val, 158 | on_epoch=True, 159 | on_step=True, 160 | batch_size=bs, 161 | ) 162 | 163 | return losses["loss"] 164 | 165 | def on_validation_epoch_end(self): 166 | # Compute contrastive metrics on the whole batch 167 | t_latents = torch.cat(self.validation_step_t_latents) 168 | m_latents = torch.cat(self.validation_step_m_latents) 169 | sent_emb = torch.cat(self.validation_step_sent_emb) 170 | 171 | # Compute the similarity matrix 172 | sim_matrix = get_sim_matrix(t_latents, m_latents).cpu().numpy() 173 | 174 | contrastive_metrics = all_contrastive_metrics( 175 | sim_matrix, 176 | emb=sent_emb.cpu().numpy(), 177 | threshold=self.threshold_selfsim_metrics, 178 | ) 179 | 180 | for loss_name in sorted(contrastive_metrics): 181 | loss_val = contrastive_metrics[loss_name] 182 | self.log( 183 | f"val_{loss_name}_epoch", 184 | loss_val, 185 | on_epoch=True, 186 | on_step=False, 187 | ) 188 | 189 | self.validation_step_t_latents.clear() 190 | self.validation_step_m_latents.clear() 191 | self.validation_step_sent_emb.clear() 192 | -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(keyids) / min(batch_size, len(keyids)))) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | for data in tqdm(all_data_splitted, leave=True): 35 | batch = collate_text_motion(data, device=device) 36 | # Text is already encoded 37 | text_x_dict = batch["text_x_dict"] 38 | motion_x_dict = batch["motion_x_dict"] 39 | sent_emb = batch["sent_emb"] 40 | 41 | # Encode both motion and text 42 | latent_text = model.encode(text_x_dict, sample_mean=True) 43 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 44 | 45 | latent_texts.append(latent_text) 46 | latent_motions.append(latent_motion) 47 | sent_embs.append(sent_emb) 48 | 49 | latent_texts = torch.cat(latent_texts) 50 | latent_motions = torch.cat(latent_motions) 51 | sent_embs = torch.cat(sent_embs) 52 | sim_matrix = get_sim_matrix(latent_texts, latent_motions) 53 | returned = { 54 | "sim_matrix": sim_matrix.cpu().numpy(), 55 | "sent_emb": sent_embs.cpu().numpy(), 56 | } 57 | return returned 58 | 59 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval") 60 | def retrieval(newcfg: DictConfig) -> None: 61 | protocol = newcfg.protocol 62 | threshold_val = newcfg.threshold 63 | device = newcfg.device 64 | run_dir = newcfg.run_dir 65 | ckpt_name = newcfg.ckpt 66 | batch_size = newcfg.batch_size 67 | save_file_name = newcfg.save_file_name 68 | split = newcfg.split 69 | 70 | print("protocol : ", protocol) 71 | assert protocol in ["all", "normal", "threshold", "nsim", "guo", "normal_no_mirror", "threshold_no_mirror"] 72 | assert split == "test" or (protocol != "nsim" and protocol != "all") 73 | 74 | if protocol == "all": 75 | protocols = ["normal", "threshold", "nsim", "guo"] 76 | else: 77 | protocols = [protocol] 78 | 79 | save_dir = os.path.join(run_dir, save_file_name) 80 | os.makedirs(save_dir, exist_ok=True) 81 | 82 | # Load last config 83 | from src.config import read_config 84 | import src.prepare # noqa 85 | 86 | cfg = read_config(run_dir) 87 | 88 | import pytorch_lightning as pl 89 | import numpy as np 90 | from hydra.utils import instantiate 91 | from src.load import load_model_from_cfg 92 | from src.model.metrics import all_contrastive_metrics, print_latex_metrics 93 | 94 | pl.seed_everything(cfg.seed) 95 | 96 | logger.info("Loading the model") 97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 98 | 99 | 100 | data = newcfg.data 101 | if data is None: 102 | data = cfg.data 103 | 104 | datasets = {} 105 | results = {} 106 | for protocol in protocols: 107 | # Load the dataset if not already 108 | if protocol not in datasets: 109 | if protocol in ["normal", "threshold", "guo"]: 110 | dataset = instantiate(data, split=split) 111 | datasets.update( 112 | {key: dataset for key in ["normal", "threshold", "guo"]} 113 | ) 114 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]: 115 | datasets[protocol] = instantiate(data, split=split + "_no_mirror") 116 | elif protocol == "nsim": 117 | datasets[protocol] = instantiate(data, split="nsim_test") 118 | dataset = datasets[protocol] 119 | 120 | # Compute sim_matrix for each protocol 121 | if protocol not in results: 122 | if protocol in ["normal", "threshold"]: 123 | res = compute_sim_matrix( 124 | model, dataset, dataset.keyids, batch_size=batch_size 125 | ) 126 | results.update({key: res for key in ["normal", "threshold"]}) 127 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]: 128 | res = compute_sim_matrix( 129 | model, dataset, dataset.keyids, batch_size=batch_size 130 | ) 131 | results[protocol] = res 132 | elif protocol == "nsim": 133 | res = compute_sim_matrix( 134 | model, dataset, dataset.keyids, batch_size=batch_size 135 | ) 136 | results[protocol] = res 137 | elif protocol == "guo": 138 | keyids = sorted(dataset.keyids) 139 | N = len(keyids) 140 | 141 | # make batches of 32 142 | idx = np.arange(N) 143 | np.random.seed(0) 144 | np.random.shuffle(idx) 145 | idx_batches = [ 146 | idx[32 * i : 32 * (i + 1)] for i in range(len(keyids) // 32) 147 | ] 148 | 149 | # split into batches of 32 150 | # batched_keyids = [ [32], [32], [...]] 151 | results["guo"] = [ 152 | compute_sim_matrix( 153 | model, 154 | dataset, 155 | np.array(keyids)[idx_batch], 156 | batch_size=batch_size, 157 | ) 158 | for idx_batch in idx_batches 159 | ] 160 | result = results[protocol] 161 | 162 | # Compute the metrics 163 | if protocol == "guo": 164 | all_metrics = [] 165 | for x in result: 166 | sim_matrix = x["sim_matrix"] 167 | metrics = all_contrastive_metrics(sim_matrix, rounding=None) 168 | all_metrics.append(metrics) 169 | 170 | avg_metrics = {} 171 | for key in all_metrics[0].keys(): 172 | avg_metrics[key] = round( 173 | float(np.mean([metrics[key] for metrics in all_metrics])), 2 174 | ) 175 | 176 | metrics = avg_metrics 177 | protocol_name = protocol 178 | else: 179 | sim_matrix = result["sim_matrix"] 180 | 181 | protocol_name = protocol 182 | if protocol == "threshold": 183 | emb = result["sent_emb"] 184 | threshold = threshold_val 185 | protocol_name = protocol + f"_{threshold}" 186 | else: 187 | emb, threshold = None, None 188 | metrics = all_contrastive_metrics(sim_matrix, emb, threshold=threshold, t2m=True, m2t=False) 189 | 190 | print_latex_metrics(metrics, ranks=[1, 3, 10], t2m=True, m2t=False, MedR=False) 191 | 192 | print("protocol_name : ", protocol_name) 193 | metric_name = f"{protocol_name}.yaml" 194 | path = os.path.join(save_dir, metric_name) 195 | save_metric(path, metrics) 196 | 197 | logger.info(f"Testing done, metrics saved in:\n{path}") 198 | 199 | 200 | if __name__ == "__main__": 201 | retrieval() 202 | -------------------------------------------------------------------------------- /src/renderer/matplotlib.py: -------------------------------------------------------------------------------- 1 | # From TEMOS: temos/render/anim.py 2 | # Assume Z is the gravity axis 3 | # Inspired by 4 | # - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py 5 | # - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py 6 | 7 | import logging 8 | 9 | from dataclasses import dataclass 10 | from typing import List, Tuple, Optional 11 | import numpy as np 12 | from src.rifke import canonicalize_rotation 13 | 14 | logger = logging.getLogger("matplotlib.animation") 15 | logger.setLevel(logging.ERROR) 16 | 17 | colors = ("black", "magenta", "red", "green", "blue") 18 | 19 | KINEMATIC_TREES = { 20 | "smpljoints": [ 21 | [0, 3, 6, 9, 12, 15], 22 | [9, 13, 16, 18, 20], 23 | [9, 14, 17, 19, 21], 24 | [0, 1, 4, 7, 10], 25 | [0, 2, 5, 8, 11], 26 | ], 27 | "guoh3djoints": [ # no hands 28 | [0, 3, 6, 9, 12, 15], 29 | [9, 13, 16, 18, 20], 30 | [9, 14, 17, 19, 21], 31 | [0, 1, 4, 7, 10], 32 | [0, 2, 5, 8, 11], 33 | ], 34 | } 35 | 36 | 37 | @dataclass 38 | class MatplotlibRender: 39 | jointstype: str = "smpljoints" 40 | fps: float = 20.0 41 | colors: List[str] = colors 42 | figsize: int = 4 43 | fontsize: int = 15 44 | canonicalize: bool = False 45 | 46 | def __call__( 47 | self, 48 | joints, 49 | highlights=None, 50 | title: str = "", 51 | output: str = "notebook", 52 | jointstype=None, 53 | ): 54 | jointstype = jointstype if jointstype is not None else self.jointstype 55 | render_animation( 56 | joints, 57 | title=title, 58 | highlights=highlights, 59 | output=output, 60 | jointstype=jointstype, 61 | fps=self.fps, 62 | colors=self.colors, 63 | figsize=(self.figsize, self.figsize), 64 | fontsize=self.fontsize, 65 | canonicalize=self.canonicalize, 66 | ) 67 | 68 | 69 | def init_axis(fig, title, radius=1.5): 70 | ax = fig.add_subplot(1, 1, 1, projection="3d") 71 | ax.view_init(elev=20.0, azim=-60) 72 | 73 | fact = 2 74 | ax.set_xlim3d([-radius / fact, radius / fact]) 75 | ax.set_ylim3d([-radius / fact, radius / fact]) 76 | ax.set_zlim3d([0, radius]) 77 | 78 | ax.set_aspect("auto") 79 | ax.set_xticklabels([]) 80 | ax.set_yticklabels([]) 81 | ax.set_zticklabels([]) 82 | 83 | ax.set_axis_off() 84 | ax.grid(b=False) 85 | 86 | ax.set_title(title, loc="center", wrap=True) 87 | return ax 88 | 89 | 90 | def plot_floor(ax, minx, maxx, miny, maxy, minz): 91 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 92 | 93 | # Plot a plane XZ 94 | verts = [ 95 | [minx, miny, minz], 96 | [minx, maxy, minz], 97 | [maxx, maxy, minz], 98 | [maxx, miny, minz], 99 | ] 100 | xz_plane = Poly3DCollection([verts], zorder=1) 101 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 1)) 102 | ax.add_collection3d(xz_plane) 103 | 104 | # Plot a bigger square plane XZ 105 | radius = max((maxx - minx), (maxy - miny)) 106 | 107 | # center +- radius 108 | minx_all = (maxx + minx) / 2 - radius 109 | maxx_all = (maxx + minx) / 2 + radius 110 | 111 | miny_all = (maxy + miny) / 2 - radius 112 | maxy_all = (maxy + miny) / 2 + radius 113 | 114 | verts = [ 115 | [minx_all, miny_all, minz], 116 | [minx_all, maxy_all, minz], 117 | [maxx_all, maxy_all, minz], 118 | [maxx_all, miny_all, minz], 119 | ] 120 | xz_plane = Poly3DCollection([verts], zorder=1) 121 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 122 | ax.add_collection3d(xz_plane) 123 | return ax 124 | 125 | 126 | def update_camera(ax, root, radius=1.5): 127 | fact = 2 128 | ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]]) 129 | ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]]) 130 | 131 | 132 | def render_animation( 133 | joints: np.ndarray, 134 | output: str = "notebook", 135 | highlights: Optional[np.ndarray] = None, 136 | jointstype: str = "smpljoints", 137 | title: str = "", 138 | fps: float = 20.0, 139 | colors: List[str] = colors, 140 | figsize: Tuple[int] = (4, 4), 141 | fontsize: int = 15, 142 | canonicalize: bool = False, 143 | agg=True, 144 | ): 145 | if agg: 146 | import matplotlib 147 | 148 | matplotlib.use("Agg") 149 | 150 | if highlights is not None: 151 | assert len(highlights) == len(joints) 152 | 153 | assert jointstype in KINEMATIC_TREES 154 | kinematic_tree = KINEMATIC_TREES[jointstype] 155 | 156 | import matplotlib.pyplot as plt 157 | from matplotlib.animation import FuncAnimation 158 | import matplotlib.patheffects as pe 159 | 160 | mean_fontsize = fontsize 161 | 162 | # heuristic to change fontsize 163 | fontsize = mean_fontsize - (len(title) - 30) / 20 164 | plt.rcParams.update({"font.size": fontsize}) 165 | 166 | # Z is gravity here 167 | x, y, z = 0, 1, 2 168 | 169 | joints = joints.copy() 170 | 171 | if canonicalize: 172 | joints = canonicalize_rotation(joints, jointstype=jointstype) 173 | 174 | # Create a figure and initialize 3d plot 175 | fig = plt.figure(figsize=figsize) 176 | ax = init_axis(fig, title) 177 | 178 | # Create spline line 179 | trajectory = joints[:, 0, [x, y]] 180 | avg_segment_length = ( 181 | np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3 182 | ) 183 | draw_offset = int(25 / avg_segment_length) 184 | (spline_line,) = ax.plot(*trajectory.T, zorder=10, color="white") 185 | 186 | # Create a floor 187 | minx, miny, _ = joints.min(axis=(0, 1)) 188 | maxx, maxy, _ = joints.max(axis=(0, 1)) 189 | plot_floor(ax, minx, maxx, miny, maxy, 0) 190 | 191 | # Put the character on the floor 192 | height_offset = np.min(joints[:, :, z]) # Min height 193 | joints = joints.copy() 194 | joints[:, :, z] -= height_offset 195 | 196 | # Initialization for redrawing 197 | lines = [] 198 | initialized = False 199 | 200 | def update(frame): 201 | nonlocal initialized 202 | skeleton = joints[frame] 203 | 204 | root = skeleton[0] 205 | update_camera(ax, root) 206 | 207 | hcolors = colors 208 | if highlights is not None and highlights[frame]: 209 | hcolors = ("red", "red", "red", "red", "red") 210 | 211 | for index, (chain, color) in enumerate( 212 | zip(reversed(kinematic_tree), reversed(hcolors)) 213 | ): 214 | if not initialized: 215 | lines.append( 216 | ax.plot( 217 | skeleton[chain, x], 218 | skeleton[chain, y], 219 | skeleton[chain, z], 220 | linewidth=6.0, 221 | color=color, 222 | zorder=20, 223 | path_effects=[pe.SimpleLineShadow(), pe.Normal()], 224 | ) 225 | ) 226 | 227 | else: 228 | lines[index][0].set_xdata(skeleton[chain, x]) 229 | lines[index][0].set_ydata(skeleton[chain, y]) 230 | lines[index][0].set_3d_properties(skeleton[chain, z]) 231 | lines[index][0].set_color(color) 232 | 233 | left = max(frame - draw_offset, 0) 234 | right = min(frame + draw_offset, trajectory.shape[0]) 235 | 236 | spline_line.set_xdata(trajectory[left:right, 0]) 237 | spline_line.set_ydata(trajectory[left:right, 1]) 238 | spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0])) 239 | initialized = True 240 | 241 | fig.tight_layout() 242 | frames = joints.shape[0] 243 | anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False) 244 | 245 | if output == "notebook": 246 | from IPython.display import HTML 247 | 248 | HTML(anim.to_jshtml()) 249 | else: 250 | # anim.save(output, writer='ffmpeg', fps=fps) 251 | anim.save(output, fps=fps) 252 | 253 | plt.close() 254 | -------------------------------------------------------------------------------- /prepare/combine_datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import hydra 4 | import json 5 | import numpy as np 6 | import os 7 | from omegaconf import DictConfig 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | SUFFIX_DICT = {"humanml3d": "h", "kitml": "k", "babel": "b"} 13 | 14 | @hydra.main(config_path="../configs", config_name="combine_datasets", version_base="1.3") 15 | def combine_datasets(cfg: DictConfig): 16 | 17 | train_datasets = cfg.datasets 18 | annotations_folder_path = cfg.annotations_path 19 | 20 | combined_dataset_name = "_".join(train_datasets) 21 | combined_dataset_folder = os.path.join(annotations_folder_path, combined_dataset_name) 22 | os.makedirs(combined_dataset_folder, exist_ok=True) 23 | 24 | annotations = {} 25 | annotations_paraphrases = {} 26 | annotations_actions = {} 27 | 28 | annotations = {} 29 | annotations_paraphrases = {} 30 | annotations_actions = {} 31 | annotations_all = {} 32 | 33 | dataset_annotations = {} 34 | 35 | for dataset in train_datasets: 36 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json") 37 | with open(annotations_path) as f: 38 | d = json.load(f) 39 | dataset_annotations[dataset] = d 40 | d_new = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 41 | annotations.update(d_new) 42 | 43 | annotations_paraphrases_path = os.path.join(annotations_folder_path, dataset, "annotations_paraphrases.json") 44 | if os.path.exists(annotations_paraphrases_path): 45 | with open(annotations_paraphrases_path) as f: 46 | d = json.load(f) 47 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 48 | annotations_paraphrases.update(d) 49 | 50 | annotations_actions_path = os.path.join(annotations_folder_path, dataset, "annotations_actions.json") 51 | if os.path.exists(annotations_actions_path): 52 | with open(annotations_actions_path) as f: 53 | d = json.load(f) 54 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 55 | annotations_actions.update(d) 56 | 57 | annotations_all_path = os.path.join(annotations_folder_path, dataset, "annotations_all.json") 58 | if os.path.exists(annotations_all_path): 59 | with open(annotations_all_path) as f: 60 | d = json.load(f) 61 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 62 | annotations_all.update(d) 63 | 64 | with open(os.path.join(combined_dataset_folder, "annotations.json"), "w") as f: 65 | json.dump(annotations, f, indent=2) 66 | with open(os.path.join(combined_dataset_folder, "annotations_paraphrases.json"), "w") as f: 67 | json.dump(annotations_paraphrases, f, indent=2) 68 | with open(os.path.join(combined_dataset_folder, "annotations_actions.json"), "w") as f: 69 | json.dump(annotations_actions, f, indent=2) 70 | with open(os.path.join(combined_dataset_folder, "annotations_all.json"), "w") as f: 71 | json.dump(annotations_all, f, indent=2) 72 | 73 | test_datasets = cfg.test_sets 74 | 75 | for dataset in test_datasets: 76 | if dataset not in dataset_annotations: 77 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json") 78 | with open(annotations_path) as f: 79 | d = json.load(f) 80 | dataset_annotations[dataset] = d 81 | 82 | # Splits creation 83 | 84 | logger.info(f"Creating train split by combining train splits from: {', '.join(train_datasets)}") 85 | logger.info(f"Removing from the train splits samples overlapping with test split from: {', '.join(test_datasets)}") 86 | 87 | dataset_splits = {} 88 | 89 | splits = ["train", "val"] 90 | for dataset in train_datasets: 91 | if dataset not in dataset_splits: 92 | dataset_splits[dataset] = {} 93 | for split in splits: 94 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f: 95 | str_inds = f.read() 96 | inds = str_inds.split("\n") 97 | if inds[-1] == "": 98 | inds.pop(-1) 99 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds} 100 | dataset_splits[dataset][split] = dic_ind_path 101 | 102 | split = 'test' 103 | for dataset in test_datasets: 104 | if dataset not in dataset_splits: 105 | dataset_splits[dataset] = {} 106 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f: 107 | str_inds = f.read() 108 | inds = str_inds.split("\n") 109 | if inds[-1] == "": 110 | inds.pop(-1) 111 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds} 112 | dataset_splits[dataset][split] = dic_ind_path 113 | 114 | 115 | to_remove = {train_dataset: {test_dataset: {"train": [], "val": []} for test_dataset in test_datasets if test_dataset != train_dataset} for train_dataset in train_datasets} 116 | 117 | for train_dataset in train_datasets: 118 | 119 | for split in ["train", "val"]: 120 | for train_id, train_path in dataset_splits[train_dataset][split].items(): 121 | 122 | for test_dataset in set(test_datasets) - set([train_dataset]): 123 | 124 | for test_id, test_path in dataset_splits[test_dataset]["test"].items(): 125 | if train_path == test_path: 126 | 127 | if not cfg.filter_babel_seg: 128 | if test_dataset == "babel": 129 | test_duration = float(dataset_annotations[test_dataset][test_id]["duration"]) 130 | test_fragment_duration = float(dataset_annotations[test_dataset][test_id]["fragment_duration"]) 131 | 132 | if not np.isclose([test_duration], [test_fragment_duration], atol=0.1, rtol=0): 133 | continue 134 | 135 | if train_dataset == "babel": 136 | train_duration = float(dataset_annotations[train_dataset][train_id]["duration"]) 137 | train_fragment_duration = float(dataset_annotations[train_dataset][train_id]["fragment_duration"]) 138 | 139 | if not np.isclose([train_duration], [train_fragment_duration], atol=0.1, rtol=0): 140 | continue 141 | 142 | train_start = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["start"]) 143 | train_end = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["end"]) 144 | test_start = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["start"]) 145 | test_end = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["end"]) 146 | 147 | if not ((train_end <= test_start) or (test_end <= train_start)): 148 | to_remove[train_dataset][test_dataset][split].append(train_id) 149 | 150 | datasets_curated = {train_dataset: {split: list(dataset_splits[train_dataset][split].keys()) for split in dataset_splits[train_dataset].keys()} for train_dataset in train_datasets} 151 | 152 | for train_dataset in train_datasets: 153 | for test_dataset in set(test_datasets) - set([train_dataset]): 154 | for split in ["train", "val"]: 155 | for keyid in to_remove[train_dataset][test_dataset][split]: 156 | if keyid in datasets_curated[train_dataset][split]: 157 | datasets_curated[train_dataset][split].remove(keyid) 158 | 159 | splits_folder = os.path.join(annotations_folder_path, combined_dataset_name, "splits") 160 | os.makedirs(splits_folder, exist_ok=True) 161 | all_ids = [] 162 | for split in ["train", "val"]: 163 | ids = [] 164 | for train_dataset in train_datasets: 165 | ids += [f'{elt}_{SUFFIX_DICT[train_dataset]}' for elt in datasets_curated[train_dataset][split]] 166 | all_ids += ids 167 | ids_str = "\n".join(ids) 168 | filename = f"{split}{cfg.split_suffix}.txt" 169 | with open(os.path.join(splits_folder, filename), "w") as f: 170 | f.write(ids_str) 171 | 172 | all_ids_str = "\n".join(all_ids) 173 | with open(os.path.join(splits_folder, f"all{cfg.split_suffix}.txt"), "w") as f: 174 | f.write(all_ids_str) 175 | 176 | 177 | if __name__ == "__main__": 178 | combine_datasets() 179 | 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |