├── .gitignore ├── DATASETS.md ├── LICENSE.md ├── README.md ├── app.py ├── configs ├── combine_datasets.yaml ├── compute_guoh3dfeats.yaml ├── data │ ├── _base.yaml │ ├── _base_augmented.yaml │ ├── babel.yaml │ ├── babel_actions_120.yaml │ ├── babel_actions_60.yaml │ ├── humanml3d.yaml │ ├── humanml3d_kitml.yaml │ ├── humanml3d_kitml_babel.yaml │ ├── kitml.yaml │ └── motion_loader │ │ ├── _base.yaml │ │ └── guoh3dfeats.yaml ├── debug │ ├── profiler.yaml │ └── train.yaml ├── defaults.yaml ├── encode_dataset.yaml ├── encode_motion.yaml ├── encode_text.yaml ├── extract.yaml ├── hydra │ ├── hydra_logging │ │ └── tqdm.yaml │ └── job_logging │ │ └── tqdm.yaml ├── load_model.yaml ├── model │ ├── temos.yaml │ ├── tmr.yaml │ ├── tmr_hn.yaml │ ├── tmr_text_averaging.yaml │ └── tmr_text_averaging_hn.yaml ├── motion_stats.yaml ├── render.yaml ├── renderer │ └── matplotlib.yaml ├── retrieval.yaml ├── retrieval_action_multi_labels.yaml ├── text_dataset_sim.yaml ├── text_embeddings.yaml ├── text_embeddings_with_augmentation.yaml ├── text_motion_sim.yaml ├── train.yaml ├── train_hn.yaml ├── train_hn_with_augmentation.yaml ├── train_with_augmentation.yaml └── trainer.yaml ├── datasets └── annotations │ ├── humanml3d │ ├── annotations.json │ └── splits │ │ ├── all.txt │ │ ├── nsim_test.txt │ │ ├── test.txt │ │ ├── test_tiny.txt │ │ ├── train.txt │ │ ├── train_tiny.txt │ │ ├── val.txt │ │ └── val_tiny.txt │ └── kitml │ ├── annotations.json │ └── splits │ ├── all.txt │ ├── nsim_test.txt │ ├── test.txt │ ├── test_tiny.txt │ ├── train.txt │ ├── train_tiny.txt │ ├── val.txt │ └── val_tiny.txt ├── demo ├── amass_to_babel.json ├── load.py └── model.py ├── encode_dataset.py ├── encode_motion.py ├── encode_text.py ├── extract.py ├── prepare ├── combine_datasets.py ├── compute_guoh3dfeats.py ├── download_pretrain_models.sh ├── motion_stats.py ├── text_embeddings.py └── tools.py ├── requirements.txt ├── retrieval.py ├── retrieval_action.py ├── retrieval_action_multi_labels.py ├── src ├── callback │ ├── __init__.py │ ├── progress.py │ └── tqdmbar.py ├── config.py ├── data │ ├── augmented_text_motion.py │ ├── collate.py │ ├── motion.py │ ├── text.py │ ├── text_motion.py │ └── text_motion_multi_labels.py ├── geometry.py ├── guofeats │ ├── __init__.py │ ├── common │ │ ├── quaternion.py │ │ └── skeleton.py │ ├── motion_representation.py │ ├── paramUtil.py │ └── skeleton_example_h3d.npy ├── joints.py ├── load.py ├── logger │ ├── csv.py │ └── csv_fabric.py ├── logging.py ├── model │ ├── __init__.py │ ├── actor.py │ ├── losses.py │ ├── metrics.py │ ├── temos.py │ ├── text_encoder.py │ ├── tmr.py │ └── tmr_text_averaging.py ├── prepare.py ├── renderer │ └── matplotlib.py └── rifke.py ├── stats ├── babel │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── babel_actions_120 │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── babel_actions_60 │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── humanml3d │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── humanml3d_kitml │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt ├── humanml3d_kitml_babel │ └── guoh3dfeats │ │ ├── mean.pt │ │ └── std.pt └── kitml │ └── guoh3dfeats │ ├── mean.pt │ └── std.pt ├── text_motion_sim.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Note on datasets 2 | 3 | Currently, three datasets are widely used for 3D text-to-motion: [KIT-ML](https://motion-annotation.humanoids.kit.edu/dataset/), [HumanML3D](https://github.com/EricGuo5513/HumanML3D) and [BABEL](https://babel.is.tue.mpg.de). 4 | 5 | ### Unifying the datasets 6 | 7 | As explained on their website, [AMASS](https://amass.is.tue.mpg.de) dataset is a large database of human motion unifying different optical marker-based motion capture datasets by representing them within a common framework and parameterization. 8 | 9 | Except from a part of HumanML3D which is based on [HumanAct12](https://ericguo5513.github.io/action-to-motion/) (which is also based on [PhSPD](https://drive.google.com/drive/folders/1ZGkpiI99J-4ygD9i3ytJdmyk_hkejKCd?usp=sharing)), almost all the motion data of KIT-ML, HumanML3D and BABEL are included in AMASS. 10 | 11 | Currently, the text-to-motion datasets are not compatible in terms of motion representation: 12 | - KIT-ML uses [Master Motor Map](https://mmm.humanoids.kit.edu) (robot-like joints) 13 | - HumanML3D takes motion from AMASS, extract joints using the SMPL layer, rotate the joints (make Y the gravity axis), crop the motions, make all the skeleton similar to a reference, and compute motion features. 14 | - BABEL use raw SMPL parameters from AMASS 15 | 16 | To be able to use any text-to-motion dataset with the same representation, I propose in [this repo](https://github.com/Mathux/AMASS-Annotation-Unifier) to unify the datasets, to have the same annotation format. With the agreement of the authors, I included the annotations files in TMR repo, in this folder: [datasets/annotations](datasets/annotations) (for BABEL please follow the instructions). For each datasets, I provide a .json file with: 17 | - The ID of the motion (as found in the original dataset) 18 | - The path of the motion in AMASS (or HumanAct12) 19 | - The duration in seconds 20 | - A list of annotations which contains: 21 | - An ID 22 | - The corresponding text 23 | - The start and end in seconds 24 | 25 | Like this one: 26 | 27 | ```json 28 | { 29 | "000000": { 30 | "path": "KIT/3/kick_high_left02_poses", 31 | "duration": 5.82, 32 | "annotations": [ 33 | { 34 | "seg_id": "000000_0", 35 | "text": "a man kicks something or someone with his left leg.", 36 | "start": 0.0, 37 | "end": 5.82 38 | }, 39 | ... 40 | ``` 41 | 42 | We are now free to use any motion representation. 43 | 44 | 45 | ### Motion representation 46 | 47 | Guo et al. uses a representation of motion which includes rotation invariant forward kinematics features, 3D rotations, velocities, foot contacts. Currently, a lot of works in 3D motion generation uses these features. However, these features are not the same for HumanML3D and KIT-ML (not the same number of joints, the scale is different, the reference skeleton is different etc). 48 | 49 | To let people use TMR as an evaluator, and be comparable with Guo et al. feature extractor, I propose to process the whole AMASS (+HumanAct12) dataset into the HumanML3D Guo features (which I refer to ``guoh3dfeats`` in the code). Then, we can crop each feature file according to any dataset. I also included the mirrored version of each motions. 50 | 51 | ### Differences with the released version of HumanML3D 52 | For motion shorter than 10s, this process corresponds to exactly the features file of HumanML3D (example "000000.npy"). 53 | As a sanity check, you can verify in python that both .npy corresponds to the same data: 54 | 55 | ```python 56 | import numpy as np 57 | new = np.load("datasets/motions/guoh3dfeats/humanact12/humanact12/P11G01R02F1812T1847A0402.npy") 58 | old = np.load("/path/to/HumanML3D/HumanML3D/new_joint_vecs/000001.npy") 59 | assert np.abs(new - old).mean() < 1e-10 60 | ``` 61 | 62 | For motion longer than 10s and which are cropped (like "000004.npy"), the results of cropping the features is a bit different than computing the features of the cropped motion. That is because the ``uniform skeleton`` function takes the first frame as reference to compute bone length. However, the difference is quite small. 63 | 64 | ### Installation 65 | Go to the section "Installation - Set up the datasets" of the [README.md](README.md) to compute the features. 66 | 67 | 68 | ## Credits 69 | For all the datasets, be sure to read and follow their license agreements, and cite them accordingly. 70 | 71 | ### KIT-ML 72 | ```bibtex 73 | @article{Plappert2016, 74 | author = {Matthias Plappert and Christian Mandery and Tamim Asfour}, 75 | title = {The {KIT} Motion-Language Dataset}, 76 | journal = {Big Data} 77 | year = 2016 78 | } 79 | ``` 80 | 81 | ### HumanML3D 82 | ```bibtex 83 | @inproceedings{Guo_2022_CVPR, 84 | author = {Guo, Chuan and Zou, Shihao and Zuo, Xinxin and Wang, Sen and Ji, Wei and Li, Xingyu and Cheng, Li}, 85 | title = {Generating Diverse and Natural 3D Human Motions From Text}, 86 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 87 | year = 2022 88 | } 89 | ``` 90 | 91 | ### BABEL 92 | ```bibtex 93 | @inproceedings{BABEL:CVPR:2021, 94 | title = {{BABEL}: Bodies, Action and Behavior with English Labels}, 95 | author = {Punnakkal, Abhinanda R. and Chandrasekaran, Arjun and Athanasiou, Nikos and Quiros-Ramirez, Alejandra and Black, Michael J.}, 96 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 97 | year = 2021 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Léore Bensabath 4 | 5 | TMR LICENCE 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # TMR++ 4 | ## A Cross-Dataset Study for Text-based 3D Human Motion Retrieval 5 | 6 | Léore Bensabath 7 | · 8 | Mathis Petrovich 9 | · 10 | Gül Varol 11 | 12 | 13 | [![arXiv](https://img.shields.io/badge/arXiv-TMR-A10717.svg?logo=arXiv)](https://arxiv.org/abs/2405.16909) 14 | 15 |
16 | 17 | 18 | ## Description 19 | Official PyTorch implementation of the paper: 20 |
21 | 22 | [**A Cross-Dataset Study for Text-based 3D Human Motion Retrieval**](https://arxiv.org/abs/2405.16909) 23 | 24 |
25 | 26 | This repo is based on the implementation of 27 | [**TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis**](https://github.com/Mathux/TMR/tree/master). 28 | 29 | Please visit our [**webpage**](https://imagine.enpc.fr/~leore.bensabath/TMR++) for more details. 30 | 31 | ### Bibtex 32 | If you find this code useful in your research, please cite: 33 | 34 | ```bibtex 35 | @inproceedings{lbensabath2024, 36 | title={TMR++: A Cross-Dataset Study for Text-based 3D Human Motion Retrieval}, 37 | author={Bensabath, Léore and Petrovich, Mathis and Varol, G{\"u}l}, 38 | journal={CVPRW HuMoGen}, 39 | year={2024} 40 | } 41 | ``` 42 | and 43 | ```bibtex 44 | @inproceedings{petrovich23tmr, 45 | title = {{TMR}: Text-to-Motion Retrieval Using Contrastive {3D} Human Motion Synthesis}, 46 | author = {Petrovich, Mathis and Black, Michael J. and Varol, G{\"u}l}, 47 | booktitle = {International Conference on Computer Vision ({ICCV})}, 48 | year = 2023 49 | } 50 | ``` 51 | 52 | You can also put a star :star:, if the code is useful to you. 53 | 54 | ## Installation :construction_worker: 55 | 56 |
Create environment 57 |   58 | 59 | Create a python virtual environnement: 60 | ```bash 61 | python -m venv ~/.venv/TMR 62 | source ~/.venv/TMR/bin/activate 63 | ``` 64 | 65 | Install [PyTorch](https://pytorch.org/get-started/locally/) 66 | ```bash 67 | python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 68 | ``` 69 | 70 | Then install remaining packages: 71 | ``` 72 | python -m pip install -r requirements.txt 73 | ``` 74 | 75 | which corresponds to the packages: pytorch_lightning, einops, hydra-core, hydra-colorlog, orjson, tqdm, scipy. 76 | The code was tested on Python 3.10.12 and PyTorch 2.0.1. 77 | 78 |
79 | 80 |
Set up the datasets 81 |   82 | 83 | Please first set up the datasets as explain in https://github.com/Mathux/TMR/tree/master in the same README section. 84 | 85 | In this repo, we provide the augmented versions of dataset humanml3d, kitml and babel. 86 | For a given dataset ($DATASET), up to 3 new annotation file have been created: 87 | - ``dataset/annotations/$DATASET/annotations_paraphrases.json``: Includes all the paraphrases generated by a llm 88 | - ``dataset/annotations/$DATASET/annotations_actions.json``: For humanml3d and kitml only, includes the action type label generated by a llm 89 | - ``dataset/annotations/$DATASET/annotations_all.json``: Includes a concatenation by key id of all the annotations (original and llm generated) 90 | 91 | Copy the data in your repo from [here](https://drive.google.com/drive/u/1/folders/1_SpOgtYCZBPAXoVz00Zhyk6tPRObUIiW) 92 | 93 | ### Compute the text embeddings for the data with text augmentation 94 | 95 | Run this command to compute the sentence embeddings and token embeddings for the annotations with text augmentation: 96 | 97 | ``` 98 | python -m prepare.text_embeddings --config-name=text_embeddings_with_augmentation data=$DATASET 99 | ``` 100 |
101 | 102 |
Combine datasets 103 |   104 | 105 | To create a combination of any of the datasets, run: 106 | 107 | ```bash 108 | python -m prepare.combine_datasets datasets=$DATASETS test_sets=$TEST_DATASETS split_suffix=$SPLIT_SUFFIX [OPTIONS] 109 | ``` 110 | Where: 111 | - ``datasets``: The list of datasets to combine 112 | - ``test_sets``: The intended list on which the dataset is going to be tested. When generating the split files, this will filter from the training set the samples from one of the training datasets that overlap with samples from another provided testing dataset. 113 | Note that you can create different splits for different intended testing sets by leveraging parameter **split_suffix**. The annotations file for the given combination will stay the same regardless of the **test_sets** value. 114 | - ``split_suffix``: The split file suffix for this given combination of test sets. Training and validation split files will be saved under: ``datasets/annotations/splits/train{split_suffix}.txt`` and ``datasets/annotations/splits/val{split_suffix}.txt`` 115 | 116 | The new dataset will be created inside folder ``datasets/annotations/{dataset1}_{dataset2}(_{dataset3})`` 117 | 118 | **Example:** 119 | ```bash 120 | python -m prepare.combine_datasets datasets=["humanml3d","kitml"] test_sets=["babel"] split_suffix="_wo_hkb" 121 | ``` 122 | 123 | Then run the ''python -m prepare.text_embeddings'' command with or without text augmentations on your new dataset combination. 124 | 125 | **Example:** 126 | ```bash 127 | python -m prepare.text_embeddings --config-name=text_embeddings_with_augmentation data=humanml3d_kitml 128 | ``` 129 |
130 | 131 | ## Training :rocket: 132 | 133 | ### Training with a combination of datasets 134 | 135 | To train with a combination of datasets without any text augmentation, run the same command as in TMR with the relevant dataset name: 136 | 137 | **Example:** 138 | ```bash 139 | python train.py data=humanml3d_kitml 140 | ``` 141 | 142 | ### Training with text augmentation 143 | 144 | ```bash 145 | python train.py --config-name=train_with_augmentation data=$DATASET 146 | ``` 147 | 148 |
Details 149 | Relevant parameters you can modify in addition to the ones in TMR are the text augmentation picking probabilities detailed in the paper: 150 | **Example** 151 | ```bash 152 | python train.py --config-name=train_with_augmentation data=humanml3d data.paraphrase_prob=0.2 data.summary_prob=0.2 data.averaging_prob=0.3 run_dir=outputs/tmr_humanml3d_w_textAugmentation_0.2_0.2_0.3 153 | ``` 154 |
155 | 156 |
Extracting weights 157 | After training, run the following command, to extract the weights from the checkpoint: 158 | 159 | ```bash 160 | python extract.py run_dir=$RUN_DIR 161 | ``` 162 | 163 | It will take the last checkpoint by default. This should create the folder ``RUN_DIR/last_weights`` and populate it with the files: ``motion_decoder.pt``, ``motion_encoder.pt`` and ``text_encoder.pt``. 164 | This process makes loading models faster, it does not depends on the file structure anymore, and each module can be loaded independently. This is already done for pretrained models. 165 | 166 |
167 | 168 | 169 | ## Pretrained models :dvd: 170 | 171 | You can find the different models used in the paper here: 172 | [pre-trained models](https://drive.google.com/drive/u/1/folders/1otB-B4m4okpD_0crGMcpg0hOsSRYH45t) 173 | 174 | 175 | ## Evaluation :bar_chart: 176 | 177 | ### Motion to text / Text to motion retrieval 178 | 179 | ```bash 180 | python retrieval.py run_dir=$RUN_DIR data=$DATA 181 | ``` 182 | 183 | ### Action recognition 184 | 185 | For action recognition on datasets babel_actions_60 and babel_actions_120, run: 186 | 187 | ```bash 188 | python retrieval_action_multi_labels.py run_dir=$RUN_DIR data=$DATA 189 | ``` 190 | 191 | 192 | It will compute the metrics, show them and save them in this folder ``RUN_DIR/contrastive_metrics_$DATA/``. 193 | You can change the name of the saving file using argument ``save_file_name``. 194 | 195 | 196 | ## Usage :computer: 197 | 198 | ### Encode a motion 199 | Note that the .npy file should corresponds to HumanML3D Guo features. 200 | 201 | ```bash 202 | python encode_motion.py run_dir=RUN_DIR npy=/path/to/motion.npy 203 | ``` 204 | 205 | ### Encode a text 206 | 207 | ```bash 208 | python encode_text.py run_dir=RUN_DIR text="A person is walking forward." 209 | ``` 210 | 211 | ### Compute similarity between text and motion 212 | ```bash 213 | python text_motion_sim.py run_dir=RUN_DIR text=TEXT npy=/path/to/motion.npy 214 | ``` 215 | For example with ``text="a man sets to do a backflips then fails back flip and falls to the ground"`` and ``npy=HumanML3D/HumanML3D/new_joint_vecs/001034.npy`` you should get around 0.96. 216 | 217 | 218 | ## Launch the demo 219 | 220 | ### Encode the whole motion dataset 221 | ```bash 222 | python encode_dataset.py run_dir=RUN_DIR 223 | ``` 224 | 225 | 226 | ### Text-to-motion retrieval demo 227 | Run this command: 228 | 229 | ```bash 230 | python app.py 231 | ``` 232 | 233 | and then open your web browser at the address: ``http://localhost:7860``. 234 | 235 | ## License :books: 236 | This code is distributed under an [MIT LICENSE](LICENSE). 237 | 238 | Note that our code depends on other libraries, including PyTorch, PyTorch3D, Hugging Face, Hydra, and uses datasets which each have their own respective licenses that must also be followed. 239 | -------------------------------------------------------------------------------- /configs/combine_datasets.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | annotations_path: datasets/annotations 6 | 7 | datasets: 8 | - humanml3d 9 | - kitml 10 | 11 | test_sets: 12 | - humanml3d 13 | - kitml 14 | 15 | filter_babel_seg: False 16 | 17 | split_suffix: '' 18 | 19 | min_duration: null 20 | max_duration: null 21 | -------------------------------------------------------------------------------- /configs/compute_guoh3dfeats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | base_folder: datasets/motions/pose_data 6 | output_folder: datasets/motions/guoh3dfeats 7 | 8 | force_redo: false # true to recompute the features 9 | -------------------------------------------------------------------------------- /configs/data/_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - motion_loader: guoh3dfeats 3 | - _self_ 4 | 5 | _target_: src.data.text_motion.TextMotionDataset 6 | 7 | path: datasets/annotations/${hydra:runtime.choices.data} 8 | 9 | text_to_token_emb: 10 | _target_: src.data.text.TokenEmbeddings 11 | path: datasets/annotations/${hydra:runtime.choices.data} 12 | modelname: distilbert-base-uncased 13 | modelpath: null 14 | preload: true 15 | 16 | text_to_sent_emb: 17 | _target_: src.data.text.SentenceEmbeddings 18 | path: datasets/annotations/${hydra:runtime.choices.data} 19 | modelname: sentence-transformers/all-mpnet-base-v2 20 | modelpath: null 21 | preload: true 22 | 23 | preload: true 24 | -------------------------------------------------------------------------------- /configs/data/_base_augmented.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - motion_loader: guoh3dfeats 3 | - _base 4 | - _self_ 5 | 6 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset 7 | 8 | paraphrase_filename: annotations_paraphrased.json 9 | summary_filename: annotations_summarized.json 10 | paraphrase_prob: 0.2 11 | summary_prob: 0.2 12 | averaging_prob: 0.4 13 | text_sampling_nbr: null 14 | 15 | text_to_token_emb: 16 | name: token_embeddings_all # TODO 17 | 18 | text_to_sent_emb: 19 | name: sent_embeddings_all 20 | -------------------------------------------------------------------------------- /configs/data/babel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/babel_actions_120.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/babel_actions_60.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d_kitml.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/humanml3d_kitml_babel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/kitml.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | -------------------------------------------------------------------------------- /configs/data/motion_loader/_base.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.motion.AMASSMotionLoader 2 | 3 | base_dir: ??? 4 | 5 | normalizer: 6 | _target_: src.data.motion.Normalizer 7 | base_dir: stats/${hydra:runtime.choices.data}/${hydra:runtime.choices.data/motion_loader} 8 | eps: 1e-12 9 | -------------------------------------------------------------------------------- /configs/data/motion_loader/guoh3dfeats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _base 3 | - _self_ 4 | 5 | base_dir: datasets/motions/guoh3dfeats 6 | fps: 20.0 7 | nfeats: 263 8 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: debug/ 4 | 5 | trainer: 6 | max_epochs: 1 7 | check_val_every_n_epoch: 1 8 | callbacks: null 9 | profiler: simple 10 | -------------------------------------------------------------------------------- /configs/debug/train.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: debug/ 4 | 5 | data: 6 | tiny: true 7 | preload: false 8 | 9 | dataloader: 10 | num_workers: 0 11 | shuffle: false 12 | 13 | trainer: 14 | enable_model_summary: false 15 | max_epochs: 3 16 | check_val_every_n_epoch: 1 17 | -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | run_dir: logs 4 | 5 | hydra: 6 | run: 7 | dir: ${run_dir} 8 | 9 | seed: 1234 10 | logger_level: INFO 11 | 12 | 13 | defaults: 14 | - _self_ 15 | - override hydra/job_logging: tqdm 16 | - override hydra/hydra_logging: tqdm 17 | -------------------------------------------------------------------------------- /configs/encode_dataset.yaml: -------------------------------------------------------------------------------- 1 | dataloader: 2 | _target_: torch.utils.data.DataLoader 3 | batch_size: 32 4 | num_workers: 8 5 | shuffle: true 6 | 7 | defaults: 8 | - data: humanml3d 9 | - defaults 10 | - _self_ 11 | 12 | run_dir: ??? 13 | 14 | ckpt_name: last 15 | device: cuda 16 | -------------------------------------------------------------------------------- /configs/encode_motion.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | npy: ??? 7 | 8 | start: !!null 9 | end: !!null 10 | 11 | ckpt_name: last 12 | device: cuda 13 | -------------------------------------------------------------------------------- /configs/encode_text.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | text: ??? 7 | 8 | ckpt_name: last 9 | device: cuda 10 | -------------------------------------------------------------------------------- /configs/extract.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | ckpt: last 7 | -------------------------------------------------------------------------------- /configs/hydra/hydra_logging/tqdm.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | verysimple: 5 | format: '%(message)s' 6 | 7 | handlers: 8 | console: 9 | class: src.logging.TqdmLoggingHandler 10 | formatter: verysimple 11 | 12 | root: 13 | level: ${logger_level} 14 | handlers: [console] 15 | 16 | 17 | disable_existing_loggers: false 18 | -------------------------------------------------------------------------------- /configs/hydra/job_logging/tqdm.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | simple: 5 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 6 | datefmt: '%d/%m/%y %H:%M:%S' 7 | 8 | colorlog: 9 | (): colorlog.ColoredFormatter 10 | format: '[%(white)s%(asctime)s%(reset)s] %(log_color)s%(levelname)s%(reset)s %(message)s' 11 | datefmt: '%d/%m/%y %H:%M:%S' 12 | 13 | log_colors: 14 | DEBUG: purple 15 | INFO: blue 16 | WARNING: yellow 17 | ERROR: red 18 | CRITICAL: red 19 | 20 | handlers: 21 | console: 22 | class: src.logging.TqdmLoggingHandler 23 | formatter: colorlog 24 | file_out: 25 | class: logging.FileHandler 26 | formatter: simple 27 | filename: ${run_dir}/${hydra.job.name}.out 28 | 29 | root: 30 | level: ${logger_level} 31 | handlers: [console, file_out] 32 | 33 | disable_existing_loggers: false 34 | -------------------------------------------------------------------------------- /configs/load_model.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | ckpt: last 7 | device: cuda 8 | eval_mode: true 9 | -------------------------------------------------------------------------------- /configs/model/temos.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.TEMOS 2 | 3 | motion_encoder: 4 | _target_: src.model.ACTORStyleEncoder 5 | nfeats: ${data.motion_loader.nfeats} 6 | vae: true 7 | latent_dim: 256 8 | ff_size: 1024 9 | num_layers: 6 10 | num_heads: 4 11 | dropout: 0.1 12 | activation: gelu 13 | 14 | text_encoder: 15 | _target_: src.model.ACTORStyleEncoder 16 | nfeats: 768 17 | vae: true 18 | latent_dim: 256 19 | ff_size: 1024 20 | num_layers: 6 21 | num_heads: 4 22 | dropout: 0.1 23 | activation: gelu 24 | 25 | motion_decoder: 26 | _target_: src.model.ACTORStyleDecoder 27 | nfeats: ${data.motion_loader.nfeats} 28 | latent_dim: 256 29 | ff_size: 1024 30 | num_layers: 6 31 | num_heads: 4 32 | dropout: 0.1 33 | activation: gelu 34 | 35 | vae: true 36 | 37 | lmd: 38 | recons: 1.0 39 | latent: 1.0e-5 40 | kl: 1.0e-5 41 | 42 | lr: 1e-4 43 | -------------------------------------------------------------------------------- /configs/model/tmr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - temos 3 | - _self_ 4 | 5 | _target_: src.model.TMR 6 | 7 | lmd: 8 | recons: 1.0 9 | latent: 1.0e-5 10 | kl: 1.0e-5 11 | contrastive: 0.1 12 | 13 | lr: 1e-4 14 | threshold_selfsim_metrics: 0.95 15 | 16 | contrastive_loss: 17 | _target_: src.model.losses.InfoNCE_with_filtering 18 | temperature: 0.1 19 | threshold_selfsim: 0.80 20 | -------------------------------------------------------------------------------- /configs/model/tmr_hn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr 3 | - _self_ 4 | 5 | contrastive_loss: 6 | _target_: src.model.losses.HN_InfoNCE_with_filtering 7 | temperature: 0.1 8 | threshold_selfsim: 0.80 9 | alpha: 0.999 10 | beta: 0.5 11 | -------------------------------------------------------------------------------- /configs/model/tmr_text_averaging.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr 3 | - _self_ 4 | 5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging 6 | -------------------------------------------------------------------------------- /configs/model/tmr_text_averaging_hn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - tmr_hn 3 | - _self_ 4 | 5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging 6 | 7 | contrastive_loss: 8 | _target_: src.model.losses.HN_InfoNCE_with_filtering 9 | temperature: 0.1 10 | threshold_selfsim: 0.80 11 | alpha: 0.999 12 | beta: 0.5 13 | -------------------------------------------------------------------------------- /configs/motion_stats.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | data: 7 | preload: false 8 | motion_loader: 9 | normalizer: 10 | disable: true 11 | text_to_token_emb: 12 | disable: true 13 | text_to_sent_emb: 14 | disable: true 15 | -------------------------------------------------------------------------------- /configs/render.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - renderer: matplotlib 3 | - defaults 4 | - _self_ 5 | 6 | npy_path: ??? 7 | title: "" 8 | 9 | swap_axis: false 10 | guofeats: false 11 | rifkefeats: false 12 | 13 | renderer: 14 | canonicalize: true 15 | -------------------------------------------------------------------------------- /configs/renderer/matplotlib.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.renderer.matplotlib.MatplotlibRender 2 | 3 | jointstype: "guoh3djoints" 4 | fps: 20.0 5 | colors: ['black', 'magenta', 'red', 'green', 'blue'] 6 | figsize: 4 7 | canonicalize: true 8 | -------------------------------------------------------------------------------- /configs/retrieval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | - data: humanml3d 5 | 6 | device: cuda 7 | 8 | run_dir: ??? 9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data} 10 | protocol: all # (is all 4), normal (a), threshold (b), nsim (c), guo (d) 11 | threshold: 0.95 # threashold to compute (b) 12 | 13 | ckpt: last 14 | batch_size: 256 15 | 16 | split: test 17 | -------------------------------------------------------------------------------- /configs/retrieval_action_multi_labels.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: babel_actions_120 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | run_dir: ??? 9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data} 10 | 11 | ckpt: last 12 | batch_size: 256 13 | 14 | split: test 15 | 16 | data: 17 | _target_: src.data.text_motion_multi_labels.TextMotionMultiLabelsDataset 18 | tiny: False 19 | 20 | text_to_token_emb: 21 | name: token_embeddings 22 | 23 | text_to_sent_emb: 24 | name: sent_embeddings 25 | -------------------------------------------------------------------------------- /configs/text_dataset_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | run_dir: outputs/tmr_humanml3d 7 | feats: false 8 | text: ??? 9 | split: train 10 | 11 | ckpt_name: last 12 | device: cuda 13 | 14 | data: 15 | preload: false 16 | text_to_token_emb: 17 | preload: false 18 | disable: true 19 | text_to_sent_emb: 20 | preload: false 21 | disable: true 22 | -------------------------------------------------------------------------------- /configs/text_embeddings.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | annotations_filename: annotations.json 9 | output_folder_name_token: token_embeddings 10 | output_folder_name_sent: sent_embeddings 11 | -------------------------------------------------------------------------------- /configs/text_embeddings_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: humanml3d 3 | - defaults 4 | - _self_ 5 | 6 | device: cuda 7 | 8 | annotations_filename: annotations_all.json 9 | output_folder_name_token: token_embeddings_all 10 | output_folder_name_sent: sent_embeddings_all 11 | -------------------------------------------------------------------------------- /configs/text_motion_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - defaults 3 | - _self_ 4 | 5 | run_dir: ??? 6 | npy: ??? 7 | feats: false 8 | text: ??? 9 | 10 | ckpt_name: last 11 | device: cuda 12 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | ckpt: last 2 | resume_dir: null 3 | ckpt_path: null 4 | 5 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader} 6 | 7 | dataloader: 8 | _target_: torch.utils.data.DataLoader 9 | batch_size: 32 10 | num_workers: 8 11 | 12 | defaults: 13 | - data: humanml3d 14 | - data_val: null 15 | - model: tmr 16 | - trainer 17 | - defaults 18 | - _self_ 19 | -------------------------------------------------------------------------------- /configs/train_hn.yaml: -------------------------------------------------------------------------------- 1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader} 2 | 3 | defaults: 4 | - train 5 | - override /model: tmr_hn 6 | - _self_ 7 | -------------------------------------------------------------------------------- /configs/train_hn_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - train_with_augmentation 3 | - override model: tmr_text_averaging_hn 4 | - _self_ 5 | -------------------------------------------------------------------------------- /configs/train_with_augmentation.yaml: -------------------------------------------------------------------------------- 1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_augmented_${hydra:runtime.choices.data/motion_loader} 2 | 3 | defaults: 4 | - train 5 | - data_val: null 6 | - override data: humanml3d_kitml 7 | - override model: tmr_text_averaging 8 | - _self_ 9 | 10 | data: 11 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset 12 | paraphrase_filename: annotations_paraphrases.json 13 | summary_filename: annotations_actions.json 14 | paraphrase_prob: 0.2 15 | summary_prob: 0.1 16 | averaging_prob: 0.3 17 | preload: True 18 | text_sampling_nbr: null 19 | 20 | text_to_token_emb: 21 | name: token_embeddings_all 22 | 23 | text_to_sent_emb: 24 | name: sent_embeddings_all 25 | 26 | -------------------------------------------------------------------------------- /configs/trainer.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | _target_: pytorch_lightning.Trainer 3 | 4 | max_epochs: 500 5 | log_every_n_steps: 50 6 | num_sanity_val_steps: 0 7 | check_val_every_n_epoch: 1 8 | accelerator: gpu 9 | devices: 1 10 | 11 | callbacks: 12 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 13 | filename: latest-{epoch} 14 | every_n_epochs: 1 15 | save_top_k: 1 16 | save_last: true 17 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 18 | filename: latest-{epoch} 19 | monitor: step 20 | mode: max 21 | every_n_epochs: 100 22 | save_top_k: -1 23 | save_last: false 24 | - _target_: src.callback.progress.ProgressLogger 25 | precision: 3 26 | - _target_: src.callback.tqdmbar.TQDMProgressBar 27 | 28 | logger: 29 | _target_: src.logger.csv.CSVLogger 30 | save_dir: ${run_dir} 31 | name: logs 32 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/nsim_test.txt: -------------------------------------------------------------------------------- 1 | 007742 2 | 005935 3 | 008597 4 | 010546 5 | 005697 6 | 005668 7 | 000565 8 | 000119 9 | 008877 10 | 006058 11 | 009566 12 | 005180 13 | 012578 14 | 011493 15 | 005946 16 | 006630 17 | 008900 18 | 004601 19 | 002459 20 | 006186 21 | 000552 22 | 005674 23 | 004545 24 | 001052 25 | 002635 26 | 005672 27 | 011004 28 | 013440 29 | 009455 30 | 003463 31 | 000824 32 | 006549 33 | 007655 34 | 006762 35 | 012222 36 | 012655 37 | 012956 38 | 004973 39 | 013403 40 | 008730 41 | 003439 42 | 008824 43 | 008340 44 | 010823 45 | 007806 46 | 013898 47 | 004996 48 | 010384 49 | 004344 50 | 005048 51 | 001152 52 | 012568 53 | M008664 54 | M007889 55 | M000389 56 | M011343 57 | M012558 58 | M010392 59 | M014283 60 | M001538 61 | M011643 62 | M003677 63 | M011972 64 | M009880 65 | M013023 66 | M012399 67 | M002761 68 | M014109 69 | M004319 70 | M001648 71 | M013778 72 | M008383 73 | M000178 74 | M009148 75 | M006433 76 | M011569 77 | M001577 78 | M008275 79 | M012813 80 | M012084 81 | M009123 82 | M000179 83 | M012639 84 | M010671 85 | M008583 86 | M000972 87 | M008349 88 | M002824 89 | M003301 90 | M008490 91 | M003902 92 | M002252 93 | M008668 94 | M000903 95 | M003689 96 | M003373 97 | M010964 98 | M001193 99 | M006533 100 | M014384 101 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/test_tiny.txt: -------------------------------------------------------------------------------- 1 | 000000 2 | 000019 3 | 000021 4 | 000022 5 | 000026 6 | 000048 7 | 000055 8 | 000063 9 | 000066 10 | 000067 11 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/train_tiny.txt: -------------------------------------------------------------------------------- 1 | 000001 2 | 000002 3 | 000003 4 | 000004 5 | 000005 6 | 000006 7 | 000007 8 | 000008 9 | 000009 10 | 000010 11 | -------------------------------------------------------------------------------- /datasets/annotations/humanml3d/splits/val_tiny.txt: -------------------------------------------------------------------------------- 1 | 012698 2 | 012808 3 | 008646 4 | 013022 5 | 003172 6 | 008859 7 | 005095 8 | 012044 9 | 002345 10 | 008039 11 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/nsim_test.txt: -------------------------------------------------------------------------------- 1 | 00355 2 | 01496 3 | 01344 4 | 03107 5 | 02245 6 | 01109 7 | 01349 8 | 03128 9 | 02906 10 | 03861 11 | 03670 12 | 00989 13 | 02070 14 | 02971 15 | 01727 16 | 03045 17 | 02510 18 | 01391 19 | 01532 20 | 03771 21 | 03036 22 | 01472 23 | 03691 24 | 01463 25 | 00978 26 | 01639 27 | 02407 28 | 03917 29 | 01450 30 | 01854 31 | 00594 32 | 00736 33 | 03087 34 | 01440 35 | 02021 36 | 01444 37 | 03683 38 | 01367 39 | 00390 40 | 03577 41 | 01485 42 | 02148 43 | 03190 44 | 01223 45 | 03215 46 | 03098 47 | 02139 48 | 02435 49 | 03532 50 | M00355 51 | M01496 52 | M02751 53 | M01344 54 | M03107 55 | M02245 56 | M00452 57 | M02556 58 | M01109 59 | M01027 60 | M01349 61 | M03128 62 | M02906 63 | M03861 64 | M03670 65 | M00989 66 | M02070 67 | M02971 68 | M01727 69 | M03045 70 | M02510 71 | M01391 72 | M01532 73 | M03036 74 | M01472 75 | M03691 76 | M01463 77 | M02407 78 | M03917 79 | M01450 80 | M01854 81 | M00594 82 | M00736 83 | M03087 84 | M01440 85 | M02021 86 | M01444 87 | M03683 88 | M00568 89 | M01298 90 | M01491 91 | M01367 92 | M03577 93 | M01485 94 | M02148 95 | M03190 96 | M03215 97 | M03098 98 | M02139 99 | M02435 100 | M03532 101 | M00669 102 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/test.txt: -------------------------------------------------------------------------------- 1 | 00004 2 | 00010 3 | 00019 4 | 00033 5 | 00035 6 | 00036 7 | 00044 8 | 00053 9 | 00059 10 | 00074 11 | 00104 12 | 00105 13 | 00122 14 | 00135 15 | 00136 16 | 00157 17 | 00158 18 | 00163 19 | 00172 20 | 00176 21 | 00186 22 | 00192 23 | 00198 24 | 00200 25 | 00216 26 | 00218 27 | 00228 28 | 00231 29 | 00235 30 | 00248 31 | 00255 32 | 00280 33 | 00287 34 | 00292 35 | 00303 36 | 00309 37 | 00316 38 | 00337 39 | 00342 40 | 00344 41 | 00352 42 | 00355 43 | 00358 44 | 00373 45 | 00384 46 | 00390 47 | 00396 48 | 00410 49 | 00424 50 | 00439 51 | 00441 52 | 00451 53 | 00452 54 | 00453 55 | 00455 56 | 00463 57 | 00465 58 | 00470 59 | 00507 60 | 00508 61 | 00509 62 | 00515 63 | 00522 64 | 00541 65 | 00543 66 | 00545 67 | 00547 68 | 00555 69 | 00556 70 | 00568 71 | 00578 72 | 00591 73 | 00594 74 | 00597 75 | 00648 76 | 00650 77 | 00651 78 | 00659 79 | 00669 80 | 00670 81 | 00675 82 | 00686 83 | 00687 84 | 00695 85 | 00704 86 | 00736 87 | 00755 88 | 00767 89 | 00771 90 | 00777 91 | 00784 92 | 00815 93 | 00842 94 | 00843 95 | 00846 96 | 00849 97 | 00868 98 | 00874 99 | 00876 100 | 00877 101 | 00881 102 | 00896 103 | 00902 104 | 00912 105 | 00931 106 | 00932 107 | 00947 108 | 00949 109 | 00958 110 | 00965 111 | 00971 112 | 00975 113 | 00978 114 | 00980 115 | 00989 116 | 00998 117 | 01004 118 | 01012 119 | 01021 120 | 01023 121 | 01026 122 | 01027 123 | 01029 124 | 01043 125 | 01044 126 | 01060 127 | 01061 128 | 01066 129 | 01067 130 | 01071 131 | 01074 132 | 01081 133 | 01108 134 | 01109 135 | 01115 136 | 01127 137 | 01128 138 | 01136 139 | 01139 140 | 01140 141 | 01145 142 | 01154 143 | 01168 144 | 01183 145 | 01194 146 | 01196 147 | 01214 148 | 01223 149 | 01240 150 | 01241 151 | 01259 152 | 01290 153 | 01295 154 | 01298 155 | 01309 156 | 01312 157 | 01315 158 | 01317 159 | 01320 160 | 01323 161 | 01328 162 | 01329 163 | 01332 164 | 01344 165 | 01349 166 | 01353 167 | 01367 168 | 01381 169 | 01391 170 | 01399 171 | 01406 172 | 01416 173 | 01437 174 | 01440 175 | 01444 176 | 01450 177 | 01463 178 | 01472 179 | 01477 180 | 01485 181 | 01487 182 | 01491 183 | 01496 184 | 01519 185 | 01525 186 | 01528 187 | 01532 188 | 01536 189 | 01548 190 | 01550 191 | 01563 192 | 01571 193 | 01640 194 | 01670 195 | 01671 196 | 01680 197 | 01681 198 | 01710 199 | 01724 200 | 01727 201 | 01732 202 | 01747 203 | 01751 204 | 01761 205 | 01763 206 | 01770 207 | 01780 208 | 01782 209 | 01802 210 | 01806 211 | 01808 212 | 01822 213 | 01824 214 | 01831 215 | 01832 216 | 01842 217 | 01852 218 | 01854 219 | 01861 220 | 01868 221 | 01874 222 | 01904 223 | 01908 224 | 01917 225 | 01924 226 | 01928 227 | 01941 228 | 01944 229 | 01950 230 | 01954 231 | 01963 232 | 01969 233 | 01970 234 | 01974 235 | 01978 236 | 01979 237 | 01997 238 | 01998 239 | 02000 240 | 02007 241 | 02011 242 | 02015 243 | 02021 244 | 02026 245 | 02038 246 | 02060 247 | 02070 248 | 02080 249 | 02083 250 | 02084 251 | 02107 252 | 02115 253 | 02116 254 | 02122 255 | 02125 256 | 02139 257 | 02140 258 | 02148 259 | 02157 260 | 02160 261 | 02165 262 | 02168 263 | 02180 264 | 02181 265 | 02193 266 | 02209 267 | 02234 268 | 02239 269 | 02242 270 | 02243 271 | 02245 272 | 02247 273 | 02248 274 | 02253 275 | 02278 276 | 02283 277 | 02285 278 | 02292 279 | 02298 280 | 02311 281 | 02328 282 | 02329 283 | 02339 284 | 02342 285 | 02344 286 | 02346 287 | 02364 288 | 02373 289 | 02394 290 | 02405 291 | 02407 292 | 02420 293 | 02427 294 | 02435 295 | 02440 296 | 02441 297 | 02449 298 | 02470 299 | 02480 300 | 02483 301 | 02510 302 | 02523 303 | 02542 304 | 02553 305 | 02556 306 | 02574 307 | 02581 308 | 02582 309 | 02588 310 | 02594 311 | 02644 312 | 02666 313 | 02691 314 | 02708 315 | 02722 316 | 02749 317 | 02751 318 | 02826 319 | 02830 320 | 02866 321 | 02888 322 | 02891 323 | 02895 324 | 02903 325 | 02906 326 | 02934 327 | 02945 328 | 02964 329 | 02967 330 | 02971 331 | 02979 332 | 03036 333 | 03045 334 | 03082 335 | 03087 336 | 03098 337 | 03107 338 | 03121 339 | 03128 340 | 03151 341 | 03188 342 | 03190 343 | 03194 344 | 03215 345 | 03224 346 | 03228 347 | 03238 348 | 03240 349 | 03246 350 | 03248 351 | 03252 352 | 03277 353 | 03278 354 | 03281 355 | 03293 356 | 03297 357 | 03314 358 | 03381 359 | 03382 360 | 03396 361 | 03404 362 | 03411 363 | 03441 364 | 03459 365 | 03481 366 | 03489 367 | 03495 368 | 03577 369 | 03593 370 | 03614 371 | 03632 372 | 03643 373 | 03670 374 | 03683 375 | 03685 376 | 03691 377 | 03695 378 | 03771 379 | 03785 380 | 03787 381 | 03788 382 | 03795 383 | 03798 384 | 03825 385 | 03830 386 | 03861 387 | 03866 388 | 03879 389 | 03917 390 | 03930 391 | 03939 392 | 03944 393 | 03964 394 | M00004 395 | M00010 396 | M00019 397 | M00033 398 | M00035 399 | M00036 400 | M00044 401 | M00053 402 | M00059 403 | M00074 404 | M00104 405 | M00105 406 | M00122 407 | M00135 408 | M00136 409 | M00157 410 | M00158 411 | M00163 412 | M00172 413 | M00176 414 | M00186 415 | M00192 416 | M00198 417 | M00200 418 | M00216 419 | M00218 420 | M00228 421 | M00231 422 | M00235 423 | M00248 424 | M00255 425 | M00280 426 | M00287 427 | M00292 428 | M00303 429 | M00309 430 | M00316 431 | M00337 432 | M00342 433 | M00344 434 | M00352 435 | M00355 436 | M00358 437 | M00373 438 | M00384 439 | M00390 440 | M00396 441 | M00410 442 | M00424 443 | M00439 444 | M00441 445 | M00451 446 | M00452 447 | M00453 448 | M00455 449 | M00463 450 | M00465 451 | M00470 452 | M00507 453 | M00508 454 | M00509 455 | M00515 456 | M00522 457 | M00541 458 | M00543 459 | M00545 460 | M00547 461 | M00555 462 | M00556 463 | M00568 464 | M00578 465 | M00591 466 | M00594 467 | M00597 468 | M00648 469 | M00650 470 | M00651 471 | M00659 472 | M00669 473 | M00670 474 | M00675 475 | M00686 476 | M00687 477 | M00695 478 | M00704 479 | M00736 480 | M00755 481 | M00767 482 | M00771 483 | M00777 484 | M00784 485 | M00815 486 | M00842 487 | M00843 488 | M00846 489 | M00849 490 | M00868 491 | M00874 492 | M00876 493 | M00877 494 | M00881 495 | M00896 496 | M00902 497 | M00912 498 | M00931 499 | M00932 500 | M00947 501 | M00949 502 | M00958 503 | M00965 504 | M00971 505 | M00975 506 | M00978 507 | M00980 508 | M00989 509 | M00998 510 | M01004 511 | M01012 512 | M01021 513 | M01023 514 | M01026 515 | M01027 516 | M01029 517 | M01043 518 | M01044 519 | M01060 520 | M01061 521 | M01066 522 | M01067 523 | M01071 524 | M01074 525 | M01081 526 | M01108 527 | M01109 528 | M01115 529 | M01127 530 | M01128 531 | M01136 532 | M01139 533 | M01140 534 | M01145 535 | M01154 536 | M01168 537 | M01183 538 | M01194 539 | M01196 540 | M01214 541 | M01223 542 | M01240 543 | M01241 544 | M01259 545 | M01290 546 | M01295 547 | M01298 548 | M01309 549 | M01312 550 | M01315 551 | M01317 552 | M01320 553 | M01323 554 | M01328 555 | M01329 556 | M01332 557 | M01344 558 | M01349 559 | M01353 560 | M01367 561 | M01381 562 | M01391 563 | M01399 564 | M01406 565 | M01416 566 | M01437 567 | M01440 568 | M01444 569 | M01450 570 | M01463 571 | M01472 572 | M01477 573 | M01485 574 | M01487 575 | M01491 576 | M01496 577 | M01519 578 | M01525 579 | M01528 580 | M01532 581 | M01536 582 | M01548 583 | M01550 584 | M01563 585 | M01571 586 | M01640 587 | M01670 588 | M01671 589 | M01680 590 | M01681 591 | M01710 592 | M01724 593 | M01727 594 | M01732 595 | M01747 596 | M01751 597 | M01761 598 | M01763 599 | M01770 600 | M01780 601 | M01782 602 | M01802 603 | M01806 604 | M01808 605 | M01822 606 | M01824 607 | M01831 608 | M01832 609 | M01842 610 | M01852 611 | M01854 612 | M01861 613 | M01868 614 | M01874 615 | M01904 616 | M01908 617 | M01917 618 | M01924 619 | M01928 620 | M01941 621 | M01944 622 | M01950 623 | M01954 624 | M01963 625 | M01969 626 | M01970 627 | M01974 628 | M01978 629 | M01979 630 | M01997 631 | M01998 632 | M02000 633 | M02007 634 | M02011 635 | M02015 636 | M02021 637 | M02026 638 | M02038 639 | M02060 640 | M02070 641 | M02080 642 | M02083 643 | M02084 644 | M02107 645 | M02115 646 | M02116 647 | M02122 648 | M02125 649 | M02139 650 | M02140 651 | M02148 652 | M02157 653 | M02160 654 | M02165 655 | M02168 656 | M02180 657 | M02181 658 | M02193 659 | M02209 660 | M02234 661 | M02239 662 | M02242 663 | M02243 664 | M02245 665 | M02247 666 | M02248 667 | M02253 668 | M02278 669 | M02283 670 | M02285 671 | M02292 672 | M02298 673 | M02311 674 | M02328 675 | M02329 676 | M02339 677 | M02342 678 | M02344 679 | M02346 680 | M02364 681 | M02373 682 | M02394 683 | M02405 684 | M02407 685 | M02420 686 | M02427 687 | M02435 688 | M02440 689 | M02441 690 | M02449 691 | M02470 692 | M02480 693 | M02483 694 | M02510 695 | M02523 696 | M02542 697 | M02553 698 | M02556 699 | M02574 700 | M02581 701 | M02582 702 | M02588 703 | M02594 704 | M02644 705 | M02666 706 | M02691 707 | M02708 708 | M02722 709 | M02749 710 | M02751 711 | M02826 712 | M02830 713 | M02866 714 | M02888 715 | M02891 716 | M02895 717 | M02903 718 | M02906 719 | M02934 720 | M02945 721 | M02964 722 | M02967 723 | M02971 724 | M02979 725 | M03036 726 | M03045 727 | M03082 728 | M03087 729 | M03098 730 | M03107 731 | M03121 732 | M03128 733 | M03151 734 | M03188 735 | M03190 736 | M03194 737 | M03215 738 | M03224 739 | M03228 740 | M03238 741 | M03240 742 | M03246 743 | M03248 744 | M03252 745 | M03277 746 | M03278 747 | M03281 748 | M03293 749 | M03297 750 | M03314 751 | M03381 752 | M03382 753 | M03396 754 | M03404 755 | M03411 756 | M03441 757 | M03459 758 | M03481 759 | M03489 760 | M03495 761 | M03577 762 | M03593 763 | M03614 764 | M03632 765 | M03643 766 | M03670 767 | M03683 768 | M03685 769 | M03691 770 | M03695 771 | M03771 772 | M03785 773 | M03787 774 | M03788 775 | M03795 776 | M03798 777 | M03825 778 | M03830 779 | M03861 780 | M03866 781 | M03879 782 | M03917 783 | M03930 784 | M03939 785 | M03944 786 | M03964 787 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/test_tiny.txt: -------------------------------------------------------------------------------- 1 | 00004 2 | 00010 3 | 00019 4 | 00033 5 | 00035 6 | 00036 7 | 00044 8 | 00053 9 | 00059 10 | 00074 11 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/train_tiny.txt: -------------------------------------------------------------------------------- 1 | 00001 2 | 00002 3 | 00003 4 | 00005 5 | 00007 6 | 00008 7 | 00009 8 | 00011 9 | 00013 10 | 00014 11 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/val.txt: -------------------------------------------------------------------------------- 1 | 00006 2 | 00015 3 | 00025 4 | 00052 5 | 00068 6 | 00092 7 | 00093 8 | 00126 9 | 00143 10 | 00170 11 | 00183 12 | 00222 13 | 00244 14 | 00254 15 | 00270 16 | 00271 17 | 00306 18 | 00311 19 | 00331 20 | 00332 21 | 00374 22 | 00427 23 | 00432 24 | 00476 25 | 00491 26 | 00498 27 | 00510 28 | 00539 29 | 00549 30 | 00582 31 | 00589 32 | 00605 33 | 00612 34 | 00646 35 | 00685 36 | 00690 37 | 00691 38 | 00709 39 | 00728 40 | 00733 41 | 00754 42 | 00764 43 | 00834 44 | 00841 45 | 00873 46 | 00897 47 | 00905 48 | 00908 49 | 00909 50 | 00918 51 | 00922 52 | 00943 53 | 00954 54 | 00964 55 | 00967 56 | 00981 57 | 00983 58 | 00988 59 | 01001 60 | 01006 61 | 01016 62 | 01030 63 | 01070 64 | 01073 65 | 01092 66 | 01097 67 | 01118 68 | 01123 69 | 01176 70 | 01177 71 | 01209 72 | 01264 73 | 01282 74 | 01299 75 | 01319 76 | 01369 77 | 01407 78 | 01409 79 | 01412 80 | 01418 81 | 01428 82 | 01429 83 | 01443 84 | 01447 85 | 01479 86 | 01483 87 | 01494 88 | 01503 89 | 01655 90 | 01664 91 | 01700 92 | 01703 93 | 01745 94 | 01762 95 | 01794 96 | 01795 97 | 01809 98 | 01825 99 | 01834 100 | 01836 101 | 01877 102 | 01905 103 | 01931 104 | 01948 105 | 01975 106 | 01992 107 | 01995 108 | 02048 109 | 02110 110 | 02118 111 | 02121 112 | 02171 113 | 02184 114 | 02315 115 | 02322 116 | 02418 117 | 02499 118 | 02593 119 | 02598 120 | 02797 121 | 02904 122 | 02922 123 | 02940 124 | 02947 125 | 03009 126 | 03155 127 | 03227 128 | 03232 129 | 03262 130 | 03276 131 | 03284 132 | 03290 133 | 03365 134 | 03419 135 | 03462 136 | 03480 137 | 03631 138 | 03682 139 | 03708 140 | 03722 141 | 03732 142 | 03813 143 | 03817 144 | 03832 145 | 03877 146 | 03899 147 | M00006 148 | M00015 149 | M00025 150 | M00052 151 | M00068 152 | M00092 153 | M00093 154 | M00126 155 | M00143 156 | M00170 157 | M00183 158 | M00222 159 | M00244 160 | M00254 161 | M00270 162 | M00271 163 | M00306 164 | M00311 165 | M00331 166 | M00332 167 | M00374 168 | M00427 169 | M00432 170 | M00476 171 | M00491 172 | M00498 173 | M00510 174 | M00539 175 | M00549 176 | M00582 177 | M00589 178 | M00605 179 | M00612 180 | M00646 181 | M00685 182 | M00690 183 | M00691 184 | M00709 185 | M00728 186 | M00733 187 | M00754 188 | M00764 189 | M00834 190 | M00841 191 | M00873 192 | M00897 193 | M00905 194 | M00908 195 | M00909 196 | M00918 197 | M00922 198 | M00943 199 | M00954 200 | M00964 201 | M00967 202 | M00981 203 | M00983 204 | M00988 205 | M01001 206 | M01006 207 | M01016 208 | M01030 209 | M01070 210 | M01073 211 | M01092 212 | M01097 213 | M01118 214 | M01123 215 | M01176 216 | M01177 217 | M01209 218 | M01264 219 | M01282 220 | M01299 221 | M01319 222 | M01369 223 | M01407 224 | M01409 225 | M01412 226 | M01418 227 | M01428 228 | M01429 229 | M01443 230 | M01447 231 | M01479 232 | M01483 233 | M01494 234 | M01503 235 | M01655 236 | M01664 237 | M01700 238 | M01703 239 | M01745 240 | M01762 241 | M01794 242 | M01795 243 | M01809 244 | M01825 245 | M01834 246 | M01836 247 | M01877 248 | M01905 249 | M01931 250 | M01948 251 | M01975 252 | M01992 253 | M01995 254 | M02048 255 | M02110 256 | M02118 257 | M02121 258 | M02171 259 | M02184 260 | M02315 261 | M02322 262 | M02418 263 | M02499 264 | M02593 265 | M02598 266 | M02797 267 | M02904 268 | M02922 269 | M02940 270 | M02947 271 | M03009 272 | M03155 273 | M03227 274 | M03232 275 | M03262 276 | M03276 277 | M03284 278 | M03290 279 | M03365 280 | M03419 281 | M03462 282 | M03480 283 | M03631 284 | M03682 285 | M03708 286 | M03722 287 | M03732 288 | M03813 289 | M03817 290 | M03832 291 | M03877 292 | M03899 293 | -------------------------------------------------------------------------------- /datasets/annotations/kitml/splits/val_tiny.txt: -------------------------------------------------------------------------------- 1 | 00006 2 | 00015 3 | 00025 4 | 00052 5 | 00068 6 | 00092 7 | 00093 8 | 00126 9 | 00143 10 | 00170 11 | -------------------------------------------------------------------------------- /demo/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import orjson 4 | import codecs as cs 5 | import torch 6 | 7 | 8 | def load_json(json_path): 9 | with open(json_path, "rb") as ff: 10 | return orjson.loads(ff.read()) 11 | 12 | 13 | def load_unit_embeddings(run_dir, dataset, device="cpu"): 14 | save_dir = os.path.join(run_dir, "latents") 15 | unit_emb_path = os.path.join(save_dir, f"{dataset}_all_unit.npy") 16 | motion_embs = torch.from_numpy(np.load(unit_emb_path)).to(device) 17 | 18 | # Loading the correspondance 19 | keyids_index = load_json(os.path.join(save_dir, f"{dataset}_keyids_index_all.json")) 20 | index_keyids = load_json(os.path.join(save_dir, f"{dataset}_index_keyids_all.json")) 21 | 22 | return motion_embs, keyids_index, index_keyids 23 | 24 | 25 | def load_split(path, split): 26 | split_file = os.path.join(path, "splits", split + ".txt") 27 | id_list = [] 28 | with cs.open(split_file, "r") as f: 29 | for line in f.readlines(): 30 | id_list.append(line.strip()) 31 | return id_list 32 | 33 | 34 | def load_splits(dataset, splits=["test", "all"]): 35 | path = f"datasets/annotations/{dataset}" 36 | return {split: load_split(path, split) for split in splits} 37 | -------------------------------------------------------------------------------- /demo/model.py: -------------------------------------------------------------------------------- 1 | # Text model + TMR text encoder only 2 | 3 | from typing import List 4 | import torch.nn as nn 5 | import os 6 | 7 | import torch 8 | import numpy as np 9 | from torch import Tensor 10 | from transformers import AutoTokenizer, AutoModel 11 | from torch.nn.functional import normalize 12 | from einops import repeat 13 | import json 14 | import warnings 15 | 16 | import logging 17 | 18 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator") 19 | logger.setLevel(logging.ERROR) 20 | 21 | 22 | warnings.filterwarnings( 23 | "ignore", "The PyTorch API of nested tensors is in prototype stage*" 24 | ) 25 | 26 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*") 27 | 28 | torch.set_float32_matmul_precision("high") 29 | 30 | 31 | class PositionalEncoding(nn.Module): 32 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: 33 | super().__init__() 34 | self.batch_first = batch_first 35 | 36 | self.dropout = nn.Dropout(p=dropout) 37 | 38 | pe = torch.zeros(max_len, d_model) 39 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 40 | div_term = torch.exp( 41 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) 42 | ) 43 | pe[:, 0::2] = torch.sin(position * div_term) 44 | pe[:, 1::2] = torch.cos(position * div_term) 45 | pe = pe.unsqueeze(0).transpose(0, 1) 46 | self.register_buffer("pe", pe, persistent=False) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | if self.batch_first: 50 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] 51 | else: 52 | x = x + self.pe[: x.shape[0], :] 53 | return self.dropout(x) 54 | 55 | 56 | def read_config(run_dir: str): 57 | path = os.path.join(run_dir, "config.json") 58 | with open(path, "r") as f: 59 | config = json.load(f) 60 | return config 61 | 62 | 63 | class TMR_text_encoder(nn.Module): 64 | def __init__(self, run_dir: str) -> None: 65 | config = read_config(run_dir) 66 | modelpath = config["data"]["text_to_token_emb"]["modelname"] 67 | 68 | text_encoder_conf = config["model"]["text_encoder"] 69 | 70 | vae = text_encoder_conf["vae"] 71 | latent_dim = text_encoder_conf["latent_dim"] 72 | ff_size = text_encoder_conf["ff_size"] 73 | num_layers = text_encoder_conf["num_layers"] 74 | num_heads = text_encoder_conf["num_heads"] 75 | activation = text_encoder_conf["activation"] 76 | nfeats = text_encoder_conf["nfeats"] 77 | 78 | super().__init__() 79 | 80 | # Projection of the text-outputs into the latent space 81 | self.projection = nn.Linear(nfeats, latent_dim) 82 | self.vae = vae 83 | self.nbtokens = 2 if vae else 1 84 | 85 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) 86 | self.sequence_pos_encoding = PositionalEncoding( 87 | latent_dim, dropout=0.0, batch_first=True 88 | ) 89 | 90 | seq_trans_encoder_layer = nn.TransformerEncoderLayer( 91 | d_model=latent_dim, 92 | nhead=num_heads, 93 | dim_feedforward=ff_size, 94 | dropout=0.0, 95 | activation=activation, 96 | batch_first=True, 97 | ) 98 | 99 | self.seqTransEncoder = nn.TransformerEncoder( 100 | seq_trans_encoder_layer, num_layers=num_layers 101 | ) 102 | 103 | text_encoder_pt_path = os.path.join(run_dir, "last_weights/text_encoder.pt") 104 | state_dict = torch.load(text_encoder_pt_path) 105 | self.load_state_dict(state_dict) 106 | 107 | from transformers import logging 108 | 109 | # load text model 110 | logging.set_verbosity_error() 111 | 112 | # Tokenizer 113 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 114 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 115 | 116 | # Text model 117 | self.text_model = AutoModel.from_pretrained(modelpath) 118 | # Then configure the model 119 | self.text_encoded_dim = self.text_model.config.hidden_size 120 | self.eval() 121 | 122 | def get_last_hidden_state(self, texts: List[str], return_mask: bool = False): 123 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 124 | output = self.text_model(**encoded_inputs.to(self.text_model.device)) 125 | if not return_mask: 126 | return output.last_hidden_state 127 | return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool) 128 | 129 | def forward(self, texts: List[str]) -> Tensor: 130 | text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True) 131 | 132 | x = self.projection(text_encoded) 133 | 134 | device = x.device 135 | bs = len(x) 136 | 137 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) 138 | xseq = torch.cat((tokens, x), 1) 139 | 140 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) 141 | aug_mask = torch.cat((token_mask, mask), 1) 142 | 143 | # add positional encoding 144 | xseq = self.sequence_pos_encoding(xseq) 145 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 146 | return final[:, 0] 147 | 148 | # compute score for retrieval 149 | def compute_scores(self, texts, unit_embs=None, embs=None): 150 | # not both empty 151 | assert not (unit_embs is None and embs is None) 152 | # not both filled 153 | assert not (unit_embs is not None and embs is not None) 154 | 155 | output_str = False 156 | # if one input, squeeze the output 157 | if isinstance(texts, str): 158 | texts = [texts] 159 | output_str = True 160 | 161 | # compute unit_embs from embs if not given 162 | if embs is not None: 163 | unit_embs = normalize(embs) 164 | 165 | with torch.no_grad(): 166 | latent_unit_texts = normalize(self(texts)) 167 | # compute cosine similarity between 0 and 1 168 | scores = (unit_embs @ latent_unit_texts.T).T / 2 + 0.5 169 | scores = scores.cpu().numpy() 170 | 171 | if output_str: 172 | scores = scores[0] 173 | 174 | return scores 175 | -------------------------------------------------------------------------------- /encode_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import json 6 | from hydra.core.hydra_config import HydraConfig 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def x_dict_to_device(x_dict, device): 13 | import torch 14 | 15 | for key, val in x_dict.items(): 16 | if isinstance(val, torch.Tensor): 17 | x_dict[key] = val.to(device) 18 | return x_dict 19 | 20 | 21 | def write_json(data, path): 22 | with open(path, "w") as ff: 23 | ff.write(json.dumps(data, indent=4)) 24 | 25 | 26 | @hydra.main(version_base=None, config_path="configs", config_name="encode_dataset") 27 | def encode_dataset(cfg: DictConfig) -> None: 28 | device = cfg.device 29 | run_dir = cfg.run_dir 30 | ckpt_name = cfg.ckpt_name 31 | cfg_data = cfg.data 32 | 33 | choices = HydraConfig.get().runtime.choices 34 | data_name = choices.data 35 | 36 | import src.prepare # noqa 37 | import torch 38 | import numpy as np 39 | from src.config import read_config 40 | from src.load import load_model_from_cfg 41 | from hydra.utils import instantiate 42 | from pytorch_lightning import seed_everything 43 | 44 | cfg = read_config(run_dir) 45 | 46 | logger.info("Loading the model") 47 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 48 | 49 | save_dir = os.path.join(run_dir, "latents") 50 | os.makedirs(save_dir, exist_ok=True) 51 | 52 | dataset = instantiate(cfg_data, split="all") 53 | dataloader = instantiate( 54 | cfg.dataloader, 55 | dataset=dataset, 56 | collate_fn=dataset.collate_fn, 57 | shuffle=False, 58 | ) 59 | seed_everything(cfg.seed) 60 | 61 | all_latents = [] 62 | all_keyids = [] 63 | 64 | with torch.inference_mode(): 65 | for batch in dataloader: 66 | motion_x_dict = batch["motion_x_dict"] 67 | x_dict_to_device(motion_x_dict, device) 68 | latents = model.encode(motion_x_dict, sample_mean=True) 69 | all_latents.append(latents.cpu().numpy()) 70 | keyids = batch["keyid"] 71 | all_keyids.extend(keyids) 72 | 73 | latents = np.concatenate(all_latents) 74 | path = os.path.join(save_dir, f"{data_name}_all.npy") 75 | logger.info(f"Encoding the latents of all the splits in {path}") 76 | np.save(path, latents) 77 | 78 | path_unit = os.path.join(save_dir, f"{data_name}_all_unit.npy") 79 | logger.info(f"Encoding the unit latents of all the splits in {path_unit}") 80 | 81 | unit_latents = latents / np.linalg.norm(latents, axis=-1)[:, None] 82 | np.save(path_unit, unit_latents) 83 | 84 | # Writing the correspondance 85 | logger.info("Writing the correspondance files") 86 | keyids_index_path = os.path.join(save_dir, f"{data_name}_keyids_index_all.json") 87 | index_keyids_path = os.path.join(save_dir, f"{data_name}_index_keyids_all.json") 88 | 89 | keyids_index = {x: i for i, x in enumerate(all_keyids)} 90 | index_keyids = {i: x for i, x in enumerate(all_keyids)} 91 | 92 | write_json(keyids_index, keyids_index_path) 93 | write_json(index_keyids, index_keyids_path) 94 | 95 | 96 | if __name__ == "__main__": 97 | encode_dataset() 98 | -------------------------------------------------------------------------------- /encode_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_motion") 10 | def encode_motion(cfg: DictConfig) -> None: 11 | device = cfg.device 12 | run_dir = cfg.run_dir 13 | ckpt_name = cfg.ckpt_name 14 | npy_path = cfg.npy 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | 25 | cfg = read_config(run_dir) 26 | 27 | logger.info("Loading the model") 28 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 29 | normalizer = instantiate(cfg.data.motion_loader.normalizer) 30 | 31 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float) 32 | motion = normalizer(motion) 33 | motion = motion.to(device) 34 | 35 | motion_x_dict = {"x": motion, "length": len(motion)} 36 | 37 | seed_everything(cfg.seed) 38 | with torch.inference_mode(): 39 | motion_x_dict = collate_x_dict([motion_x_dict]) 40 | latent = model.encode(motion_x_dict, sample_mean=True)[0] 41 | latent = latent.cpu().numpy() 42 | 43 | fname = os.path.split(npy_path)[1] 44 | output_folder = os.path.join(run_dir, "encoded") 45 | os.makedirs(output_folder, exist_ok=True) 46 | path = os.path.join(output_folder, fname) 47 | 48 | np.save(path, latent) 49 | logger.info(f"Encoding done, latent saved in:\n{path}") 50 | 51 | 52 | if __name__ == "__main__": 53 | encode_motion() 54 | -------------------------------------------------------------------------------- /encode_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_text") 10 | def encode_text(cfg: DictConfig) -> None: 11 | device = cfg.device 12 | run_dir = cfg.run_dir 13 | ckpt_name = cfg.ckpt_name 14 | text = cfg.text 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | 25 | cfg = read_config(run_dir) 26 | 27 | logger.info("Loading the text model") 28 | text_model = instantiate(cfg.data.text_to_token_emb, device=device) 29 | 30 | logger.info("Loading the model") 31 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 32 | 33 | seed_everything(cfg.seed) 34 | with torch.inference_mode(): 35 | text_x_dict = collate_x_dict(text_model([text])) 36 | latent = model.encode(text_x_dict, sample_mean=True)[0] 37 | latent = latent.cpu().numpy() 38 | 39 | fname = text.lower().replace(" ", "_") + ".npy" 40 | 41 | output_folder = os.path.join(run_dir, "encoded") 42 | os.makedirs(output_folder, exist_ok=True) 43 | path = os.path.join(output_folder, fname) 44 | 45 | np.save(path, latent) 46 | logger.info(f"Encoding done, latent saved in:\n{path}") 47 | 48 | 49 | if __name__ == "__main__": 50 | encode_text() 51 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import hydra 3 | from omegaconf import DictConfig 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hydra.main(config_path="configs", config_name="extract", version_base="1.3") 9 | def extract(cfg: DictConfig): 10 | run_dir = cfg.run_dir 11 | ckpt = cfg.ckpt 12 | 13 | from src.load import extract_ckpt 14 | 15 | logger.info("Extracting the checkpoint...") 16 | extract_ckpt(run_dir, ckpt_name=ckpt) 17 | logger.info("Done") 18 | 19 | 20 | if __name__ == "__main__": 21 | extract() 22 | -------------------------------------------------------------------------------- /prepare/combine_datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import hydra 4 | import json 5 | import numpy as np 6 | import os 7 | from omegaconf import DictConfig 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | SUFFIX_DICT = {"humanml3d": "h", "kitml": "k", "babel": "b"} 13 | 14 | @hydra.main(config_path="../configs", config_name="combine_datasets", version_base="1.3") 15 | def combine_datasets(cfg: DictConfig): 16 | 17 | train_datasets = cfg.datasets 18 | annotations_folder_path = cfg.annotations_path 19 | 20 | combined_dataset_name = "_".join(train_datasets) 21 | combined_dataset_folder = os.path.join(annotations_folder_path, combined_dataset_name) 22 | os.makedirs(combined_dataset_folder, exist_ok=True) 23 | 24 | annotations = {} 25 | annotations_paraphrases = {} 26 | annotations_actions = {} 27 | 28 | annotations = {} 29 | annotations_paraphrases = {} 30 | annotations_actions = {} 31 | annotations_all = {} 32 | 33 | dataset_annotations = {} 34 | 35 | for dataset in train_datasets: 36 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json") 37 | with open(annotations_path) as f: 38 | d = json.load(f) 39 | dataset_annotations[dataset] = d 40 | d_new = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 41 | annotations.update(d_new) 42 | 43 | annotations_paraphrases_path = os.path.join(annotations_folder_path, dataset, "annotations_paraphrases.json") 44 | if os.path.exists(annotations_paraphrases_path): 45 | with open(annotations_paraphrases_path) as f: 46 | d = json.load(f) 47 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 48 | annotations_paraphrases.update(d) 49 | 50 | annotations_actions_path = os.path.join(annotations_folder_path, dataset, "annotations_actions.json") 51 | if os.path.exists(annotations_actions_path): 52 | with open(annotations_actions_path) as f: 53 | d = json.load(f) 54 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 55 | annotations_actions.update(d) 56 | 57 | annotations_all_path = os.path.join(annotations_folder_path, dataset, "annotations_all.json") 58 | if os.path.exists(annotations_all_path): 59 | with open(annotations_all_path) as f: 60 | d = json.load(f) 61 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()} 62 | annotations_all.update(d) 63 | 64 | with open(os.path.join(combined_dataset_folder, "annotations.json"), "w") as f: 65 | json.dump(annotations, f, indent=2) 66 | with open(os.path.join(combined_dataset_folder, "annotations_paraphrases.json"), "w") as f: 67 | json.dump(annotations_paraphrases, f, indent=2) 68 | with open(os.path.join(combined_dataset_folder, "annotations_actions.json"), "w") as f: 69 | json.dump(annotations_actions, f, indent=2) 70 | with open(os.path.join(combined_dataset_folder, "annotations_all.json"), "w") as f: 71 | json.dump(annotations_all, f, indent=2) 72 | 73 | test_datasets = cfg.test_sets 74 | 75 | for dataset in test_datasets: 76 | if dataset not in dataset_annotations: 77 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json") 78 | with open(annotations_path) as f: 79 | d = json.load(f) 80 | dataset_annotations[dataset] = d 81 | 82 | # Splits creation 83 | 84 | logger.info(f"Creating train split by combining train splits from: {', '.join(train_datasets)}") 85 | logger.info(f"Removing from the train splits samples overlapping with test split from: {', '.join(test_datasets)}") 86 | 87 | dataset_splits = {} 88 | 89 | splits = ["train", "val"] 90 | for dataset in train_datasets: 91 | if dataset not in dataset_splits: 92 | dataset_splits[dataset] = {} 93 | for split in splits: 94 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f: 95 | str_inds = f.read() 96 | inds = str_inds.split("\n") 97 | if inds[-1] == "": 98 | inds.pop(-1) 99 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds} 100 | dataset_splits[dataset][split] = dic_ind_path 101 | 102 | split = 'test' 103 | for dataset in test_datasets: 104 | if dataset not in dataset_splits: 105 | dataset_splits[dataset] = {} 106 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f: 107 | str_inds = f.read() 108 | inds = str_inds.split("\n") 109 | if inds[-1] == "": 110 | inds.pop(-1) 111 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds} 112 | dataset_splits[dataset][split] = dic_ind_path 113 | 114 | 115 | to_remove = {train_dataset: {test_dataset: {"train": [], "val": []} for test_dataset in test_datasets if test_dataset != train_dataset} for train_dataset in train_datasets} 116 | 117 | for train_dataset in train_datasets: 118 | 119 | for split in ["train", "val"]: 120 | for train_id, train_path in dataset_splits[train_dataset][split].items(): 121 | 122 | for test_dataset in set(test_datasets) - set([train_dataset]): 123 | 124 | for test_id, test_path in dataset_splits[test_dataset]["test"].items(): 125 | if train_path == test_path: 126 | 127 | if not cfg.filter_babel_seg: 128 | if test_dataset == "babel": 129 | test_duration = float(dataset_annotations[test_dataset][test_id]["duration"]) 130 | test_fragment_duration = float(dataset_annotations[test_dataset][test_id]["fragment_duration"]) 131 | 132 | if not np.isclose([test_duration], [test_fragment_duration], atol=0.1, rtol=0): 133 | continue 134 | 135 | if train_dataset == "babel": 136 | train_duration = float(dataset_annotations[train_dataset][train_id]["duration"]) 137 | train_fragment_duration = float(dataset_annotations[train_dataset][train_id]["fragment_duration"]) 138 | 139 | if not np.isclose([train_duration], [train_fragment_duration], atol=0.1, rtol=0): 140 | continue 141 | 142 | train_start = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["start"]) 143 | train_end = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["end"]) 144 | test_start = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["start"]) 145 | test_end = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["end"]) 146 | 147 | if not ((train_end <= test_start) or (test_end <= train_start)): 148 | to_remove[train_dataset][test_dataset][split].append(train_id) 149 | 150 | datasets_curated = {train_dataset: {split: list(dataset_splits[train_dataset][split].keys()) for split in dataset_splits[train_dataset].keys()} for train_dataset in train_datasets} 151 | 152 | for train_dataset in train_datasets: 153 | for test_dataset in set(test_datasets) - set([train_dataset]): 154 | for split in ["train", "val"]: 155 | for keyid in to_remove[train_dataset][test_dataset][split]: 156 | if keyid in datasets_curated[train_dataset][split]: 157 | datasets_curated[train_dataset][split].remove(keyid) 158 | 159 | splits_folder = os.path.join(annotations_folder_path, combined_dataset_name, "splits") 160 | os.makedirs(splits_folder, exist_ok=True) 161 | all_ids = [] 162 | for split in ["train", "val"]: 163 | ids = [] 164 | for train_dataset in train_datasets: 165 | ids += [f'{elt}_{SUFFIX_DICT[train_dataset]}' for elt in datasets_curated[train_dataset][split]] 166 | all_ids += ids 167 | ids_str = "\n".join(ids) 168 | filename = f"{split}{cfg.split_suffix}.txt" 169 | with open(os.path.join(splits_folder, filename), "w") as f: 170 | f.write(ids_str) 171 | 172 | all_ids_str = "\n".join(all_ids) 173 | with open(os.path.join(splits_folder, f"all{cfg.split_suffix}.txt"), "w") as f: 174 | f.write(all_ids_str) 175 | 176 | 177 | if __name__ == "__main__": 178 | combine_datasets() 179 | 180 | -------------------------------------------------------------------------------- /prepare/compute_guoh3dfeats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | import numpy as np 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def extract_h3d(feats): 12 | from einops import unpack 13 | 14 | root_data, ric_data, rot_data, local_vel, feet_l, feet_r = unpack( 15 | feats, [[4], [63], [126], [66], [2], [2]], "i *" 16 | ) 17 | return root_data, ric_data, rot_data, local_vel, feet_l, feet_r 18 | 19 | 20 | def swap_left_right(data): 21 | assert len(data.shape) == 3 and data.shape[-1] == 3 22 | data = data.copy() 23 | data[..., 0] *= -1 24 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21] 25 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20] 26 | left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30] 27 | right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51] 28 | tmp = data[:, right_chain] 29 | data[:, right_chain] = data[:, left_chain] 30 | data[:, left_chain] = tmp 31 | if data.shape[1] > 24: 32 | tmp = data[:, right_hand_chain] 33 | data[:, right_hand_chain] = data[:, left_hand_chain] 34 | data[:, left_hand_chain] = tmp 35 | return data 36 | 37 | 38 | @hydra.main( 39 | config_path="../configs", config_name="compute_guoh3dfeats", version_base="1.3" 40 | ) 41 | def compute_guoh3dfeats(cfg: DictConfig): 42 | base_folder = cfg.base_folder 43 | output_folder = cfg.output_folder 44 | force_redo = cfg.force_redo 45 | 46 | from src.guofeats import joints_to_guofeats 47 | from .tools import loop_amass 48 | 49 | output_folder_M = os.path.join(output_folder, "M") 50 | 51 | print("Get h3d features from Guo et al.") 52 | print("The processed motions will be stored in this folder:") 53 | print(output_folder) 54 | 55 | iterator = loop_amass( 56 | base_folder, output_folder, ext=".npy", newext=".npy", force_redo=force_redo 57 | ) 58 | 59 | for motion_path, new_motion_path in iterator: 60 | joints = np.load(motion_path) 61 | 62 | if "humanact12" not in motion_path: 63 | # This is because the authors of HumanML3D 64 | # save the motions by swapping Y and Z (det = -1) 65 | # which is not a proper rotation (det = 1) 66 | # so we should invert x, to make it a rotation 67 | # that is why the authors use "data[..., 0] *= -1" inside the "if" 68 | # before swapping left/right 69 | # https://github.com/EricGuo5513/HumanML3D/blob/main/raw_pose_processing.ipynb 70 | joints[..., 0] *= -1 71 | # the humanact12 motions are normally saved correctly, no need to swap again 72 | # (but in fact this may not be true and the orignal H3D features 73 | # corresponding to HumanAct12 appears to be left/right flipped..) 74 | # At least we are compatible with previous work :/ 75 | 76 | joints_m = swap_left_right(joints) 77 | 78 | # apply transformation 79 | try: 80 | features = joints_to_guofeats(joints) 81 | features_m = joints_to_guofeats(joints_m) 82 | except (IndexError, ValueError): 83 | # The sequence should be only 1 frame long 84 | # so we cannot compute features (which involve velocities etc) 85 | assert len(joints) == 1 86 | continue 87 | # save the features 88 | np.save(new_motion_path, features) 89 | 90 | # save the mirrored features as well 91 | new_motion_path_M = new_motion_path.replace(output_folder, output_folder_M) 92 | os.makedirs(os.path.split(new_motion_path_M)[0], exist_ok=True) 93 | np.save(new_motion_path_M, features_m) 94 | 95 | 96 | if __name__ == "__main__": 97 | compute_guoh3dfeats() 98 | -------------------------------------------------------------------------------- /prepare/download_pretrain_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo -e "The pretrained models will stored in the 'models' folder\n" 3 | mkdir -p models 4 | python -m gdown.cli "https://drive.google.com/uc?id=1n6kRb-d2gKsk8EXfFULFIpaUKYcnaYmm" 5 | 6 | echo -e "Please check that the md5sum is: 7b6d8814f9c1ca972f62852ebb6c7a6f" 7 | echo -e "+ md5sum tmr_models.tgz" 8 | md5sum tmr_models.tgz 9 | 10 | echo -e "If it is not, please rerun this script" 11 | 12 | sleep 5 13 | tar xfzv tmr_models.tgz 14 | 15 | echo -e "Cleaning\n" 16 | rm tmr_models.tgz 17 | 18 | echo -e "Downloading done!" 19 | -------------------------------------------------------------------------------- /prepare/motion_stats.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import hydra 3 | from omegaconf import DictConfig 4 | from hydra.utils import instantiate 5 | from tqdm import tqdm 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | @hydra.main(config_path="../configs", config_name="motion_stats", version_base="1.3") 11 | def motion_stats(cfg: DictConfig): 12 | logger.info("Computing motion stats") 13 | import src.prepare # noqa 14 | 15 | train_dataset = instantiate(cfg.data, split="train") 16 | import torch 17 | 18 | feats = torch.cat([x["motion_x_dict"]["x"] for x in tqdm(train_dataset)]) 19 | mean = feats.mean(0) 20 | std = feats.std(0) 21 | 22 | normalizer = train_dataset.motion_loader.normalizer 23 | logger.info(f"Saving them in {normalizer.base_dir}") 24 | normalizer.save(mean, std) 25 | 26 | 27 | if __name__ == "__main__": 28 | motion_stats() 29 | -------------------------------------------------------------------------------- /prepare/text_embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(config_path="../configs", config_name="text_embeddings", version_base="1.3") 10 | def text_embeddings(cfg: DictConfig): 11 | device = cfg.device 12 | 13 | from src.data.text import save_token_embeddings, save_sent_embeddings 14 | 15 | annotations_filename = cfg.annotations_filename 16 | 17 | # Compute token embeddings 18 | modelname = cfg.data.text_to_token_emb.modelname 19 | modelpath = cfg.data.text_to_token_emb.modelpath 20 | logger.info(f"Compute token embeddings for {modelname}") 21 | path = cfg.data.text_to_token_emb.path 22 | output_folder_name = cfg.output_folder_name_token 23 | save_token_embeddings(path, annotations_filename=annotations_filename, 24 | output_folder_name=output_folder_name, 25 | modelname=modelname, modelpath=modelpath, 26 | device=device) 27 | 28 | # Compute sent embeddings 29 | modelname = cfg.data.text_to_sent_emb.modelname 30 | modelpath = cfg.data.text_to_sent_emb.modelpath 31 | logger.info(f"Compute sentence embeddings for {modelname}") 32 | path = cfg.data.text_to_sent_emb.path 33 | output_folder_name = cfg.output_folder_name_sent 34 | save_sent_embeddings(path, annotations_filename=annotations_filename, 35 | output_folder_name=output_folder_name, 36 | modelname=modelname, modelpath=modelpath, 37 | device=device) 38 | 39 | 40 | if __name__ == "__main__": 41 | text_embeddings() 42 | -------------------------------------------------------------------------------- /prepare/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | 5 | 6 | def loop_amass( 7 | base_folder, 8 | new_base_folder, 9 | ext=".npz", 10 | newext=".npz", 11 | force_redo=False, 12 | exclude=None, 13 | ): 14 | match_str = f"**/*{ext}" 15 | 16 | for motion_file in tqdm(glob(match_str, root_dir=base_folder, recursive=True)): 17 | if exclude and exclude in motion_file: 18 | continue 19 | 20 | motion_path = os.path.join(base_folder, motion_file) 21 | 22 | if motion_path.endswith("shape.npz"): 23 | continue 24 | 25 | new_motion_path = os.path.join( 26 | new_base_folder, motion_file.replace(ext, newext) 27 | ) 28 | if not force_redo and os.path.exists(new_motion_path): 29 | continue 30 | 31 | new_folder = os.path.split(new_motion_path)[0] 32 | os.makedirs(new_folder, exist_ok=True) 33 | 34 | yield motion_path, new_motion_path 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.5 2 | aiosignal==1.3.1 3 | antlr4-python3-runtime==4.9.3 4 | async-timeout==4.0.3 5 | attrs==23.1.0 6 | certifi==2023.7.22 7 | charset-normalizer==2.1.1 8 | cmake==3.25.0 9 | colorlog==6.7.0 10 | einops==0.6.1 11 | filelock==3.9.0 12 | frozenlist==1.4.0 13 | fsspec==2023.9.0 14 | hydra-colorlog==1.2.0 15 | hydra-core==1.3.2 16 | idna==3.4 17 | Jinja2==3.1.2 18 | lightning-utilities==0.9.0 19 | lit==15.0.7 20 | MarkupSafe==2.1.2 21 | mpmath==1.3.0 22 | multidict==6.0.4 23 | networkx==3.0 24 | numpy==1.24.1 25 | omegaconf==2.3.0 26 | orjson==3.9.7 27 | packaging==23.1 28 | Pillow==9.3.0 29 | pytorch-lightning==2.0.9 30 | PyYAML==6.0.1 31 | requests==2.31.0 32 | scipy==1.11.2 33 | sympy==1.11.1 34 | torchmetrics==1.1.2 35 | transformers==4.41.2 36 | tqdm==4.66.1 37 | triton==2.0.0 38 | typing_extensions==4.4.0 39 | urllib3==1.26.13 40 | yarl==1.9.2 41 | -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(keyids) / min(batch_size, len(keyids)))) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | for data in tqdm(all_data_splitted, leave=True): 35 | batch = collate_text_motion(data, device=device) 36 | # Text is already encoded 37 | text_x_dict = batch["text_x_dict"] 38 | motion_x_dict = batch["motion_x_dict"] 39 | sent_emb = batch["sent_emb"] 40 | 41 | # Encode both motion and text 42 | latent_text = model.encode(text_x_dict, sample_mean=True) 43 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 44 | 45 | latent_texts.append(latent_text) 46 | latent_motions.append(latent_motion) 47 | sent_embs.append(sent_emb) 48 | 49 | latent_texts = torch.cat(latent_texts) 50 | latent_motions = torch.cat(latent_motions) 51 | sent_embs = torch.cat(sent_embs) 52 | sim_matrix = get_sim_matrix(latent_texts, latent_motions) 53 | returned = { 54 | "sim_matrix": sim_matrix.cpu().numpy(), 55 | "sent_emb": sent_embs.cpu().numpy(), 56 | } 57 | return returned 58 | 59 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval") 60 | def retrieval(newcfg: DictConfig) -> None: 61 | protocol = newcfg.protocol 62 | threshold_val = newcfg.threshold 63 | device = newcfg.device 64 | run_dir = newcfg.run_dir 65 | ckpt_name = newcfg.ckpt 66 | batch_size = newcfg.batch_size 67 | save_file_name = newcfg.save_file_name 68 | split = newcfg.split 69 | 70 | print("protocol : ", protocol) 71 | assert protocol in ["all", "normal", "threshold", "nsim", "guo", "normal_no_mirror", "threshold_no_mirror"] 72 | assert split == "test" or (protocol != "nsim" and protocol != "all") 73 | 74 | if protocol == "all": 75 | protocols = ["normal", "threshold", "nsim", "guo"] 76 | else: 77 | protocols = [protocol] 78 | 79 | save_dir = os.path.join(run_dir, save_file_name) 80 | os.makedirs(save_dir, exist_ok=True) 81 | 82 | # Load last config 83 | from src.config import read_config 84 | import src.prepare # noqa 85 | 86 | cfg = read_config(run_dir) 87 | 88 | import pytorch_lightning as pl 89 | import numpy as np 90 | from hydra.utils import instantiate 91 | from src.load import load_model_from_cfg 92 | from src.model.metrics import all_contrastive_metrics, print_latex_metrics 93 | 94 | pl.seed_everything(cfg.seed) 95 | 96 | logger.info("Loading the model") 97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 98 | 99 | 100 | data = newcfg.data 101 | if data is None: 102 | data = cfg.data 103 | 104 | datasets = {} 105 | results = {} 106 | for protocol in protocols: 107 | # Load the dataset if not already 108 | if protocol not in datasets: 109 | if protocol in ["normal", "threshold", "guo"]: 110 | dataset = instantiate(data, split=split) 111 | datasets.update( 112 | {key: dataset for key in ["normal", "threshold", "guo"]} 113 | ) 114 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]: 115 | datasets[protocol] = instantiate(data, split=split + "_no_mirror") 116 | elif protocol == "nsim": 117 | datasets[protocol] = instantiate(data, split="nsim_test") 118 | dataset = datasets[protocol] 119 | 120 | # Compute sim_matrix for each protocol 121 | if protocol not in results: 122 | if protocol in ["normal", "threshold"]: 123 | res = compute_sim_matrix( 124 | model, dataset, dataset.keyids, batch_size=batch_size 125 | ) 126 | results.update({key: res for key in ["normal", "threshold"]}) 127 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]: 128 | res = compute_sim_matrix( 129 | model, dataset, dataset.keyids, batch_size=batch_size 130 | ) 131 | results[protocol] = res 132 | elif protocol == "nsim": 133 | res = compute_sim_matrix( 134 | model, dataset, dataset.keyids, batch_size=batch_size 135 | ) 136 | results[protocol] = res 137 | elif protocol == "guo": 138 | keyids = sorted(dataset.keyids) 139 | N = len(keyids) 140 | 141 | # make batches of 32 142 | idx = np.arange(N) 143 | np.random.seed(0) 144 | np.random.shuffle(idx) 145 | idx_batches = [ 146 | idx[32 * i : 32 * (i + 1)] for i in range(len(keyids) // 32) 147 | ] 148 | 149 | # split into batches of 32 150 | # batched_keyids = [ [32], [32], [...]] 151 | results["guo"] = [ 152 | compute_sim_matrix( 153 | model, 154 | dataset, 155 | np.array(keyids)[idx_batch], 156 | batch_size=batch_size, 157 | ) 158 | for idx_batch in idx_batches 159 | ] 160 | result = results[protocol] 161 | 162 | # Compute the metrics 163 | if protocol == "guo": 164 | all_metrics = [] 165 | for x in result: 166 | sim_matrix = x["sim_matrix"] 167 | metrics = all_contrastive_metrics(sim_matrix, rounding=None) 168 | all_metrics.append(metrics) 169 | 170 | avg_metrics = {} 171 | for key in all_metrics[0].keys(): 172 | avg_metrics[key] = round( 173 | float(np.mean([metrics[key] for metrics in all_metrics])), 2 174 | ) 175 | 176 | metrics = avg_metrics 177 | protocol_name = protocol 178 | else: 179 | sim_matrix = result["sim_matrix"] 180 | 181 | protocol_name = protocol 182 | if protocol == "threshold": 183 | emb = result["sent_emb"] 184 | threshold = threshold_val 185 | protocol_name = protocol + f"_{threshold}" 186 | else: 187 | emb, threshold = None, None 188 | metrics = all_contrastive_metrics(sim_matrix, emb, threshold=threshold, t2m=True, m2t=False) 189 | 190 | print_latex_metrics(metrics, ranks=[1, 3, 10], t2m=True, m2t=False, MedR=False) 191 | 192 | print("protocol_name : ", protocol_name) 193 | metric_name = f"{protocol_name}.yaml" 194 | path = os.path.join(save_dir, metric_name) 195 | save_metric(path, metrics) 196 | 197 | logger.info(f"Testing done, metrics saved in:\n{path}") 198 | 199 | 200 | if __name__ == "__main__": 201 | retrieval() 202 | -------------------------------------------------------------------------------- /retrieval_action.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(dataset) / batch_size)) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | 35 | for data in tqdm(all_data_splitted, leave=True): 36 | batch = collate_text_motion(data, device=device) 37 | 38 | # Text is already encoded 39 | text_x_dict = batch["text_x_dict"] 40 | motion_x_dict = batch["motion_x_dict"] 41 | sent_emb = batch["sent_emb"] 42 | 43 | # Encode both motion and text 44 | latent_text = model.encode(text_x_dict, sample_mean=True) 45 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 46 | 47 | latent_texts.append(latent_text) 48 | latent_motions.append(latent_motion) 49 | sent_embs.append(sent_emb) 50 | 51 | latent_texts = torch.cat(latent_texts) 52 | action_latent_text = torch.unique(latent_texts, dim=0) 53 | 54 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))} 55 | 56 | latent_motions = torch.cat(latent_motions) 57 | motion_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_motions))] 58 | 59 | #sent_embs = torch.cat(sent_embs) 60 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions) 61 | returned = { 62 | "sim_matrix": sim_matrix.cpu().numpy(), 63 | "motion_cat_idx": motion_cat_idx 64 | } 65 | return returned 66 | 67 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval") 68 | def retrieval(newcfg: DictConfig) -> None: 69 | device = newcfg.device 70 | run_dir = newcfg.run_dir 71 | ckpt_name = newcfg.ckpt 72 | batch_size = newcfg.batch_size 73 | save_file_name = newcfg.save_file_name 74 | split = newcfg.split 75 | 76 | assert split == "test" 77 | protocols = ["normal"] 78 | 79 | save_dir = os.path.join(run_dir, save_file_name) 80 | os.makedirs(save_dir, exist_ok=True) 81 | 82 | # Load last config 83 | from src.config import read_config 84 | import src.prepare # noqa 85 | 86 | cfg = read_config(run_dir) 87 | 88 | import pytorch_lightning as pl 89 | import numpy as np 90 | from hydra.utils import instantiate 91 | from src.load import load_model_from_cfg 92 | from src.model.metrics import all_contrastive_metrics_action_retrieval, print_latex_metrics 93 | 94 | pl.seed_everything(cfg.seed) 95 | 96 | logger.info("Loading the model") 97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 98 | 99 | 100 | data = newcfg.data 101 | if data is None: 102 | data = cfg.data 103 | 104 | datasets = {} 105 | for protocol in protocols: 106 | # Load the dataset if not already 107 | if protocol not in datasets: 108 | dataset = instantiate(data, split=split) 109 | datasets.update( 110 | {key: dataset for key in ["normal", "threshold", "guo"]} 111 | ) 112 | dataset = datasets[protocol] 113 | 114 | # Compute sim_matrix for each protocol 115 | protocol = "normal" 116 | result = compute_sim_matrix( 117 | model, dataset, dataset.keyids, batch_size=batch_size 118 | ) 119 | 120 | # Compute the metrics 121 | sim_matrix = result["sim_matrix"] 122 | motion_cat_idx = result["motion_cat_idx"] 123 | 124 | protocol_name = protocol 125 | metrics = all_contrastive_metrics_action_retrieval(sim_matrix, motion_cat_idx, norm_metrics=True) 126 | 127 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False) 128 | 129 | metric_name = f"{protocol_name}.yaml" 130 | path = os.path.join(save_dir, metric_name) 131 | save_metric(path, metrics) 132 | 133 | logger.info(f"Testing done, metrics saved in:\n{path}") 134 | 135 | 136 | if __name__ == "__main__": 137 | retrieval() 138 | -------------------------------------------------------------------------------- /retrieval_action_multi_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | import yaml 6 | from tqdm import tqdm 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def save_metric(path, metrics): 12 | strings = yaml.dump(metrics, indent=4, sort_keys=False) 13 | with open(path, "w") as f: 14 | f.write(strings) 15 | 16 | 17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256): 18 | import torch 19 | import numpy as np 20 | from src.data.collate import collate_text_motion_multiple_texts 21 | from src.model.tmr import get_sim_matrix 22 | 23 | device = model.device 24 | 25 | nsplit = int(np.ceil(len(dataset) / batch_size)) 26 | with torch.inference_mode(): 27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids] 28 | all_data_splitted = np.array_split(all_data, nsplit) 29 | 30 | # by batch (can be too costly on cuda device otherwise) 31 | latent_texts = [] 32 | latent_motions = [] 33 | sent_embs = [] 34 | text_indices = [] 35 | indices_shift = 0 36 | 37 | #for data in tqdm(all_data_splitted, leave=True): 38 | for data in all_data_splitted: 39 | 40 | batch = collate_text_motion_multiple_texts(data, device=device) 41 | # Text is already encoded 42 | text_x_dict = batch["text_x_dict"] 43 | motion_x_dict = batch["motion_x_dict"] 44 | sent_emb = batch["sent_emb"] 45 | 46 | # Encode both motion and text 47 | latent_text = model.encode(text_x_dict, sample_mean=True) 48 | latent_motion = model.encode(motion_x_dict, sample_mean=True) 49 | 50 | latent_texts.append(latent_text) 51 | latent_motions.append(latent_motion) 52 | sent_embs.append(sent_emb) 53 | idx = batch["text_slices"] 54 | idx = [[elt[0] + indices_shift, elt[1] + indices_shift] for elt in idx] 55 | text_indices.extend(idx) 56 | indices_shift += len(latent_text) 57 | 58 | latent_texts = torch.cat(latent_texts) 59 | action_latent_text = torch.unique(latent_texts, dim=0) 60 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))} 61 | text_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_texts))] 62 | 63 | latent_motions = torch.cat(latent_motions) 64 | motion_cat_idx = [] 65 | for start_ind, end_ind in text_indices: 66 | motion_cat_idx.append(text_cat_idx[start_ind:end_ind]) 67 | 68 | #sent_embs = torch.cat(sent_embs) 69 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions) 70 | 71 | returned = { 72 | "sim_matrix": sim_matrix.cpu().numpy(), 73 | "motion_cat_idx": motion_cat_idx 74 | } 75 | return returned 76 | 77 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval_action_multi_labels") 78 | def retrieval_action_multi_labels(newcfg: DictConfig) -> None: 79 | device = newcfg.device 80 | run_dir = newcfg.run_dir 81 | ckpt_name = newcfg.ckpt 82 | batch_size = newcfg.batch_size 83 | save_file_name = newcfg.save_file_name 84 | split = newcfg.split 85 | 86 | assert split == "test" 87 | 88 | save_dir = os.path.join(run_dir, save_file_name) 89 | os.makedirs(save_dir, exist_ok=True) 90 | 91 | # Load last config 92 | from src.config import read_config 93 | import src.prepare # noqa 94 | 95 | cfg = read_config(run_dir) 96 | 97 | import pytorch_lightning as pl 98 | import numpy as np 99 | from hydra.utils import instantiate 100 | from src.load import load_model_from_cfg 101 | from src.model.metrics import all_contrastive_metrics_action_retrieval_multi_labels, print_latex_metrics 102 | 103 | pl.seed_everything(cfg.seed) 104 | 105 | logger.info("Loading the model") 106 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 107 | 108 | 109 | data = newcfg.data 110 | if data is None: 111 | data = cfg.data 112 | 113 | dataset = instantiate(data, split=split) 114 | 115 | # Compute sim_matrix for each protocol 116 | protocol = "normal" 117 | result = compute_sim_matrix( 118 | model, dataset, dataset.keyids, batch_size=batch_size 119 | ) 120 | 121 | # Compute the metrics 122 | sim_matrix = result["sim_matrix"] 123 | motion_cat_idx = result["motion_cat_idx"] 124 | 125 | protocol_name = protocol 126 | metrics = all_contrastive_metrics_action_retrieval_multi_labels(sim_matrix, motion_cat_idx, norm_metrics=True) 127 | 128 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False) 129 | 130 | metric_name = f"{protocol_name}.yaml" 131 | path = os.path.join(save_dir, metric_name) 132 | save_metric(path, metrics) 133 | 134 | logger.info(f"Testing done, metrics saved in:\n{path}") 135 | 136 | 137 | if __name__ == "__main__": 138 | retrieval_action_multi_labels() 139 | -------------------------------------------------------------------------------- /src/callback/__init__.py: -------------------------------------------------------------------------------- 1 | # from .render import RenderCallback 2 | from .progress import ProgressLogger 3 | -------------------------------------------------------------------------------- /src/callback/progress.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning import LightningModule, Trainer 4 | from pytorch_lightning.callbacks import Callback 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ProgressLogger(Callback): 11 | def __init__(self, precision: int = 2): 12 | self.precision = precision 13 | 14 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, **kwargs): 15 | logger.info("Training started") 16 | 17 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs): 18 | logger.info("Training done") 19 | 20 | def on_validation_epoch_end( 21 | self, trainer: Trainer, pl_module: LightningModule, **kwargs 22 | ): 23 | if trainer.sanity_checking: 24 | logger.info("Sanity checking ok.") 25 | 26 | def on_train_epoch_end( 27 | self, trainer: Trainer, pl_module: LightningModule, **kwargs 28 | ): 29 | metric_format = f"{{:.{self.precision}e}}" 30 | line = f"Epoch {trainer.current_epoch}" 31 | metrics_str = [] 32 | 33 | losses_dict = trainer.callback_metrics 34 | 35 | def is_contrastive_metrics(x): 36 | return "t2m" in x or "m2t" in x 37 | 38 | losses_to_print = [ 39 | x 40 | for x in losses_dict.keys() 41 | for y in [x.split("_")] 42 | if len(y) == 3 43 | and y[2] == "epoch" 44 | and ( 45 | y[1] in pl_module.lmd or y[1] == "loss" or is_contrastive_metrics(y[1]) 46 | ) 47 | ] 48 | 49 | # Natual order for contrastive 50 | letters = "0123456789" 51 | mapping = str.maketrans(letters, letters[::-1]) 52 | 53 | def sort_losses(x): 54 | split, name, epoch_step = x.split("_") 55 | if is_contrastive_metrics(x): 56 | # put them at the end 57 | name = "a" + name.translate(mapping) 58 | return (name, split) 59 | 60 | losses_to_print = sorted(losses_to_print, key=sort_losses, reverse=True) 61 | for metric_name in losses_to_print: 62 | split, name, _ = metric_name.split("_") 63 | 64 | metric = losses_dict[metric_name].item() 65 | 66 | if is_contrastive_metrics(metric_name): 67 | if "len" in metric_name: 68 | metric = str(int(metric)) 69 | elif "MedR" in metric_name: 70 | metric = str(int(metric * 100) / 100) + "%" 71 | else: 72 | metric = str(int(metric * 100) / 100) + "%" 73 | else: 74 | metric = metric_format.format(metric) 75 | 76 | if split == "train": 77 | mname = name 78 | else: 79 | mname = f"v_{name}" 80 | 81 | metric = f"{mname} {metric}" 82 | metrics_str.append(metric) 83 | 84 | if len(metrics_str) == 0: 85 | return 86 | 87 | line = line + ": " + " ".join(metrics_str) 88 | logger.info(line) 89 | -------------------------------------------------------------------------------- /src/callback/tqdmbar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import TQDMProgressBar as OriginalTQDMProgressBar 4 | 5 | 6 | def customize_bar(bar): 7 | if not sys.stdout.isatty(): 8 | bar.disable = True 9 | bar.leave = True # remove the bar after completion 10 | return bar 11 | 12 | 13 | class TQDMProgressBar(OriginalTQDMProgressBar): 14 | # remove the annoying v_num in the bar 15 | def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 16 | items_dict = super().get_metrics(trainer, pl_module).copy() 17 | 18 | if "v_num" in items_dict: 19 | items_dict.pop("v_num") 20 | return items_dict 21 | 22 | def init_sanity_tqdm(self): 23 | bar = super().init_sanity_tqdm() 24 | return customize_bar(bar) 25 | 26 | def init_train_tqdm(self): 27 | bar = super().init_train_tqdm() 28 | return customize_bar(bar) 29 | 30 | def init_validation_tqdm(self): 31 | bar = super().init_validation_tqdm() 32 | bar.disable = True 33 | return bar 34 | 35 | def init_predict_tqdm(self): 36 | bar = super().init_predict_tqdm() 37 | return customize_bar(bar) 38 | 39 | def init_test_tqdm(self): 40 | bar = super().init_test_tqdm() 41 | return customize_bar(bar) 42 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | 6 | def save_config(cfg: DictConfig) -> str: 7 | path = os.path.join(cfg.run_dir, "config.json") 8 | config = OmegaConf.to_container(cfg, resolve=True) 9 | with open(path, "w") as f: 10 | string = json.dumps(config, indent=4) 11 | f.write(string) 12 | return path 13 | 14 | 15 | def read_config(run_dir: str, return_json=False) -> DictConfig: 16 | path = os.path.join(run_dir, "config.json") 17 | with open(path, "r") as f: 18 | config = json.load(f) 19 | if return_json: 20 | return config 21 | cfg = OmegaConf.create(config) 22 | cfg.run_dir = run_dir 23 | return cfg 24 | -------------------------------------------------------------------------------- /src/data/augmented_text_motion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import random 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from .collate import collate_text_motion_multiple_texts 9 | from .text_motion import load_annotations, TextMotionDataset 10 | 11 | 12 | class AugmentedTextMotionDataset(TextMotionDataset): 13 | def __init__( 14 | self, 15 | path: str, 16 | motion_loader, 17 | text_to_sent_emb, 18 | text_to_token_emb, 19 | split: str = "train", 20 | min_seconds: float = 2.0, 21 | max_seconds: float = 10.0, 22 | preload: bool = True, 23 | tiny: bool = False, 24 | paraphrase_filename: str = None, 25 | summary_filename: str = None, 26 | paraphrase_prob: float = 0.2, 27 | summary_prob: float = 0.2, 28 | averaging_prob: float = 0.4, 29 | text_sampling_nbr: int = 4 30 | ): 31 | super().__init__(path, motion_loader, text_to_sent_emb, text_to_token_emb, 32 | split=split, min_seconds=min_seconds, max_seconds=max_seconds, 33 | preload=False, tiny=tiny) 34 | 35 | self.collate_fn = collate_text_motion_multiple_texts 36 | 37 | assert paraphrase_prob == 0 or paraphrase_filename is not None 38 | assert summary_prob == 0 or summary_filename is not None 39 | 40 | self.text_sampling_nbr = text_sampling_nbr 41 | self.paraphrase_prob = 0 42 | if split=="train" and paraphrase_filename is not None: 43 | self.annotations_paraphrased = load_annotations(path, name=paraphrase_filename) 44 | self.paraphrase_prob = paraphrase_prob 45 | self.summary_prob = 0 46 | if split=="train" and summary_filename is not None: 47 | self.annotations_summary = load_annotations(path, name=summary_filename) 48 | self.summary_prob = summary_prob 49 | self.averaging_prob = 0 50 | if split=="train" and paraphrase_filename is not None: 51 | self.averaging_prob = averaging_prob 52 | 53 | # filter annotations (min/max) 54 | # but not for the test set 55 | # otherwise it is not fair for everyone 56 | if "test" not in split: 57 | if "train" in split and paraphrase_filename is not None: 58 | self.annotations_paraphrased = self.filter_annotations(self.annotations_paraphrased) 59 | if "train" in split and summary_filename is not None: 60 | self.annotations_summary = self.filter_annotations(self.annotations_summary) 61 | 62 | if preload: 63 | for _ in tqdm(self, desc="Preloading the dataset"): 64 | continue 65 | 66 | def load_keyid(self, keyid, text_idx=None, sent_emb_mode="first"): 67 | 68 | p = random.random() # Probability that will determine if we use data from augmentation, and with which config 69 | averaging = False 70 | if self.is_training and p < self.paraphrase_prob: 71 | annotations = self.annotations_paraphrased[keyid] 72 | elif self.is_training and p < self.summary_prob + self.paraphrase_prob: 73 | if keyid in self.annotations_summary: 74 | annotations = self.annotations_summary[keyid] 75 | else: 76 | annotations = self.annotations_paraphrased[keyid] # For Babel that has no summary 77 | elif self.is_training and p < self.averaging_prob + self.summary_prob + self.paraphrase_prob: 78 | annotations = copy.deepcopy(self.annotations[keyid]) 79 | if hasattr(self, "annotations_paraphrased") and keyid in self.annotations_paraphrased: 80 | annotations["annotations"] += self.annotations_paraphrased[keyid]["annotations"] 81 | if hasattr(self, "annotations_summary") and keyid in self.annotations_summary: 82 | annotations["annotations"] += self.annotations_summary[keyid]["annotations"] 83 | averaging = True 84 | else: 85 | annotations = self.annotations[keyid] 86 | 87 | # Take the first one for testing/validation 88 | # Otherwise take a random one 89 | index = 0 90 | if averaging: 91 | if isinstance(self.text_sampling_nbr, int): # If number of samples if provided 92 | n = min(self.text_sampling_nbr, len(annotations["annotations"])) 93 | else: # If number of sample not provided, it's chosen randomly 94 | n = random.randint(2, len(annotations["annotations"])) 95 | index = random.sample(range(0, len(annotations["annotations"])), n) 96 | elif text_idx is not None: 97 | index = text_idx % len(annotations["annotations"]) 98 | elif self.is_training: 99 | index = np.random.randint(len(annotations["annotations"])) 100 | 101 | if isinstance(index, int): 102 | index = [index] 103 | 104 | annotation_list = [annotations["annotations"][i] for i in index] 105 | text = [ann["text"] for ann in annotation_list] 106 | annotation0 = annotations["annotations"][index[0]] 107 | 108 | text_x_dict = [self.text_to_token_emb(t) for t in text] 109 | 110 | motion_x_dict = self.motion_loader( 111 | path=annotations["path"], 112 | start=annotation0["start"], 113 | end=annotation0["end"], 114 | ) 115 | 116 | if sent_emb_mode == "first": 117 | sent_emb = self.text_to_sent_emb(text[0]) 118 | elif sent_emb_mode == "average": 119 | sent_emb = torch.stack([self.text_to_sent_emb(t) for t in text]) 120 | sent_emb = torch.mean(sent_emb, axis=0) 121 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0) 122 | 123 | output = { 124 | "motion_x_dict": motion_x_dict, 125 | "text_x_dict": text_x_dict, 126 | "text": text, 127 | "keyid": keyid, 128 | "sent_emb": sent_emb, 129 | } 130 | 131 | # TODO 132 | #if device is not None: 133 | # output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device) 134 | # for i in range(len(output["text_x_dict"][i])): 135 | # output["text_x_dict"][i]["x"] = output["text_x_dict"][i]["x"].to(device) 136 | 137 | return output 138 | -------------------------------------------------------------------------------- /src/data/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List, Dict, Optional 4 | from torch import Tensor 5 | from torch.utils.data import default_collate 6 | 7 | 8 | def length_to_mask(length, device: torch.device = None) -> Tensor: 9 | if device is None: 10 | device = "cpu" 11 | 12 | if isinstance(length, list): 13 | length = torch.tensor(length, device=device) 14 | 15 | max_len = max(length) 16 | mask = torch.arange(max_len, device=device).expand( 17 | len(length), max_len 18 | ) < length.unsqueeze(1) 19 | return mask 20 | 21 | 22 | def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor: 23 | dims = batch[0].dim() 24 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] 25 | size = (len(batch),) + tuple(max_size) 26 | canvas = batch[0].new_zeros(size=size) 27 | for i, b in enumerate(batch): 28 | sub_tensor = canvas[i] 29 | for d in range(dims): 30 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) 31 | sub_tensor.add_(b) 32 | return canvas 33 | 34 | 35 | def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = None) -> Dict: 36 | x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict]) 37 | if device is not None: 38 | x = x.to(device) 39 | length = [x_dict["length"] for x_dict in lst_x_dict] 40 | mask = length_to_mask(length, device=x.device) 41 | batch = {"x": x, "length": length, "mask": mask} 42 | return batch 43 | 44 | 45 | def collate_text_motion(lst_elements: List, *, device: Optional[str] = None) -> Dict: 46 | one_el = lst_elements[0] 47 | keys = one_el.keys() 48 | 49 | x_dict_keys = [key for key in keys if "x_dict" in key] 50 | other_keys = [key for key in keys if "x_dict" not in key] 51 | 52 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys} 53 | for key, val in batch.items(): 54 | if isinstance(val, torch.Tensor) and device is not None: 55 | batch[key] = val.to(device) 56 | 57 | for key in x_dict_keys: 58 | batch[key] = collate_x_dict([x[key] for x in lst_elements], device=device) 59 | return batch 60 | 61 | 62 | def collate_text_motion_multiple_texts(lst_elements: List, *, device: Optional[str] = None): 63 | other_keys = ['keyid', 'sent_emb'] 64 | 65 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys} 66 | batch["text"] = [elt["text"] for elt in lst_elements] 67 | 68 | for key, val in batch.items(): 69 | if isinstance(val, torch.Tensor) and device is not None: 70 | batch[key] = val.to(device) 71 | 72 | batch["motion_x_dict"] = collate_x_dict([x["motion_x_dict"] for x in lst_elements], device=device) 73 | 74 | batch["text_slices"] = [] 75 | current_index = 0 76 | for elt in lst_elements: 77 | batch["text_slices"].append((current_index, current_index + len(elt["text"]))) 78 | current_index += len(elt["text"]) 79 | 80 | texts_concat = [x_dict for x in lst_elements for x_dict in x["text_x_dict"]] 81 | batch["text_x_dict"] = collate_x_dict( 82 | texts_concat, 83 | device=device 84 | ) 85 | return batch 86 | -------------------------------------------------------------------------------- /src/data/motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class AMASSMotionLoader: 7 | def __init__( 8 | self, base_dir, fps, normalizer=None, disable: bool = False, nfeats=None 9 | ): 10 | self.fps = fps 11 | self.base_dir = base_dir 12 | self.motions = {} 13 | self.normalizer = normalizer 14 | self.disable = disable 15 | self.nfeats = nfeats 16 | 17 | def __call__(self, path, start, end): 18 | if self.disable: 19 | return {"x": path, "length": int(self.fps * (end - start))} 20 | 21 | begin = int(start * self.fps) 22 | end = int(end * self.fps) 23 | if path not in self.motions: 24 | motion_path = os.path.join(self.base_dir, path + ".npy") 25 | motion = np.load(motion_path) 26 | motion = torch.from_numpy(motion).to(torch.float) 27 | if self.normalizer is not None: 28 | motion = self.normalizer(motion) 29 | self.motions[path] = motion 30 | 31 | motion = self.motions[path][begin:end] 32 | x_dict = {"x": motion, "length": len(motion)} 33 | return x_dict 34 | 35 | 36 | class Normalizer: 37 | def __init__(self, base_dir: str, eps: float = 1e-12, disable: bool = False): 38 | self.base_dir = base_dir 39 | self.mean_path = os.path.join(base_dir, "mean.pt") 40 | self.std_path = os.path.join(base_dir, "std.pt") 41 | self.eps = eps 42 | 43 | self.disable = disable 44 | if not disable: 45 | self.load() 46 | 47 | def load(self): 48 | self.mean = torch.load(self.mean_path) 49 | self.std = torch.load(self.std_path) 50 | 51 | def save(self, mean, std): 52 | os.makedirs(self.base_dir, exist_ok=True) 53 | torch.save(mean, self.mean_path) 54 | torch.save(std, self.std_path) 55 | 56 | def __call__(self, x): 57 | if self.disable: 58 | return x 59 | x = (x - self.mean) / (self.std + self.eps) 60 | return x 61 | 62 | def inverse(self, x): 63 | if self.disable: 64 | return x 65 | x = x * (self.std + self.eps) + self.mean 66 | return x 67 | -------------------------------------------------------------------------------- /src/data/text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import orjson 3 | import json 4 | import torch 5 | from torch import Tensor 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from abc import ABC, abstractmethod 10 | 11 | from src.model import TextToEmb 12 | 13 | 14 | class TextEmbeddings(ABC): 15 | name = ... 16 | 17 | def __init__( 18 | self, 19 | modelname: str, 20 | modelpath: str = None, 21 | path: str = "", 22 | device: str = "cpu", 23 | preload: bool = True, 24 | disable: bool = False, 25 | name: str = None # TODO check if needed keep this change compared to TMR 26 | ): 27 | if name is not None: 28 | self.name = name 29 | 30 | self.modelname = modelname 31 | self.modelpath = modelpath 32 | if modelpath is None: 33 | self.modelpath = modelname 34 | self.embeddings_folder = os.path.join(path, self.name) 35 | 36 | self.cache = {} 37 | self.device = device 38 | self.disable = disable 39 | 40 | if preload and not disable: 41 | self.load_embeddings() 42 | else: 43 | self.embeddings_index = {} 44 | 45 | @abstractmethod 46 | def load_model(self) -> None: 47 | ... 48 | 49 | @abstractmethod 50 | def load_embeddings(self) -> None: 51 | ... 52 | 53 | @abstractmethod 54 | def get_embedding(self, text: str) -> Tensor: 55 | ... 56 | 57 | def __contains__(self, text): 58 | return text in self.embeddings_index 59 | 60 | def get_model(self): 61 | model = getattr(self, "model", None) 62 | if model is None: 63 | model = self.load_model() 64 | return model 65 | 66 | def __call__(self, texts): 67 | if self.disable: 68 | return texts 69 | 70 | squeeze = False 71 | if isinstance(texts, str): 72 | texts = [texts] 73 | squeeze = True 74 | 75 | x_dict_lst = [] 76 | # one at a time here 77 | for text in texts: 78 | # Precomputed in advance 79 | if text in self: 80 | x_dict = self.get_embedding(text) 81 | # Already computed during the session 82 | elif text in self.cache: 83 | x_dict = self.cache[text] 84 | # Load the text model (if not already loaded) + compute on the fly 85 | else: 86 | model = self.get_model() 87 | x_dict = model(text) 88 | self.cache[text] = x_dict 89 | x_dict_lst.append(x_dict) 90 | 91 | if squeeze: 92 | return x_dict_lst[0] 93 | return x_dict_lst 94 | 95 | 96 | class TokenEmbeddings(TextEmbeddings): 97 | name = "token_embeddings" 98 | 99 | def load_model(self): 100 | self.model = TextToEmb(self.modelpath, mean_pooling=False, device=self.device) 101 | return self.model 102 | 103 | def load_embeddings(self): 104 | self.embeddings_big = torch.from_numpy( 105 | np.load(os.path.join(self.embeddings_folder, self.modelname + ".npy")) 106 | ).to(dtype=torch.float, device=self.device) 107 | self.embeddings_slice = np.load( 108 | os.path.join(self.embeddings_folder, self.modelname + "_slice.npy") 109 | ) 110 | self.embeddings_index = load_json( 111 | os.path.join(self.embeddings_folder, self.modelname + "_index.json") 112 | ) 113 | self.text_dim = self.embeddings_big.shape[-1] 114 | 115 | def get_embedding(self, text): 116 | # Precomputed in advance 117 | index = self.embeddings_index[text] 118 | begin, end = self.embeddings_slice[index] 119 | embedding = self.embeddings_big[begin:end] 120 | x_dict = {"x": embedding, "length": len(embedding)} 121 | return x_dict 122 | 123 | 124 | class SentenceEmbeddings(TextEmbeddings): 125 | name = "sent_embeddings" 126 | 127 | def load_model(self): 128 | self.model = TextToEmb(self.modelpath, mean_pooling=True, device=self.device) 129 | return self.model 130 | 131 | def load_embeddings(self): 132 | self.embeddings = torch.from_numpy( 133 | np.load(os.path.join(self.embeddings_folder, self.modelname + ".npy")) 134 | ).to(dtype=torch.float, device=self.device) 135 | self.embeddings_index = load_json( 136 | os.path.join(self.embeddings_folder, self.modelname + "_index.json") 137 | ) 138 | assert len(self.embeddings_index) == len(self.embeddings) 139 | 140 | self.text_dim = self.embeddings.shape[-1] 141 | 142 | def get_embedding(self, text): 143 | index = self.embeddings_index[text] 144 | embedding = self.embeddings[index] 145 | return embedding.to(self.device) 146 | 147 | 148 | def load_json(json_path): 149 | with open(json_path, "rb") as ff: 150 | return orjson.loads(ff.read()) 151 | 152 | 153 | def load_annotations(path, name="annotations.json"): 154 | json_path = os.path.join(path, name) 155 | return load_json(json_path) 156 | 157 | 158 | def write_json(data, path): 159 | with open(path, "w") as ff: 160 | ff.write(json.dumps(data, indent=4)) 161 | 162 | 163 | def save_token_embeddings( 164 | path, annotations_filename="annotations.json", output_folder_name=None, 165 | modelname="sentence-transformers/all-mpnet-base-v2", modelpath=None, device="cuda" 166 | ): 167 | if modelpath is None: 168 | modelpath = modelname 169 | model = TextToEmb(modelpath, device=device) 170 | 171 | annotations = load_annotations(path, name=annotations_filename) 172 | 173 | if output_folder_name is None: 174 | output_folder_name = TokenEmbeddings.name 175 | path = os.path.join(path, output_folder_name) 176 | 177 | ptpath = os.path.join(path, f"{modelname}.npy") 178 | slicepath = os.path.join(path, f"{modelname}_slice.npy") 179 | jsonpath = os.path.join(path, f"{modelname}_index.json") 180 | 181 | # modelname can have folders 182 | path = os.path.split(ptpath)[0] 183 | os.makedirs(path, exist_ok=True) 184 | 185 | # fetch all the texts 186 | all_texts = [] 187 | for dico in annotations.values(): 188 | for lst in dico["annotations"]: 189 | all_texts.append(lst["text"]) 190 | 191 | # remove duplicates 192 | all_texts = list(set(all_texts)) 193 | 194 | # batch of N/10 195 | nb_tokens = [] 196 | all_texts_batched = np.array_split(all_texts, min(len(all_texts), 100)) 197 | 198 | nb_tokens_so_far = 0 199 | big_tensor = [] 200 | index = [] 201 | for all_texts_batch in tqdm(all_texts_batched): 202 | x_dict = model(list(all_texts_batch)) 203 | 204 | tensor = x_dict["x"] 205 | nb_tokens = x_dict["length"] 206 | 207 | # remove padding 208 | tensor_no_padding = [x[:n].cpu() for x, n in zip(tensor, nb_tokens)] 209 | tensor_concat = torch.cat(tensor_no_padding) 210 | 211 | big_tensor.append(tensor_concat) 212 | # save where it is 213 | ends = torch.cumsum(nb_tokens, 0) 214 | begins = torch.cat((0 * ends[[0]], ends[:-1])) 215 | 216 | # offset 217 | ends += nb_tokens_so_far 218 | begins += nb_tokens_so_far 219 | nb_tokens_so_far += len(tensor_concat) 220 | 221 | index.append(torch.stack((begins, ends)).T) 222 | 223 | big_tensor = torch.cat(big_tensor).cpu().numpy() 224 | index = torch.cat(index).cpu().numpy() 225 | 226 | np.save(ptpath, big_tensor) 227 | np.save(slicepath, index) 228 | print(f"{ptpath} written") 229 | print(f"{slicepath} written") 230 | 231 | # correspondance 232 | dico = {txt: i for i, txt in enumerate(all_texts)} 233 | write_json(dico, jsonpath) 234 | print(f"{jsonpath} written") 235 | 236 | 237 | def save_sent_embeddings( 238 | path, annotations_filename="annotations.json", output_folder_name=None, 239 | modelname="sentence-transformers/all-mpnet-base-v2", modelpath=None, device="cuda" 240 | ): 241 | # Provide modelpath as a path to the local folder of the model if you can't access internet during training 242 | if modelpath is None: 243 | modelpath = modelname 244 | model = TextToEmb(modelpath, mean_pooling=True, device=device) 245 | annotations = load_annotations(path, name=annotations_filename) 246 | 247 | if output_folder_name is None: 248 | output_folder_name = SentenceEmbeddings.name 249 | path = os.path.join(path, output_folder_name) 250 | 251 | ptpath = os.path.join(path, f"{modelname}.npy") 252 | jsonpath = os.path.join(path, f"{modelname}_index.json") 253 | 254 | # modelname can have folders 255 | path = os.path.split(ptpath)[0] 256 | os.makedirs(path, exist_ok=True) 257 | 258 | # fetch all the texts 259 | all_texts = [] 260 | for dico in annotations.values(): 261 | for lst in dico["annotations"]: 262 | all_texts.append(lst["text"]) 263 | 264 | # remove duplicates 265 | all_texts = list(set(all_texts)) 266 | 267 | # batch of N/10 268 | all_texts_batched = np.array_split(all_texts, min(len(all_texts), 100)) 269 | embeddings = [] 270 | for all_texts_batch in tqdm(all_texts_batched): 271 | embedding = model(list(all_texts_batch)).cpu() 272 | embeddings.append(embedding) 273 | 274 | embeddings = torch.cat(embeddings).numpy() 275 | np.save(ptpath, embeddings) 276 | print(f"{ptpath} written") 277 | 278 | # correspondance 279 | dico = {txt: i for i, txt in enumerate(all_texts)} 280 | write_json(dico, jsonpath) 281 | print(f"{jsonpath} written") 282 | 283 | -------------------------------------------------------------------------------- /src/data/text_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs as cs 3 | import orjson # loading faster than json 4 | import json 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | 10 | from .collate import collate_text_motion 11 | 12 | 13 | def read_split(path, split): 14 | split_file = os.path.join(path, "splits", split + ".txt") 15 | id_list = [] 16 | with cs.open(split_file, "r") as f: 17 | for line in f.readlines(): 18 | id_list.append(line.strip()) 19 | return id_list 20 | 21 | 22 | def load_annotations(path, name="annotations.json"): 23 | json_path = os.path.join(path, name) 24 | with open(json_path, "rb") as ff: 25 | return orjson.loads(ff.read()) 26 | 27 | 28 | class TextMotionDataset(Dataset): 29 | def __init__( 30 | self, 31 | path: str, 32 | motion_loader, 33 | text_to_sent_emb, 34 | text_to_token_emb, 35 | split: str = "train", 36 | min_seconds: float = 2.0, 37 | max_seconds: float = 10.0, 38 | preload: bool = True, 39 | tiny: bool = False, 40 | ): 41 | if tiny: 42 | split = split + "_tiny" 43 | 44 | self.collate_fn = collate_text_motion 45 | self.split = split 46 | self.keyids = read_split(path, split) 47 | 48 | self.text_to_sent_emb = text_to_sent_emb 49 | self.text_to_token_emb = text_to_token_emb 50 | self.motion_loader = motion_loader 51 | 52 | self.min_seconds = min_seconds 53 | self.max_seconds = max_seconds 54 | 55 | # remove too short or too long annotations 56 | self.annotations = load_annotations(path) 57 | 58 | # filter annotations (min/max) 59 | # but not for the test set 60 | # otherwise it is not fair for everyone 61 | if "test" not in split: 62 | self.annotations = self.filter_annotations(self.annotations) 63 | 64 | self.is_training = "train" in split 65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations] 66 | self.nfeats = self.motion_loader.nfeats 67 | 68 | if preload: 69 | for _ in tqdm(self, desc="Preloading the dataset"): 70 | continue 71 | 72 | def __len__(self): 73 | return len(self.keyids) 74 | 75 | def __getitem__(self, index): 76 | keyid = self.keyids[index] 77 | return self.load_keyid(keyid) 78 | 79 | def load_keyid(self, keyid): 80 | annotations = self.annotations[keyid] 81 | 82 | # Take the first one for testing/validation 83 | # Otherwise take a random one 84 | index = 0 85 | if self.is_training: 86 | index = np.random.randint(len(annotations["annotations"])) 87 | annotation = annotations["annotations"][index] 88 | 89 | text = annotation["text"] 90 | text_x_dict = self.text_to_token_emb(text) 91 | motion_x_dict = self.motion_loader( 92 | path=annotations["path"], 93 | start=annotation["start"], 94 | end=annotation["end"], 95 | ) 96 | sent_emb = self.text_to_sent_emb(text) 97 | 98 | output = { 99 | "motion_x_dict": motion_x_dict, 100 | "text_x_dict": text_x_dict, 101 | "text": text, 102 | "keyid": keyid, 103 | "sent_emb": sent_emb, 104 | } 105 | return output 106 | 107 | def filter_annotations(self, annotations): 108 | filtered_annotations = {} 109 | for key, val in annotations.items(): 110 | annots = val.pop("annotations") 111 | filtered_annots = [] 112 | for annot in annots: 113 | duration = annot["end"] - annot["start"] 114 | if self.max_seconds >= duration >= self.min_seconds: 115 | filtered_annots.append(annot) 116 | 117 | if filtered_annots: 118 | val["annotations"] = filtered_annots 119 | filtered_annotations[key] = val 120 | 121 | return filtered_annotations 122 | 123 | 124 | def write_json(data, path): 125 | with open(path, "w") as ff: 126 | ff.write(json.dumps(data, indent=4)) 127 | -------------------------------------------------------------------------------- /src/data/text_motion_multi_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs as cs 3 | import orjson # loading faster than json 4 | import json 5 | import logging 6 | import random 7 | 8 | import torch 9 | import numpy as np 10 | from .text_motion import TextMotionDataset 11 | from tqdm import tqdm 12 | 13 | from .collate import collate_text_motion_multiple_texts 14 | 15 | 16 | 17 | def read_split(path, split): 18 | split_file = os.path.join(path, "splits", split + ".txt") 19 | id_list = [] 20 | with cs.open(split_file, "r") as f: 21 | for line in f.readlines(): 22 | id_list.append(line.strip()) 23 | return id_list 24 | 25 | 26 | def load_annotations(path, name="annotations.json"): 27 | json_path = os.path.join(path, name) 28 | with open(json_path, "rb") as ff: 29 | return orjson.loads(ff.read()) 30 | 31 | 32 | class TextMotionMultiLabelsDataset(TextMotionDataset): 33 | def __init__( 34 | self, 35 | path: str, 36 | motion_loader, 37 | text_to_sent_emb, 38 | text_to_token_emb, 39 | split: str = "train", 40 | min_seconds: float = 2.0, 41 | max_seconds: float = 10.0, 42 | preload: bool = True, 43 | tiny: bool = False, 44 | ): 45 | if tiny: 46 | split = split + "_tiny" 47 | 48 | self.collate_fn = collate_text_motion_multiple_texts 49 | self.split = split 50 | self.keyids = read_split(path, split) 51 | 52 | self.text_to_sent_emb = text_to_sent_emb 53 | self.text_to_token_emb = text_to_token_emb 54 | self.motion_loader = motion_loader 55 | 56 | self.min_seconds = min_seconds 57 | self.max_seconds = max_seconds 58 | 59 | # remove too short or too long annotations 60 | self.annotations = load_annotations(path) 61 | if "test" not in split: 62 | self.annotations = self.filter_annotations(self.annotations) 63 | 64 | self.is_training = "train" in split 65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations] 66 | 67 | self.nfeats = self.motion_loader.nfeats 68 | 69 | if preload: 70 | for _ in tqdm(self, desc="Preloading the dataset"): 71 | continue 72 | 73 | def load_keyid(self, keyid, device=None, text_idx=None, sent_emb_mode="first"): 74 | annotations = self.annotations[keyid] 75 | 76 | index = 0 77 | 78 | path = annotations["path"] 79 | annotation = annotations["annotations"][index] 80 | start = annotation["start"] 81 | end = annotation["end"] 82 | 83 | texts = [ann["text"] for ann in annotations["annotations"]] 84 | 85 | text_x_dicts = self.text_to_token_emb(texts) # [{"x": ..., "length": ...}, {"x": ..., "length"}: ..., ... ] 86 | motion_x_dict = self.motion_loader( 87 | path=path, 88 | start=start, 89 | end=end, 90 | ) 91 | 92 | if sent_emb_mode == "first": 93 | sent_emb = self.text_to_sent_emb(texts[0]) 94 | elif sent_emb_mode == "average": 95 | sent_emb = torch.stack([self.text_to_sent_emb(text) for text in texts]) 96 | sent_emb = torch.mean(sent_emb, axis=0) 97 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0) 98 | 99 | output = { 100 | "motion_x_dict": motion_x_dict, 101 | "text_x_dict": text_x_dicts, 102 | "text": texts, 103 | "keyid": keyid, 104 | "sent_emb": sent_emb, 105 | } 106 | 107 | if device is not None: 108 | output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device) 109 | for text_x_dict in output["text_x_dict"]: 110 | text_x_dict["x"] = text_x_dict["x"].to(device) 111 | 112 | return output 113 | 114 | 115 | def write_json(data, path): 116 | with open(path, "w") as ff: 117 | ff.write(json.dumps(data, indent=4)) 118 | 119 | -------------------------------------------------------------------------------- /src/guofeats/__init__.py: -------------------------------------------------------------------------------- 1 | from .motion_representation import joints_to_guofeats, guofeats_to_joints # noqa 2 | -------------------------------------------------------------------------------- /src/guofeats/paramUtil.py: -------------------------------------------------------------------------------- 1 | # Taken from 2 | # https://github.com/EricGuo5513/HumanML3D/blob/main/paramUtil.py 3 | 4 | import numpy as np 5 | 6 | # Define a kinematic tree for the skeletal struture 7 | kit_kinematic_chain = [ 8 | [0, 11, 12, 13, 14, 15], 9 | [0, 16, 17, 18, 19, 20], 10 | [0, 1, 2, 3, 4], 11 | [3, 5, 6, 7], 12 | [3, 8, 9, 10], 13 | ] 14 | 15 | kit_raw_offsets = np.array( 16 | [ 17 | [0, 0, 0], 18 | [0, 1, 0], 19 | [0, 1, 0], 20 | [0, 1, 0], 21 | [0, 1, 0], 22 | [1, 0, 0], 23 | [0, -1, 0], 24 | [0, -1, 0], 25 | [-1, 0, 0], 26 | [0, -1, 0], 27 | [0, -1, 0], 28 | [1, 0, 0], 29 | [0, -1, 0], 30 | [0, -1, 0], 31 | [0, 0, 1], 32 | [0, 0, 1], 33 | [-1, 0, 0], 34 | [0, -1, 0], 35 | [0, -1, 0], 36 | [0, 0, 1], 37 | [0, 0, 1], 38 | ] 39 | ) 40 | 41 | t2m_raw_offsets = np.array( 42 | [ 43 | [0, 0, 0], 44 | [1, 0, 0], 45 | [-1, 0, 0], 46 | [0, 1, 0], 47 | [0, -1, 0], 48 | [0, -1, 0], 49 | [0, 1, 0], 50 | [0, -1, 0], 51 | [0, -1, 0], 52 | [0, 1, 0], 53 | [0, 0, 1], 54 | [0, 0, 1], 55 | [0, 1, 0], 56 | [1, 0, 0], 57 | [-1, 0, 0], 58 | [0, 0, 1], 59 | [0, -1, 0], 60 | [0, -1, 0], 61 | [0, -1, 0], 62 | [0, -1, 0], 63 | [0, -1, 0], 64 | [0, -1, 0], 65 | ] 66 | ) 67 | 68 | t2m_kinematic_chain = [ 69 | [0, 2, 5, 8, 11], 70 | [0, 1, 4, 7, 10], 71 | [0, 3, 6, 9, 12, 15], 72 | [9, 14, 17, 19, 21], 73 | [9, 13, 16, 18, 20], 74 | ] 75 | t2m_left_hand_chain = [ 76 | [20, 22, 23, 24], 77 | [20, 34, 35, 36], 78 | [20, 25, 26, 27], 79 | [20, 31, 32, 33], 80 | [20, 28, 29, 30], 81 | ] 82 | t2m_right_hand_chain = [ 83 | [21, 43, 44, 45], 84 | [21, 46, 47, 48], 85 | [21, 40, 41, 42], 86 | [21, 37, 38, 39], 87 | [21, 49, 50, 51], 88 | ] 89 | 90 | 91 | kit_tgt_skel_id = "03950" 92 | t2m_tgt_skel_id = "000021" 93 | -------------------------------------------------------------------------------- /src/guofeats/skeleton_example_h3d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/src/guofeats/skeleton_example_h3d.npy -------------------------------------------------------------------------------- /src/joints.py: -------------------------------------------------------------------------------- 1 | JOINT_NAMES = { 2 | "smpljoints": [ 3 | "pelvis", 4 | "left_hip", 5 | "right_hip", 6 | "spine1", 7 | "left_knee", 8 | "right_knee", 9 | "spine2", 10 | "left_ankle", 11 | "right_ankle", 12 | "spine3", 13 | "left_foot", 14 | "right_foot", 15 | "neck", 16 | "left_collar", 17 | "right_collar", 18 | "head", 19 | "left_shoulder", 20 | "right_shoulder", 21 | "left_elbow", 22 | "right_elbow", 23 | "left_wrist", 24 | "right_wrist", 25 | "left_hand", 26 | "right_hand", 27 | ], 28 | "guoh3djoints": [ 29 | "pelvis", 30 | "left_hip", 31 | "right_hip", 32 | "spine1", 33 | "left_knee", 34 | "right_knee", 35 | "spine2", 36 | "left_ankle", 37 | "right_ankle", 38 | "spine3", 39 | "left_foot", 40 | "right_foot", 41 | "neck", 42 | "left_collar", 43 | "right_collar", 44 | "head", 45 | "left_shoulder", 46 | "right_shoulder", 47 | "left_elbow", 48 | "right_elbow", 49 | "left_wrist", 50 | "right_wrist", 51 | ], 52 | } 53 | 54 | INFOS = { 55 | "smpljoints": { 56 | "LM": JOINT_NAMES["smpljoints"].index("left_ankle"), 57 | "RM": JOINT_NAMES["smpljoints"].index("right_ankle"), 58 | "LF": JOINT_NAMES["smpljoints"].index("left_foot"), 59 | "RF": JOINT_NAMES["smpljoints"].index("right_foot"), 60 | "LS": JOINT_NAMES["smpljoints"].index("left_shoulder"), 61 | "RS": JOINT_NAMES["smpljoints"].index("right_shoulder"), 62 | "LH": JOINT_NAMES["smpljoints"].index("left_hip"), 63 | "RH": JOINT_NAMES["smpljoints"].index("right_hip"), 64 | "njoints": len(JOINT_NAMES["smpljoints"]), 65 | }, 66 | "guoh3djoints": { 67 | "LM": JOINT_NAMES["guoh3djoints"].index("left_ankle"), 68 | "RM": JOINT_NAMES["guoh3djoints"].index("right_ankle"), 69 | "LF": JOINT_NAMES["guoh3djoints"].index("left_foot"), 70 | "RF": JOINT_NAMES["guoh3djoints"].index("right_foot"), 71 | "LS": JOINT_NAMES["guoh3djoints"].index("left_shoulder"), 72 | "RS": JOINT_NAMES["guoh3djoints"].index("right_shoulder"), 73 | "LH": JOINT_NAMES["guoh3djoints"].index("left_hip"), 74 | "RH": JOINT_NAMES["guoh3djoints"].index("right_hip"), 75 | "njoints": len(JOINT_NAMES["guoh3djoints"]), 76 | }, 77 | } 78 | -------------------------------------------------------------------------------- /src/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import DictConfig 3 | import logging 4 | import hydra 5 | 6 | from src.config import read_config 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | # split the lightning checkpoint into 12 | # seperate state_dict modules for faster loading 13 | def extract_ckpt(run_dir, ckpt_name="last"): 14 | import torch 15 | 16 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt") 17 | 18 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights") 19 | os.makedirs(extracted_path, exist_ok=True) 20 | 21 | new_path_template = os.path.join(extracted_path, "{}.pt") 22 | ckpt_dict = torch.load(ckpt_path) 23 | state_dict = ckpt_dict["state_dict"] 24 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()])) 25 | 26 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example 27 | for module_name in module_names: 28 | path = new_path_template.format(module_name) 29 | sub_state_dict = { 30 | ".".join(x.split(".")[1:]): y.cpu() 31 | for x, y in state_dict.items() 32 | if x.split(".")[0] == module_name 33 | } 34 | torch.save(sub_state_dict, path) 35 | 36 | 37 | def load_model(run_dir, **params): 38 | # Load last config 39 | cfg = read_config(run_dir) 40 | cfg.run_dir = run_dir 41 | return load_model_from_cfg(cfg, **params) 42 | 43 | 44 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True): 45 | import src.prepare # noqa 46 | import torch 47 | 48 | run_dir = cfg.run_dir 49 | model = hydra.utils.instantiate(cfg.model) 50 | 51 | # Loading modules one by one 52 | # motion_encoder / text_encoder / text_decoder 53 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") 54 | 55 | if not os.path.exists(pt_path): 56 | logger.info("The extracted model is not found. Split into submodules..") 57 | extract_ckpt(run_dir, ckpt_name) 58 | 59 | for fname in os.listdir(pt_path): 60 | module_name, ext = os.path.splitext(fname) 61 | 62 | if ext != ".pt": 63 | continue 64 | 65 | module = getattr(model, module_name, None) 66 | if module is None: 67 | continue 68 | 69 | module_path = os.path.join(pt_path, fname) 70 | state_dict = torch.load(module_path) 71 | module.load_state_dict(state_dict) 72 | logger.info(f" {module_name} loaded") 73 | 74 | logger.info("Loading previous checkpoint done") 75 | model = model.to(device) 76 | logger.info(f"Put the model on {device}") 77 | if eval_mode: 78 | model = model.eval() 79 | logger.info("Put the model in eval mode") 80 | return model 81 | 82 | 83 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model") 84 | def hydra_load_model(cfg: DictConfig) -> None: 85 | run_dir = cfg.run_dir 86 | ckpt_name = cfg.ckpt 87 | device = cfg.device 88 | eval_mode = cfg.eval_mode 89 | return load_model(run_dir, ckpt_name, device, eval_mode) 90 | 91 | 92 | if __name__ == "__main__": 93 | hydra_load_model() 94 | -------------------------------------------------------------------------------- /src/logger/csv.py: -------------------------------------------------------------------------------- 1 | # from pytorch_lightning/loggers/csv_logs.py 2 | # of pytorch_lightning version 2.04 3 | 4 | # Copyright The Lightning AI team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ 18 | CSV logger 19 | ---------- 20 | 21 | CSV logger for basic experiment logging that does not require opening ports 22 | 23 | """ 24 | import logging 25 | import os 26 | from argparse import Namespace 27 | from typing import Any, Dict, Optional, Union 28 | 29 | # from lightning_fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter 30 | # from lightning_fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger 31 | # Local replacement 32 | from .csv_fabric import _ExperimentWriter as _FabricExperimentWriter 33 | from .csv_fabric import CSVLogger as FabricCSVLogger 34 | 35 | from lightning_fabric.loggers.logger import rank_zero_experiment 36 | from lightning_fabric.utilities.logger import _convert_params 37 | from lightning_fabric.utilities.types import _PATH 38 | from pytorch_lightning.core.saving import save_hparams_to_yaml 39 | from pytorch_lightning.loggers.logger import Logger 40 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 41 | 42 | log = logging.getLogger(__name__) 43 | 44 | 45 | class ExperimentWriter(_FabricExperimentWriter): 46 | r"""Experiment writer for CSVLogger. 47 | 48 | Currently, supports to log hyperparameters and metrics in YAML and CSV 49 | format, respectively. 50 | 51 | Args: 52 | log_dir: Directory for the experiment logs 53 | """ 54 | 55 | NAME_HPARAMS_FILE = "hparams.yaml" 56 | 57 | def __init__(self, log_dir: str) -> None: 58 | super().__init__(log_dir=log_dir) 59 | self.hparams: Dict[str, Any] = {} 60 | 61 | def log_hparams(self, params: Dict[str, Any]) -> None: 62 | """Record hparams.""" 63 | self.hparams.update(params) 64 | 65 | def save(self) -> None: 66 | """Save recorded hparams and metrics into files.""" 67 | hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) 68 | save_hparams_to_yaml(hparams_file, self.hparams) 69 | return super().save() 70 | 71 | 72 | class CSVLogger(Logger, FabricCSVLogger): 73 | r"""Log to local file system in yaml and CSV format. 74 | 75 | Logs are saved to ``os.path.join(save_dir, name)``. 76 | 77 | Example: 78 | >>> from pytorch_lightning import Trainer 79 | >>> from pytorch_lightning.loggers import CSVLogger 80 | >>> logger = CSVLogger("logs", name="my_exp_name") 81 | >>> trainer = Trainer(logger=logger) 82 | 83 | Args: 84 | save_dir: Save directory 85 | name: Experiment name. Defaults to ``'lightning_logs'``. 86 | prefix: A string to put at the beginning of metric keys. 87 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). 88 | """ 89 | 90 | LOGGER_JOIN_CHAR = "-" 91 | 92 | def __init__( 93 | self, 94 | save_dir: _PATH, 95 | name: str = "lightning_logs", 96 | prefix: str = "", 97 | flush_logs_every_n_steps: int = 100, 98 | ): 99 | super().__init__( 100 | root_dir=save_dir, 101 | name=name, 102 | prefix=prefix, 103 | flush_logs_every_n_steps=flush_logs_every_n_steps, 104 | ) 105 | self._save_dir = os.fspath(save_dir) 106 | 107 | @property 108 | def root_dir(self) -> str: 109 | """Parent directory for all checkpoint subdirectories. 110 | 111 | If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will 112 | be saved in "save_dir/" 113 | """ 114 | return os.path.join(self.save_dir, self.name) 115 | 116 | @property 117 | def log_dir(self) -> str: 118 | """The log directory for this run.""" 119 | return self.root_dir 120 | 121 | @property 122 | def save_dir(self) -> str: 123 | """The current directory where logs are saved. 124 | 125 | Returns: 126 | The path to current directory where logs are saved. 127 | """ 128 | return self._save_dir 129 | 130 | @rank_zero_only 131 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 132 | # don't log hyperparameters 133 | # already done in the config 134 | return 135 | 136 | @property 137 | @rank_zero_experiment 138 | def experiment(self) -> _FabricExperimentWriter: 139 | r""" 140 | 141 | Actual _ExperimentWriter object. To use _ExperimentWriter features in your 142 | :class:`~pytorch_lightning.core.module.LightningModule` do the following. 143 | 144 | Example:: 145 | 146 | self.logger.experiment.some_experiment_writer_function() 147 | 148 | """ 149 | if self._experiment is not None: 150 | return self._experiment 151 | 152 | self._fs.makedirs(self.root_dir, exist_ok=True) 153 | self._experiment = ExperimentWriter(log_dir=self.log_dir) 154 | return self._experiment 155 | -------------------------------------------------------------------------------- /src/logger/csv_fabric.py: -------------------------------------------------------------------------------- 1 | # from lightning_fabric/loggers/csv_logs.py 2 | # of lightning_fabric version 2.04 3 | 4 | # Copyright The Lightning AI team. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import csv 19 | import logging 20 | import os 21 | from argparse import Namespace 22 | from typing import Any, Dict, List, Optional, Union 23 | 24 | from torch import Tensor 25 | 26 | from lightning_fabric.loggers.logger import Logger, rank_zero_experiment 27 | from lightning_fabric.utilities.cloud_io import get_filesystem 28 | from lightning_fabric.utilities.logger import _add_prefix 29 | from lightning_fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn 30 | from lightning_fabric.utilities.types import _PATH 31 | 32 | log = logging.getLogger(__name__) 33 | 34 | 35 | class CSVLogger(Logger): 36 | r"""Log to the local file system in CSV format. 37 | 38 | Logs are saved to ``os.path.join(root_dir, name)``. 39 | 40 | Args: 41 | root_dir: The root directory in which all your experiments with different names and versions will be stored. 42 | name: Experiment name. Defaults to ``'lightning_logs'``. 43 | prefix: A string to put at the beginning of metric keys. 44 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). 45 | 46 | Example:: 47 | 48 | from lightning_fabric.loggers import CSVLogger 49 | 50 | logger = CSVLogger("path/to/logs/root", name="my_model") 51 | logger.log_metrics({"loss": 0.235, "acc": 0.75}) 52 | logger.finalize("success") 53 | """ 54 | 55 | LOGGER_JOIN_CHAR = "-" 56 | 57 | def __init__( 58 | self, 59 | root_dir: _PATH, 60 | name: str = "lightning_logs", 61 | prefix: str = "", 62 | flush_logs_every_n_steps: int = 100, 63 | ): 64 | super().__init__() 65 | root_dir = os.fspath(root_dir) 66 | self._root_dir = root_dir 67 | self._name = name or "" 68 | self._prefix = prefix 69 | self._fs = get_filesystem(root_dir) 70 | self._experiment: Optional[_ExperimentWriter] = None 71 | self._flush_logs_every_n_steps = flush_logs_every_n_steps 72 | 73 | @property 74 | def name(self) -> str: 75 | """Gets the name of the experiment. 76 | 77 | Returns: 78 | The name of the experiment. 79 | """ 80 | return self._name 81 | 82 | @property 83 | def version(self) -> str: 84 | return "" 85 | 86 | @property 87 | def root_dir(self) -> str: 88 | """Gets the save directory where the versioned CSV experiments are saved.""" 89 | return self._root_dir 90 | 91 | @property 92 | def log_dir(self) -> str: 93 | """The log directory for this run.""" 94 | # create a pseudo standard path 95 | return os.path.join(self.root_dir, self.name) 96 | 97 | @property 98 | @rank_zero_experiment 99 | def experiment(self) -> "_ExperimentWriter": 100 | """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the 101 | following. 102 | 103 | Example:: 104 | 105 | self.logger.experiment.some_experiment_writer_function() 106 | """ 107 | if self._experiment is not None: 108 | return self._experiment 109 | 110 | os.makedirs(self.root_dir, exist_ok=True) 111 | self._experiment = _ExperimentWriter(log_dir=self.log_dir) 112 | return self._experiment 113 | 114 | @rank_zero_only 115 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 116 | raise NotImplementedError( 117 | "The `CSVLogger` does not yet support logging hyperparameters." 118 | ) 119 | 120 | @rank_zero_only 121 | def log_metrics( 122 | self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None 123 | ) -> None: 124 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) 125 | self.experiment.log_metrics(metrics, step) 126 | if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0: 127 | self.save() 128 | 129 | @rank_zero_only 130 | def save(self) -> None: 131 | super().save() 132 | self.experiment.save() 133 | 134 | @rank_zero_only 135 | def finalize(self, status: str) -> None: 136 | if self._experiment is None: 137 | # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been 138 | # initialized there 139 | return 140 | self.save() 141 | 142 | 143 | class _ExperimentWriter: 144 | r"""Experiment writer for CSVLogger. 145 | 146 | Args: 147 | log_dir: Directory for the experiment logs 148 | """ 149 | 150 | NAME_METRICS_FILE = "metrics.csv" 151 | 152 | def __init__(self, log_dir: str) -> None: 153 | self.metrics: List[Dict[str, float]] = [] 154 | 155 | self._fs = get_filesystem(log_dir) 156 | self.log_dir = log_dir 157 | self._fs.makedirs(self.log_dir, exist_ok=True) 158 | self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) 159 | 160 | if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir): 161 | # Read previous logs 162 | if os.path.exists(self.metrics_file_path): 163 | with self._fs.open(self.metrics_file_path, "r") as f: 164 | reader = csv.DictReader(f) 165 | self.metrics = [x for x in reader] 166 | 167 | def log_metrics( 168 | self, metrics_dict: Dict[str, float], step: Optional[int] = None 169 | ) -> None: 170 | """Record metrics.""" 171 | 172 | def _handle_value(value: Union[Tensor, Any]) -> Any: 173 | if isinstance(value, Tensor): 174 | return value.item() 175 | return value 176 | 177 | if step is None: 178 | step = len(self.metrics) 179 | 180 | metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} 181 | metrics["step"] = step 182 | self.metrics.append(metrics) 183 | 184 | def save(self) -> None: 185 | """Save recorded metrics into files.""" 186 | if not self.metrics: 187 | return 188 | 189 | last_m = {} 190 | for m in self.metrics: 191 | last_m.update(m) 192 | metrics_keys = list(last_m.keys()) 193 | 194 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 195 | writer = csv.DictWriter(f, fieldnames=metrics_keys) 196 | writer.writeheader() 197 | writer.writerows(self.metrics) 198 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tqdm 3 | 4 | 5 | # from https://stackoverflow.com/questions/38543506/change-logging-print-function-to-tqdm-write-so-logging-doesnt-interfere-wit 6 | class TqdmLoggingHandler(logging.Handler): 7 | def __init__(self, level=logging.NOTSET): 8 | super().__init__(level) 9 | 10 | def emit(self, record): 11 | try: 12 | msg = self.format(record) 13 | tqdm.tqdm.write(msg) 14 | self.flush() 15 | except Exception: 16 | self.handleError(record) 17 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa 2 | from .temos import TEMOS # noqa 3 | from .tmr import TMR # noqa 4 | from .text_encoder import TextToEmb # noqa 5 | -------------------------------------------------------------------------------- /src/model/actor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | import numpy as np 7 | 8 | from einops import repeat 9 | 10 | 11 | class PositionalEncoding(nn.Module): 12 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: 13 | super().__init__() 14 | self.batch_first = batch_first 15 | 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 20 | div_term = torch.exp( 21 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) 22 | ) 23 | pe[:, 0::2] = torch.sin(position * div_term) 24 | pe[:, 1::2] = torch.cos(position * div_term) 25 | pe = pe.unsqueeze(0).transpose(0, 1) 26 | self.register_buffer("pe", pe, persistent=False) 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | if self.batch_first: 30 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] 31 | else: 32 | x = x + self.pe[: x.shape[0], :] 33 | return self.dropout(x) 34 | 35 | 36 | class ACTORStyleEncoder(nn.Module): 37 | # Similar to ACTOR but "action agnostic" and more general 38 | def __init__( 39 | self, 40 | nfeats: int, 41 | vae: bool, 42 | latent_dim: int = 256, 43 | ff_size: int = 1024, 44 | num_layers: int = 4, 45 | num_heads: int = 4, 46 | dropout: float = 0.1, 47 | activation: str = "gelu", 48 | ) -> None: 49 | super().__init__() 50 | 51 | self.nfeats = nfeats 52 | self.projection = nn.Linear(nfeats, latent_dim) 53 | 54 | self.vae = vae 55 | self.nbtokens = 2 if vae else 1 56 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) 57 | 58 | self.sequence_pos_encoding = PositionalEncoding( 59 | latent_dim, dropout=dropout, batch_first=True 60 | ) 61 | 62 | seq_trans_encoder_layer = nn.TransformerEncoderLayer( 63 | d_model=latent_dim, 64 | nhead=num_heads, 65 | dim_feedforward=ff_size, 66 | dropout=dropout, 67 | activation=activation, 68 | batch_first=True, 69 | ) 70 | 71 | self.seqTransEncoder = nn.TransformerEncoder( 72 | seq_trans_encoder_layer, num_layers=num_layers 73 | ) 74 | 75 | def forward(self, x_dict: Dict) -> Tensor: 76 | x = x_dict["x"] 77 | mask = x_dict["mask"] 78 | 79 | x = self.projection(x) 80 | 81 | device = x.device 82 | bs = len(x) 83 | 84 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) 85 | xseq = torch.cat((tokens, x), 1) 86 | 87 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) 88 | aug_mask = torch.cat((token_mask, mask), 1) 89 | 90 | # add positional encoding 91 | xseq = self.sequence_pos_encoding(xseq) 92 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) 93 | return final[:, : self.nbtokens] 94 | 95 | 96 | class ACTORStyleDecoder(nn.Module): 97 | # Similar to ACTOR Decoder 98 | 99 | def __init__( 100 | self, 101 | nfeats: int, 102 | latent_dim: int = 256, 103 | ff_size: int = 1024, 104 | num_layers: int = 4, 105 | num_heads: int = 4, 106 | dropout: float = 0.1, 107 | activation: str = "gelu", 108 | ) -> None: 109 | super().__init__() 110 | output_feats = nfeats 111 | self.nfeats = nfeats 112 | 113 | self.sequence_pos_encoding = PositionalEncoding( 114 | latent_dim, dropout, batch_first=True 115 | ) 116 | 117 | seq_trans_decoder_layer = nn.TransformerDecoderLayer( 118 | d_model=latent_dim, 119 | nhead=num_heads, 120 | dim_feedforward=ff_size, 121 | dropout=dropout, 122 | activation=activation, 123 | batch_first=True, 124 | ) 125 | 126 | self.seqTransDecoder = nn.TransformerDecoder( 127 | seq_trans_decoder_layer, num_layers=num_layers 128 | ) 129 | 130 | self.final_layer = nn.Linear(latent_dim, output_feats) 131 | 132 | def forward(self, z_dict: Dict) -> Tensor: 133 | z = z_dict["z"] 134 | mask = z_dict["mask"] 135 | 136 | latent_dim = z.shape[1] 137 | bs, nframes = mask.shape 138 | 139 | z = z[:, None] # sequence of 1 element for the memory 140 | 141 | # Construct time queries 142 | time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device) 143 | time_queries = self.sequence_pos_encoding(time_queries) 144 | 145 | # Pass through the transformer decoder 146 | # with the latent vector for memory 147 | output = self.seqTransDecoder( 148 | tgt=time_queries, memory=z, tgt_key_padding_mask=~mask 149 | ) 150 | 151 | output = self.final_layer(output) 152 | # zero for padded area 153 | output[~mask] = 0 154 | return output 155 | -------------------------------------------------------------------------------- /src/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # For reference 6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence 8 | class KLLoss: 9 | def __call__(self, q, p): 10 | mu_q, logvar_q = q 11 | mu_p, logvar_p = p 12 | 13 | log_var_ratio = logvar_q - logvar_p 14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp() 15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 16 | return div.mean() 17 | 18 | def __repr__(self): 19 | return "KLLoss()" 20 | 21 | 22 | def get_sim_matrix(x, y): 23 | x_logits = torch.nn.functional.normalize(x, dim=-1) 24 | y_logits = torch.nn.functional.normalize(y, dim=-1) 25 | sim_matrix = x_logits @ y_logits.T 26 | return sim_matrix 27 | 28 | 29 | class InfoNCE_with_filtering: 30 | def __init__(self, temperature=0.7, threshold_selfsim=0.8): 31 | self.temperature = temperature 32 | self.threshold_selfsim = threshold_selfsim 33 | 34 | def filter_sim_mat_with_sent_emb(self, sim_matrix, sent_emb): 35 | # put the threshold value between -1 and 1 36 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1 37 | # Filtering too close values 38 | # mask them by putting -inf in the sim_matrix 39 | selfsim = sent_emb @ sent_emb.T 40 | selfsim_nodiag = selfsim - selfsim.diag().diag() 41 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim) 42 | sim_matrix[idx] = -torch.inf 43 | return sim_matrix # TODO check if return necessary or in place operation 44 | 45 | def __call__(self, x, y, sent_emb=None): 46 | bs, device = len(x), x.device 47 | sim_matrix = get_sim_matrix(x, y) / self.temperature 48 | 49 | if sent_emb is not None and self.threshold_selfsim: 50 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb) 51 | 52 | labels = torch.arange(bs, device=device) 53 | 54 | total_loss = ( 55 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels) 56 | ) / 2 57 | 58 | return total_loss 59 | 60 | def __repr__(self): 61 | return f"Constrastive(temp={self.temp})" 62 | 63 | 64 | class HN_InfoNCE_with_filtering(InfoNCE_with_filtering): 65 | def __init__(self, temperature=0.7, threshold_selfsim=0.8, alpha=1.0, beta=0.25): 66 | super().__init__(temperature=temperature, threshold_selfsim=threshold_selfsim) 67 | self.alpha = alpha 68 | self.beta = beta 69 | 70 | def cross_entropy_with_HN_weights(self, sim_matrix): 71 | n = sim_matrix.shape[0] 72 | 73 | labels = range(sim_matrix.shape[0]) 74 | exp_mat = torch.exp(sim_matrix) 75 | num = exp_mat[range(exp_mat.shape[0]), labels] 76 | 77 | exp_mat_beta = torch.exp(self.beta * sim_matrix) 78 | weights = (n - 1) * exp_mat_beta / torch.unsqueeze((torch.sum(exp_mat_beta, axis=1) - exp_mat_beta.diag()), dim=1) 79 | weights = weights.fill_diagonal_(self.alpha) 80 | denum = torch.sum(weights * exp_mat, axis=1) 81 | 82 | return -torch.mean(torch.log(num/denum)) 83 | 84 | def __call__(self, x, y, sent_emb=None): 85 | bs, device = len(x), x.device 86 | sim_matrix = get_sim_matrix(x, y) / self.temperature 87 | 88 | if sent_emb is not None and self.threshold_selfsim: 89 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb) 90 | 91 | total_loss = ( 92 | self.cross_entropy_with_HN_weights(sim_matrix) + self.cross_entropy_with_HN_weights(sim_matrix.T) 93 | ) / 2 94 | 95 | return total_loss 96 | 97 | def __repr__(self): 98 | return f"Constrastive(temp={self.temp})" 99 | -------------------------------------------------------------------------------- /src/model/temos.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from pytorch_lightning import LightningModule 7 | 8 | from src.model.losses import KLLoss 9 | 10 | 11 | def length_to_mask(length: List[int], device: torch.device = None) -> Tensor: 12 | if device is None: 13 | device = "cpu" 14 | 15 | if isinstance(length, list): 16 | length = torch.tensor(length, device=device) 17 | 18 | max_len = max(length) 19 | mask = torch.arange(max_len, device=device).expand( 20 | len(length), max_len 21 | ) < length.unsqueeze(1) 22 | return mask 23 | 24 | 25 | class TEMOS(LightningModule): 26 | r"""TEMOS: Generating diverse human motions 27 | from textual descriptions 28 | Find more information about the model on the following website: 29 | https://mathis.petrovich.fr/temos 30 | 31 | Args: 32 | motion_encoder: a module to encode the input motion features in the latent space (required). 33 | text_encoder: a module to encode the text embeddings in the latent space (required). 34 | motion_decoder: a module to decode the latent vector into motion features (required). 35 | vae: a boolean to make the model probabilistic (required). 36 | fact: a scaling factor for sampling the VAE (optional). 37 | sample_mean: sample the mean vector instead of random sampling (optional). 38 | lmd: dictionary of losses weights (optional). 39 | lr: learninig rate for the optimizer (optional). 40 | """ 41 | 42 | def __init__( 43 | self, 44 | motion_encoder: nn.Module, 45 | text_encoder: nn.Module, 46 | motion_decoder: nn.Module, 47 | vae: bool, 48 | fact: Optional[float] = None, 49 | sample_mean: Optional[bool] = False, 50 | lmd: Dict = {"recons": 1.0, "latent": 1.0e-5, "kl": 1.0e-5}, 51 | lr: float = 1e-4, 52 | ) -> None: 53 | super().__init__() 54 | 55 | self.motion_encoder = motion_encoder 56 | self.text_encoder = text_encoder 57 | self.motion_decoder = motion_decoder 58 | 59 | # sampling parameters 60 | self.vae = vae 61 | self.fact = fact if fact is not None else 1.0 62 | self.sample_mean = sample_mean 63 | 64 | # losses 65 | self.reconstruction_loss_fn = torch.nn.SmoothL1Loss(reduction="mean") 66 | self.latent_loss_fn = torch.nn.SmoothL1Loss(reduction="mean") 67 | self.kl_loss_fn = KLLoss() 68 | 69 | # lambda weighting for the losses 70 | self.lmd = lmd 71 | self.lr = lr 72 | 73 | def configure_optimizers(self) -> None: 74 | return {"optimizer": torch.optim.AdamW(lr=self.lr, params=self.parameters())} 75 | 76 | def _find_encoder(self, inputs, modality): 77 | assert modality in ["text", "motion", "auto"] 78 | 79 | if modality == "text": 80 | return self.text_encoder 81 | elif modality == "motion": 82 | return self.motion_encoder 83 | 84 | m_nfeats = self.motion_encoder.nfeats 85 | t_nfeats = self.text_encoder.nfeats 86 | 87 | if m_nfeats == t_nfeats: 88 | raise ValueError( 89 | "Cannot automatically find the encoder, as they share the same input space." 90 | ) 91 | 92 | nfeats = inputs["x"].shape[-1] 93 | if nfeats == m_nfeats: 94 | return self.motion_encoder 95 | elif nfeats == t_nfeats: 96 | return self.text_encoder 97 | else: 98 | raise ValueError("The inputs is not recognized.") 99 | 100 | def encode( 101 | self, 102 | inputs, 103 | modality: str = "auto", 104 | sample_mean: Optional[bool] = None, 105 | fact: Optional[float] = None, 106 | return_distribution: bool = False, 107 | ): 108 | sample_mean = self.sample_mean if sample_mean is None else sample_mean 109 | fact = self.fact if fact is None else fact 110 | 111 | # Encode the inputs 112 | encoder = self._find_encoder(inputs, modality) 113 | encoded = encoder(inputs) 114 | 115 | # Sampling 116 | if self.vae: 117 | dists = encoded.unbind(1) 118 | mu, logvar = dists 119 | if sample_mean: 120 | latent_vectors = mu 121 | else: 122 | # Reparameterization trick 123 | std = logvar.exp().pow(0.5) 124 | eps = std.data.new(std.size()).normal_() 125 | latent_vectors = mu + fact * eps * std 126 | else: 127 | dists = None 128 | (latent_vectors, _) = encoded.unbind(1) 129 | 130 | if return_distribution: 131 | return latent_vectors, dists 132 | 133 | return latent_vectors 134 | 135 | def decode( 136 | self, 137 | latent_vectors: Tensor, 138 | lengths: Optional[List[int]] = None, 139 | mask: Optional[Tensor] = None, 140 | ): 141 | mask = mask if mask is not None else length_to_mask(lengths, device=self.device) 142 | z_dict = {"z": latent_vectors, "mask": mask} 143 | motions = self.motion_decoder(z_dict) 144 | return motions 145 | 146 | # Forward: X => motions 147 | def forward( 148 | self, 149 | inputs, 150 | lengths: Optional[List[int]] = None, 151 | mask: Optional[Tensor] = None, 152 | sample_mean: Optional[bool] = None, 153 | fact: Optional[float] = None, 154 | return_all: bool = False, 155 | ) -> List[Tensor]: 156 | # Encoding the inputs and sampling if needed 157 | latent_vectors, distributions = self.encode( 158 | inputs, sample_mean=sample_mean, fact=fact, return_distribution=True 159 | ) 160 | # Decoding the latent vector: generating motions 161 | motions = self.decode(latent_vectors, lengths, mask) 162 | 163 | if return_all: 164 | return {"motions": motions, 165 | "latent_vectors": latent_vectors, 166 | "distributions": distributions} 167 | 168 | return {"motions": motions} 169 | 170 | def call_models(self, batch): 171 | text_x_dict = batch["text_x_dict"] 172 | motion_x_dict = batch["motion_x_dict"] 173 | 174 | mask = motion_x_dict["mask"] 175 | 176 | # text -> motion 177 | t_results = self(text_x_dict, mask=mask, return_all=True) 178 | 179 | # motion -> motion 180 | m_results = self(motion_x_dict, mask=mask, return_all=True) 181 | 182 | return t_results, m_results 183 | 184 | def compute_loss(self, batch: Dict) -> Dict: 185 | t_results, m_results = self.call_models(batch) 186 | t_motions, t_latents, t_dists = t_results["motions"], t_results["latent_vectors"], t_results["distributions"] 187 | m_motions, m_latents, m_dists = m_results["motions"], m_results["latent_vectors"], m_results["distributions"] 188 | 189 | ref_motions = batch["motion_x_dict"]["x"] 190 | 191 | # Store all losses 192 | losses = {} 193 | 194 | # Reconstructions losses 195 | # fmt: off 196 | losses["recons"] = ( 197 | + self.reconstruction_loss_fn(t_motions, ref_motions) # text -> motion 198 | + self.reconstruction_loss_fn(m_motions, ref_motions) # motion -> motion 199 | ) 200 | # fmt: on 201 | 202 | # VAE losses 203 | if self.vae: 204 | # Create a centred normal distribution to compare with 205 | # logvar = 0 -> std = 1 206 | ref_mus = torch.zeros_like(m_dists[0]) 207 | ref_logvar = torch.zeros_like(m_dists[1]) 208 | ref_dists = (ref_mus, ref_logvar) 209 | 210 | losses["kl"] = ( 211 | self.kl_loss_fn(t_dists, m_dists) # text_to_motion 212 | + self.kl_loss_fn(m_dists, t_dists) # motion_to_text 213 | + self.kl_loss_fn(m_dists, ref_dists) # motion 214 | + self.kl_loss_fn(t_dists, ref_dists) # text 215 | ) 216 | 217 | # Latent manifold loss 218 | losses["latent"] = self.latent_loss_fn(t_latents, m_latents) 219 | 220 | # Weighted average of the losses 221 | losses["loss"] = sum( 222 | self.lmd[x] * val for x, val in losses.items() if x in self.lmd 223 | ) 224 | return losses 225 | 226 | def training_step(self, batch: Dict, batch_idx: int) -> Tensor: 227 | bs = len(batch["motion_x_dict"]["x"]) 228 | losses = self.compute_loss(batch) 229 | 230 | for loss_name in sorted(losses): 231 | loss_val = losses[loss_name] 232 | self.log( 233 | f"train_{loss_name}", 234 | loss_val, 235 | on_epoch=True, 236 | on_step=True, 237 | batch_size=bs, 238 | ) 239 | return losses["loss"] 240 | 241 | def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: 242 | bs = len(batch["motion_x_dict"]["x"]) 243 | losses = self.compute_loss(batch) 244 | 245 | for loss_name in sorted(losses): 246 | loss_val = losses[loss_name] 247 | self.log( 248 | f"val_{loss_name}", 249 | loss_val, 250 | on_epoch=True, 251 | on_step=True, 252 | batch_size=bs, 253 | ) 254 | return losses["loss"] 255 | -------------------------------------------------------------------------------- /src/model/text_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | from torch import Tensor 5 | from typing import Dict, List 6 | import torch.nn.functional as F 7 | 8 | 9 | class TextToEmb(nn.Module): 10 | def __init__( 11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu" 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.device = device 16 | from transformers import AutoTokenizer, AutoModel 17 | from transformers import logging 18 | 19 | logging.set_verbosity_error() 20 | 21 | # Tokenizer 22 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath) 24 | 25 | # Text model 26 | self.text_model = AutoModel.from_pretrained(modelpath) 27 | # Then configure the model 28 | self.text_encoded_dim = self.text_model.config.hidden_size 29 | 30 | if mean_pooling: 31 | self.forward = self.forward_pooling 32 | 33 | # put it in eval mode by default 34 | self.eval() 35 | 36 | # Freeze the weights just in case 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | self.to(device) 41 | 42 | def train(self, mode: bool = True) -> nn.Module: 43 | # override it to be always false 44 | self.training = False 45 | for module in self.children(): 46 | module.train(False) 47 | return self 48 | 49 | @torch.no_grad() 50 | def forward(self, texts: List[str], device=None) -> Dict: 51 | device = device if device is not None else self.device 52 | 53 | squeeze = False 54 | if isinstance(texts, str): 55 | texts = [texts] 56 | squeeze = True 57 | 58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 59 | output = self.text_model(**encoded_inputs.to(device)) 60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1) 61 | 62 | if squeeze: 63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]} 64 | else: 65 | x_dict = {"x": output.last_hidden_state, "length": length} 66 | return x_dict 67 | 68 | @torch.no_grad() 69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor: 70 | device = device if device is not None else self.device 71 | 72 | squeeze = False 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | squeeze = True 76 | 77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) 79 | output = self.text_model(**encoded_inputs.to(device)) 80 | attention_mask = encoded_inputs["attention_mask"] 81 | 82 | # Mean Pooling - Take attention mask into account for correct averaging 83 | token_embeddings = output["last_hidden_state"] 84 | input_mask_expanded = ( 85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 86 | ) 87 | sentence_embeddings = torch.sum( 88 | token_embeddings * input_mask_expanded, 1 89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 90 | # Normalize embeddings 91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 92 | if squeeze: 93 | sentence_embeddings = sentence_embeddings[0] 94 | return sentence_embeddings 95 | -------------------------------------------------------------------------------- /src/model/tmr.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from torch import Tensor 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .temos import TEMOS 7 | from .losses import InfoNCE_with_filtering 8 | from .metrics import all_contrastive_metrics 9 | 10 | 11 | # x.T will be deprecated in pytorch 12 | def transpose(x): 13 | return x.permute(*torch.arange(x.ndim - 1, -1, -1)) 14 | 15 | 16 | def get_sim_matrix(x, y): 17 | x_logits = torch.nn.functional.normalize(x, dim=-1) 18 | y_logits = torch.nn.functional.normalize(y, dim=-1) 19 | sim_matrix = x_logits @ transpose(y_logits) 20 | return sim_matrix 21 | 22 | 23 | # Scores are between 0 and 1 24 | def get_score_matrix(x, y): 25 | sim_matrix = get_sim_matrix(x, y) 26 | scores = sim_matrix / 2 + 0.5 27 | return scores 28 | 29 | 30 | class TMR(TEMOS): 31 | r"""TMR: Text-to-Motion Retrieval 32 | Using Contrastive 3D Human Motion Synthesis 33 | Find more information about the model on the following website: 34 | https://mathis.petrovich.fr/tmr 35 | 36 | Args: 37 | motion_encoder: a module to encode the input motion features in the latent space (required). 38 | text_encoder: a module to encode the text embeddings in the latent space (required). 39 | motion_decoder: a module to decode the latent vector into motion features (required). 40 | vae: a boolean to make the model probabilistic (required). 41 | fact: a scaling factor for sampling the VAE (optional). 42 | sample_mean: sample the mean vector instead of random sampling (optional). 43 | lmd: dictionary of losses weights (optional). 44 | lr: learninig rate for the optimizer (optional). 45 | temperature: temperature of the softmax in the contrastive loss (optional). 46 | threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional). 47 | threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional). 48 | """ 49 | 50 | def __init__( 51 | self, 52 | motion_encoder: nn.Module, 53 | text_encoder: nn.Module, 54 | motion_decoder: nn.Module, 55 | vae: bool, 56 | contrastive_loss: Optional[InfoNCE_with_filtering] = None, 57 | temperature: float = 0.7, # For compatibility with TMR original code 58 | threshold_selfsim: float = 0.80, # For compatibility with TMR original code 59 | fact: Optional[float] = None, 60 | sample_mean: Optional[bool] = False, 61 | lmd: Dict = {"recons": 1.0, "latent": 1.0e-5, "kl": 1.0e-5, "contrastive": 0.1}, 62 | lr: float = 1e-4, 63 | threshold_selfsim_metrics: float = 0.95, 64 | ) -> None: 65 | # Initialize module like TEMOS 66 | super().__init__( 67 | motion_encoder=motion_encoder, 68 | text_encoder=text_encoder, 69 | motion_decoder=motion_decoder, 70 | vae=vae, 71 | fact=fact, 72 | sample_mean=sample_mean, 73 | lmd=lmd, 74 | lr=lr, 75 | ) 76 | 77 | # adding the contrastive loss 78 | self.contrastive_loss_fn = contrastive_loss 79 | if self.contrastive_loss_fn is None: # For compatibility with TMR original code 80 | self.contrastive_loss_fn = InfoNCE_with_filtering( 81 | temperature=temperature, threshold_selfsim=threshold_selfsim 82 | ) 83 | self.threshold_selfsim_metrics = threshold_selfsim_metrics 84 | 85 | # store validation values to compute retrieval metrics 86 | # on the whole validation set 87 | self.validation_step_t_latents = [] 88 | self.validation_step_m_latents = [] 89 | self.validation_step_sent_emb = [] 90 | 91 | def compute_loss(self, batch: Dict, return_all=False) -> Dict: 92 | t_results, m_results = self.call_models(batch) 93 | t_motions, t_latents, t_dists = t_results["motions"], t_results["latent_vectors"], t_results["distributions"] 94 | m_motions, m_latents, m_dists = m_results["motions"], m_results["latent_vectors"], m_results["distributions"] 95 | 96 | ref_motions = batch["motion_x_dict"]["x"] 97 | 98 | # sentence embeddings 99 | sent_emb = batch["sent_emb"] 100 | 101 | # Store all losses 102 | losses = {} 103 | 104 | # Reconstructions losses 105 | # fmt: off 106 | losses["recons"] = ( 107 | + self.reconstruction_loss_fn(t_motions, ref_motions) # text -> motion 108 | + self.reconstruction_loss_fn(m_motions, ref_motions) # motion -> motion 109 | ) 110 | # fmt: on 111 | 112 | # VAE losses 113 | if self.vae: 114 | # Create a centred normal distribution to compare with 115 | # logvar = 0 -> std = 1 116 | ref_mus = torch.zeros_like(m_dists[0]) 117 | ref_logvar = torch.zeros_like(m_dists[1]) 118 | ref_dists = (ref_mus, ref_logvar) 119 | 120 | losses["kl"] = ( 121 | self.kl_loss_fn(t_dists, m_dists) # text_to_motion 122 | + self.kl_loss_fn(m_dists, t_dists) # motion_to_text 123 | + self.kl_loss_fn(m_dists, ref_dists) # motion 124 | + self.kl_loss_fn(t_dists, ref_dists) # text 125 | ) 126 | 127 | # Latent manifold loss 128 | losses["latent"] = self.latent_loss_fn(t_latents, m_latents) 129 | 130 | # TMR: adding the contrastive loss 131 | losses["contrastive"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb) 132 | 133 | # Weighted average of the losses 134 | losses["loss"] = sum( 135 | self.lmd[x] * val for x, val in losses.items() if x in self.lmd 136 | ) 137 | 138 | # Used for the validation step 139 | if return_all: 140 | return losses, t_latents, m_latents 141 | 142 | return losses 143 | 144 | def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: 145 | bs = len(batch["motion_x_dict"]["x"]) 146 | losses, t_latents, m_latents = self.compute_loss(batch, return_all=True) 147 | 148 | # Store the latent vectors 149 | self.validation_step_t_latents.append(t_latents) 150 | self.validation_step_m_latents.append(m_latents) 151 | self.validation_step_sent_emb.append(batch["sent_emb"]) 152 | 153 | for loss_name in sorted(losses): 154 | loss_val = losses[loss_name] 155 | self.log( 156 | f"val_{loss_name}", 157 | loss_val, 158 | on_epoch=True, 159 | on_step=True, 160 | batch_size=bs, 161 | ) 162 | 163 | return losses["loss"] 164 | 165 | def on_validation_epoch_end(self): 166 | # Compute contrastive metrics on the whole batch 167 | t_latents = torch.cat(self.validation_step_t_latents) 168 | m_latents = torch.cat(self.validation_step_m_latents) 169 | sent_emb = torch.cat(self.validation_step_sent_emb) 170 | 171 | # Compute the similarity matrix 172 | sim_matrix = get_sim_matrix(t_latents, m_latents).cpu().numpy() 173 | 174 | contrastive_metrics = all_contrastive_metrics( 175 | sim_matrix, 176 | emb=sent_emb.cpu().numpy(), 177 | threshold=self.threshold_selfsim_metrics, 178 | ) 179 | 180 | for loss_name in sorted(contrastive_metrics): 181 | loss_val = contrastive_metrics[loss_name] 182 | self.log( 183 | f"val_{loss_name}_epoch", 184 | loss_val, 185 | on_epoch=True, 186 | on_step=False, 187 | ) 188 | 189 | self.validation_step_t_latents.clear() 190 | self.validation_step_m_latents.clear() 191 | self.validation_step_sent_emb.clear() 192 | -------------------------------------------------------------------------------- /src/model/tmr_text_averaging.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from torch import Tensor 3 | 4 | import torch 5 | from .tmr import TMR 6 | 7 | class TMRTextAveraging(TMR): 8 | """Compatible with AugmentedTextMotionDataset Dataset object and collate_text_motion_multiple_texts collate function.""" 9 | 10 | # Forward: X => motions 11 | def forward( 12 | self, 13 | inputs, 14 | text_slices: Optional[List[int]] = None, 15 | lengths: Optional[List[int]] = None, 16 | mask: Optional[Tensor] = None, 17 | sample_mean: Optional[bool] = None, 18 | fact: Optional[float] = None, 19 | return_all: bool = False, 20 | ) -> List[Tensor]: 21 | 22 | # Encoding the inputs and sampling if needed 23 | latent_vectors, distributions = self.encode( 24 | inputs, sample_mean=sample_mean, fact=fact, return_distribution=True 25 | ) 26 | 27 | # Averages over the different text embbedings for each sample. 28 | if text_slices is not None: 29 | latent_vectors = [torch.mean(latent_vectors[i:j], dim=0) for i, j in text_slices] 30 | latent_vectors = torch.stack(latent_vectors, dim=0) 31 | distributions = list(distributions) 32 | distributions[0] = torch.stack([torch.mean(distributions[0][i:j], dim=0) for i, j in text_slices], dim=0) 33 | distributions[1] = torch.stack([torch.mean(distributions[1][i:j], dim=0) for i, j in text_slices], dim=0) 34 | distributions = tuple(distributions) 35 | 36 | # Decoding the latent vector: generating motions 37 | motions = self.decode(latent_vectors, lengths, mask) 38 | 39 | if return_all: 40 | return {"motions": motions, 41 | "latent_vectors": latent_vectors, 42 | "distributions": distributions} 43 | 44 | return {"motions": motions} 45 | 46 | def call_models(self, batch): 47 | text_x_dict = batch["text_x_dict"] 48 | motion_x_dict = batch["motion_x_dict"] 49 | text_slices = batch["text_slices"] 50 | 51 | mask = motion_x_dict["mask"] 52 | 53 | # text -> motion 54 | t_results = self(text_x_dict, mask=mask, return_all=True, text_slices=text_slices) 55 | 56 | # motion -> motion 57 | m_results = self(motion_x_dict, mask=mask, return_all=True) 58 | 59 | return t_results, m_results 60 | -------------------------------------------------------------------------------- /src/prepare.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import warnings 4 | 5 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator") 6 | logger.setLevel(logging.ERROR) 7 | 8 | 9 | warnings.filterwarnings( 10 | "ignore", "The PyTorch API of nested tensors is in prototype stage*" 11 | ) 12 | 13 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*") 14 | 15 | torch.set_float32_matmul_precision("high") 16 | -------------------------------------------------------------------------------- /src/renderer/matplotlib.py: -------------------------------------------------------------------------------- 1 | # From TEMOS: temos/render/anim.py 2 | # Assume Z is the gravity axis 3 | # Inspired by 4 | # - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py 5 | # - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py 6 | 7 | import logging 8 | 9 | from dataclasses import dataclass 10 | from typing import List, Tuple, Optional 11 | import numpy as np 12 | from src.rifke import canonicalize_rotation 13 | 14 | logger = logging.getLogger("matplotlib.animation") 15 | logger.setLevel(logging.ERROR) 16 | 17 | colors = ("black", "magenta", "red", "green", "blue") 18 | 19 | KINEMATIC_TREES = { 20 | "smpljoints": [ 21 | [0, 3, 6, 9, 12, 15], 22 | [9, 13, 16, 18, 20], 23 | [9, 14, 17, 19, 21], 24 | [0, 1, 4, 7, 10], 25 | [0, 2, 5, 8, 11], 26 | ], 27 | "guoh3djoints": [ # no hands 28 | [0, 3, 6, 9, 12, 15], 29 | [9, 13, 16, 18, 20], 30 | [9, 14, 17, 19, 21], 31 | [0, 1, 4, 7, 10], 32 | [0, 2, 5, 8, 11], 33 | ], 34 | } 35 | 36 | 37 | @dataclass 38 | class MatplotlibRender: 39 | jointstype: str = "smpljoints" 40 | fps: float = 20.0 41 | colors: List[str] = colors 42 | figsize: int = 4 43 | fontsize: int = 15 44 | canonicalize: bool = False 45 | 46 | def __call__( 47 | self, 48 | joints, 49 | highlights=None, 50 | title: str = "", 51 | output: str = "notebook", 52 | jointstype=None, 53 | ): 54 | jointstype = jointstype if jointstype is not None else self.jointstype 55 | render_animation( 56 | joints, 57 | title=title, 58 | highlights=highlights, 59 | output=output, 60 | jointstype=jointstype, 61 | fps=self.fps, 62 | colors=self.colors, 63 | figsize=(self.figsize, self.figsize), 64 | fontsize=self.fontsize, 65 | canonicalize=self.canonicalize, 66 | ) 67 | 68 | 69 | def init_axis(fig, title, radius=1.5): 70 | ax = fig.add_subplot(1, 1, 1, projection="3d") 71 | ax.view_init(elev=20.0, azim=-60) 72 | 73 | fact = 2 74 | ax.set_xlim3d([-radius / fact, radius / fact]) 75 | ax.set_ylim3d([-radius / fact, radius / fact]) 76 | ax.set_zlim3d([0, radius]) 77 | 78 | ax.set_aspect("auto") 79 | ax.set_xticklabels([]) 80 | ax.set_yticklabels([]) 81 | ax.set_zticklabels([]) 82 | 83 | ax.set_axis_off() 84 | ax.grid(b=False) 85 | 86 | ax.set_title(title, loc="center", wrap=True) 87 | return ax 88 | 89 | 90 | def plot_floor(ax, minx, maxx, miny, maxy, minz): 91 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 92 | 93 | # Plot a plane XZ 94 | verts = [ 95 | [minx, miny, minz], 96 | [minx, maxy, minz], 97 | [maxx, maxy, minz], 98 | [maxx, miny, minz], 99 | ] 100 | xz_plane = Poly3DCollection([verts], zorder=1) 101 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 1)) 102 | ax.add_collection3d(xz_plane) 103 | 104 | # Plot a bigger square plane XZ 105 | radius = max((maxx - minx), (maxy - miny)) 106 | 107 | # center +- radius 108 | minx_all = (maxx + minx) / 2 - radius 109 | maxx_all = (maxx + minx) / 2 + radius 110 | 111 | miny_all = (maxy + miny) / 2 - radius 112 | maxy_all = (maxy + miny) / 2 + radius 113 | 114 | verts = [ 115 | [minx_all, miny_all, minz], 116 | [minx_all, maxy_all, minz], 117 | [maxx_all, maxy_all, minz], 118 | [maxx_all, miny_all, minz], 119 | ] 120 | xz_plane = Poly3DCollection([verts], zorder=1) 121 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 122 | ax.add_collection3d(xz_plane) 123 | return ax 124 | 125 | 126 | def update_camera(ax, root, radius=1.5): 127 | fact = 2 128 | ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]]) 129 | ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]]) 130 | 131 | 132 | def render_animation( 133 | joints: np.ndarray, 134 | output: str = "notebook", 135 | highlights: Optional[np.ndarray] = None, 136 | jointstype: str = "smpljoints", 137 | title: str = "", 138 | fps: float = 20.0, 139 | colors: List[str] = colors, 140 | figsize: Tuple[int] = (4, 4), 141 | fontsize: int = 15, 142 | canonicalize: bool = False, 143 | agg=True, 144 | ): 145 | if agg: 146 | import matplotlib 147 | 148 | matplotlib.use("Agg") 149 | 150 | if highlights is not None: 151 | assert len(highlights) == len(joints) 152 | 153 | assert jointstype in KINEMATIC_TREES 154 | kinematic_tree = KINEMATIC_TREES[jointstype] 155 | 156 | import matplotlib.pyplot as plt 157 | from matplotlib.animation import FuncAnimation 158 | import matplotlib.patheffects as pe 159 | 160 | mean_fontsize = fontsize 161 | 162 | # heuristic to change fontsize 163 | fontsize = mean_fontsize - (len(title) - 30) / 20 164 | plt.rcParams.update({"font.size": fontsize}) 165 | 166 | # Z is gravity here 167 | x, y, z = 0, 1, 2 168 | 169 | joints = joints.copy() 170 | 171 | if canonicalize: 172 | joints = canonicalize_rotation(joints, jointstype=jointstype) 173 | 174 | # Create a figure and initialize 3d plot 175 | fig = plt.figure(figsize=figsize) 176 | ax = init_axis(fig, title) 177 | 178 | # Create spline line 179 | trajectory = joints[:, 0, [x, y]] 180 | avg_segment_length = ( 181 | np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3 182 | ) 183 | draw_offset = int(25 / avg_segment_length) 184 | (spline_line,) = ax.plot(*trajectory.T, zorder=10, color="white") 185 | 186 | # Create a floor 187 | minx, miny, _ = joints.min(axis=(0, 1)) 188 | maxx, maxy, _ = joints.max(axis=(0, 1)) 189 | plot_floor(ax, minx, maxx, miny, maxy, 0) 190 | 191 | # Put the character on the floor 192 | height_offset = np.min(joints[:, :, z]) # Min height 193 | joints = joints.copy() 194 | joints[:, :, z] -= height_offset 195 | 196 | # Initialization for redrawing 197 | lines = [] 198 | initialized = False 199 | 200 | def update(frame): 201 | nonlocal initialized 202 | skeleton = joints[frame] 203 | 204 | root = skeleton[0] 205 | update_camera(ax, root) 206 | 207 | hcolors = colors 208 | if highlights is not None and highlights[frame]: 209 | hcolors = ("red", "red", "red", "red", "red") 210 | 211 | for index, (chain, color) in enumerate( 212 | zip(reversed(kinematic_tree), reversed(hcolors)) 213 | ): 214 | if not initialized: 215 | lines.append( 216 | ax.plot( 217 | skeleton[chain, x], 218 | skeleton[chain, y], 219 | skeleton[chain, z], 220 | linewidth=6.0, 221 | color=color, 222 | zorder=20, 223 | path_effects=[pe.SimpleLineShadow(), pe.Normal()], 224 | ) 225 | ) 226 | 227 | else: 228 | lines[index][0].set_xdata(skeleton[chain, x]) 229 | lines[index][0].set_ydata(skeleton[chain, y]) 230 | lines[index][0].set_3d_properties(skeleton[chain, z]) 231 | lines[index][0].set_color(color) 232 | 233 | left = max(frame - draw_offset, 0) 234 | right = min(frame + draw_offset, trajectory.shape[0]) 235 | 236 | spline_line.set_xdata(trajectory[left:right, 0]) 237 | spline_line.set_ydata(trajectory[left:right, 1]) 238 | spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0])) 239 | initialized = True 240 | 241 | fig.tight_layout() 242 | frames = joints.shape[0] 243 | anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False) 244 | 245 | if output == "notebook": 246 | from IPython.display import HTML 247 | 248 | HTML(anim.to_jshtml()) 249 | else: 250 | # anim.save(output, writer='ffmpeg', fps=fps) 251 | anim.save(output, fps=fps) 252 | 253 | plt.close() 254 | -------------------------------------------------------------------------------- /src/rifke.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | import numpy as np 5 | 6 | from torch import Tensor 7 | 8 | from .geometry import axis_angle_rotation, matrix_to_axis_angle 9 | from .joints import INFOS 10 | 11 | 12 | def joints_to_rifke(joints, jointstype="smpljoints"): 13 | # Joints to rotation invariant poses (Holden et. al.) 14 | # Similar function than fke2rifke in Language2Pose repository 15 | # Adapted from the pytorch version of TEMOS 16 | # https://github.com/Mathux/TEMOS 17 | # Estimate the last velocities based on acceleration 18 | # Difference of rotations are in SO3 space now 19 | 20 | # First remove the ground 21 | ground = joints[..., 2].min() 22 | poses = joints.clone() 23 | poses[..., 2] -= ground 24 | 25 | poses = joints.clone() 26 | translation = poses[..., 0, :].clone() 27 | 28 | # Let the root have the Z translation --> gravity axis 29 | root_grav_axis = translation[..., 2] 30 | 31 | # Trajectory => Translation without gravity axis (Z) 32 | trajectory = translation[..., [0, 1]] 33 | 34 | # Compute the forward direction (before removing the root joint) 35 | forward = get_forward_direction(poses, jointstype=jointstype) 36 | 37 | # Delete the root joints of the poses 38 | poses = poses[..., 1:, :] 39 | 40 | # Remove the trajectory of the poses 41 | poses[..., [0, 1]] -= trajectory[..., None, :] 42 | 43 | vel_trajectory = torch.diff(trajectory, dim=-2) 44 | 45 | # repeat the last acceleration 46 | # for the last (not seen) velocity 47 | last_acceleration = vel_trajectory[..., -1, :] - vel_trajectory[..., -2, :] 48 | 49 | future_velocity = vel_trajectory[..., -1, :] + last_acceleration 50 | vel_trajectory = torch.cat((vel_trajectory, future_velocity[..., None, :]), dim=-2) 51 | 52 | angles = torch.atan2(*(forward.transpose(0, -1))).transpose(0, -1) 53 | 54 | # True difference of angles 55 | mat_rotZ = axis_angle_rotation("Z", angles) 56 | vel_mat_rotZ = mat_rotZ[..., 1:, :, :] @ mat_rotZ.transpose(-1, -2)[..., :-1, :, :] 57 | # repeat the last acceleration (same as the trajectory but in the 3D rotation space) 58 | last_acc_rotZ = ( 59 | vel_mat_rotZ[..., -1, :, :] @ vel_mat_rotZ.transpose(-1, -2)[..., -2, :, :] 60 | ) 61 | future_vel_rotZ = vel_mat_rotZ[..., -1, :, :] @ last_acc_rotZ 62 | vel_mat_rotZ = torch.cat((vel_mat_rotZ, future_vel_rotZ[..., None, :, :]), dim=-3) 63 | vel_angles = matrix_to_axis_angle(vel_mat_rotZ)[..., 2] 64 | 65 | # Construct the inverse rotation matrix 66 | rotations_inv = mat_rotZ.transpose(-1, -2)[..., :2, :2] 67 | 68 | poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 1]], rotations_inv) 69 | poses_local = torch.stack( 70 | (poses_local[..., 0], poses_local[..., 1], poses[..., 2]), axis=-1 71 | ) 72 | 73 | # stack the xyz joints into feature vectors 74 | poses_features = rearrange(poses_local, "... joints xyz -> ... (joints xyz)") 75 | 76 | # Rotate the vel_trajectory 77 | vel_trajectory_local = torch.einsum( 78 | "...j,...jk->...k", vel_trajectory, rotations_inv 79 | ) 80 | 81 | # Stack things together 82 | features = group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local) 83 | return features 84 | 85 | 86 | def rifke_to_joints(features: Tensor, jointstype="smpljoints") -> Tensor: 87 | root_grav_axis, poses_features, vel_angles, vel_trajectory_local = ungroup(features) 88 | 89 | # Remove the dummy last angle and integrate the angles 90 | angles = torch.cumsum(vel_angles[..., :-1], dim=-1) 91 | # The first angle is zero 92 | angles = torch.cat((0 * angles[..., [0]], angles), dim=-1) 93 | rotations = axis_angle_rotation("Z", angles)[..., :2, :2] 94 | 95 | # Get back the poses 96 | poses_local = rearrange(poses_features, "... (joints xyz) -> ... joints xyz", xyz=3) 97 | 98 | # Rotate the poses 99 | poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 1]], rotations) 100 | poses = torch.stack((poses[..., 0], poses[..., 1], poses_local[..., 2]), axis=-1) 101 | 102 | # Rotate the vel_trajectory 103 | vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, rotations) 104 | # Remove the dummy last velocity and integrate the trajectory 105 | trajectory = torch.cumsum(vel_trajectory[..., :-1, :], dim=-2) 106 | # The first position is zero 107 | trajectory = torch.cat((0 * trajectory[..., [0], :], trajectory), dim=-2) 108 | 109 | # Add the root joints (which is still zero) 110 | poses = torch.cat((0 * poses[..., [0], :], poses), -2) 111 | 112 | # put back the gravity offset 113 | poses[..., 0, 2] = root_grav_axis 114 | 115 | # Add the trajectory globally 116 | poses[..., [0, 1]] += trajectory[..., None, :] 117 | return poses 118 | 119 | 120 | def group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local): 121 | # Stack things together 122 | features = torch.cat( 123 | ( 124 | root_grav_axis[..., None], 125 | poses_features, 126 | vel_angles[..., None], 127 | vel_trajectory_local, 128 | ), 129 | -1, 130 | ) 131 | return features 132 | 133 | 134 | def ungroup(features: Tensor) -> tuple[Tensor]: 135 | # Unbind things 136 | root_grav_axis = features[..., 0] 137 | poses_features = features[..., 1:-3] 138 | vel_angles = features[..., -3] 139 | vel_trajectory_local = features[..., -2:] 140 | return root_grav_axis, poses_features, vel_angles, vel_trajectory_local 141 | 142 | 143 | def get_forward_direction(poses, jointstype="smpljoints"): 144 | assert jointstype in INFOS 145 | infos = INFOS[jointstype] 146 | assert poses.shape[-2] == infos["njoints"] 147 | RH, LH, RS, LS = infos["RH"], infos["LH"], infos["RS"], infos["LS"] 148 | across = ( 149 | poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - poses[..., LS, :] 150 | ) 151 | forward = torch.stack((-across[..., 1], across[..., 0]), axis=-1) 152 | forward = torch.nn.functional.normalize(forward, dim=-1) 153 | return forward 154 | 155 | 156 | def canonicalize_rotation(joints, jointstype="smpljoints"): 157 | return_np = False 158 | if isinstance(joints, np.ndarray): 159 | joints = torch.from_numpy(joints) 160 | return_np = True 161 | 162 | features = joints_to_rifke(joints, jointstype=jointstype) 163 | joints_c = rifke_to_joints(features, jointstype=jointstype) 164 | if return_np: 165 | joints_c = joints_c.numpy() 166 | return joints_c 167 | -------------------------------------------------------------------------------- /stats/babel/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/babel/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/babel_actions_120/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_120/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/babel_actions_120/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_120/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/babel_actions_60/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_60/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/babel_actions_60/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_60/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/humanml3d/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/humanml3d_kitml_babel/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml_babel/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /stats/kitml/guoh3dfeats/mean.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/kitml/guoh3dfeats/mean.pt -------------------------------------------------------------------------------- /stats/kitml/guoh3dfeats/std.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/kitml/guoh3dfeats/std.pt -------------------------------------------------------------------------------- /text_motion_sim.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | import logging 3 | import hydra 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | @hydra.main(version_base=None, config_path="configs", config_name="text_motion_sim") 9 | def text_motion_sim(cfg: DictConfig) -> None: 10 | device = cfg.device 11 | run_dir = cfg.run_dir 12 | ckpt_name = cfg.ckpt_name 13 | npy_path = cfg.npy 14 | text = cfg.text 15 | 16 | import src.prepare # noqa 17 | import torch 18 | import numpy as np 19 | from src.config import read_config 20 | from src.load import load_model_from_cfg 21 | from hydra.utils import instantiate 22 | from pytorch_lightning import seed_everything 23 | from src.data.collate import collate_x_dict 24 | from src.model.tmr import get_score_matrix 25 | 26 | cfg = read_config(run_dir) 27 | 28 | seed_everything(cfg.seed) 29 | 30 | logger.info("Loading the text model") 31 | text_model = instantiate(cfg.data.text_to_token_emb, device=device) 32 | 33 | logger.info("Loading the model") 34 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device) 35 | 36 | normalizer = instantiate(cfg.data.motion_loader.normalizer) 37 | 38 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float) 39 | motion = normalizer(motion) 40 | motion = motion.to(device) 41 | 42 | motion_x_dict = {"x": motion, "length": len(motion)} 43 | 44 | with torch.inference_mode(): 45 | # motion -> latent 46 | motion_x_dict = collate_x_dict([motion_x_dict]) 47 | lat_m = model.encode(motion_x_dict, sample_mean=True)[0] 48 | 49 | # text -> latent 50 | text_x_dict = collate_x_dict(text_model([text])) 51 | lat_t = model.encode(text_x_dict, sample_mean=True)[0] 52 | 53 | score = get_score_matrix(lat_t, lat_m).cpu() 54 | 55 | score_str = f"{score:.3}" 56 | logger.info( 57 | f"The similariy score s (0 <= s <= 1) between the text and the motion is: {score_str}" 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | text_motion_sim() 63 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import instantiate 3 | import logging 4 | from omegaconf import DictConfig 5 | import os 6 | import pytorch_lightning as pl 7 | 8 | from src.config import read_config, save_config 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @hydra.main(config_path="configs", config_name="train", version_base="1.3") 14 | def train(cfg: DictConfig): 15 | # Resuming if needed 16 | ckpt = None 17 | 18 | if cfg.ckpt is not None: 19 | ckpt = cfg.ckpt 20 | 21 | if cfg.resume_dir is not None: 22 | assert cfg.ckpt is not None 23 | max_epochs = cfg.trainer.max_epochs 24 | ckpt = os.path.join(cfg.resume_dir, 'logs', 'checkpoints', f'{cfg.ckpt}.ckpt') 25 | cfg = read_config(cfg.resume_dir) 26 | cfg.trainer.max_epochs = max_epochs 27 | logger.info("Resuming training") 28 | logger.info(f"The config is loaded from: \n{cfg.run_dir}") 29 | else: 30 | if "ckpt_path" in cfg and cfg.ckpt_path is not None: 31 | ckpt = cfg.ckpt_path 32 | config_path = save_config(cfg) 33 | logger.info("Training script") 34 | logger.info(f"The config can be found here: \n{config_path}") 35 | 36 | pl.seed_everything(cfg.seed) 37 | 38 | text_to_token_emb = instantiate(cfg.data.text_to_token_emb) 39 | text_to_sent_emb = instantiate(cfg.data.text_to_sent_emb) 40 | 41 | logger.info("Loading the dataloaders") 42 | train_dataset = instantiate(cfg.data, split="train", 43 | text_to_token_emb=text_to_token_emb, 44 | text_to_sent_emb=text_to_sent_emb) 45 | 46 | if "data_val" not in cfg: 47 | data_val = cfg.data 48 | else: 49 | data_val = cfg.data_val 50 | text_to_token_emb = instantiate(cfg.data_val.text_to_token_emb) 51 | text_to_sent_emb = instantiate(cfg.data_val.text_to_sent_emb) 52 | 53 | val_dataset = instantiate(data_val, split="val", 54 | text_to_token_emb=text_to_token_emb, 55 | text_to_sent_emb=text_to_sent_emb) 56 | 57 | train_dataloader = instantiate( 58 | cfg.dataloader, 59 | dataset=train_dataset, 60 | collate_fn=train_dataset.collate_fn, 61 | shuffle=True, 62 | ) 63 | 64 | val_dataloader = instantiate( 65 | cfg.dataloader, 66 | dataset=val_dataset, 67 | collate_fn=val_dataset.collate_fn, 68 | shuffle=False, 69 | ) 70 | 71 | logger.info("Loading the model") 72 | model = instantiate(cfg.model) 73 | 74 | logger.info(f"Using checkpoint: {ckpt}") 75 | logger.info("Training") 76 | trainer = instantiate(cfg.trainer) 77 | trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt) 78 | 79 | 80 | if __name__ == "__main__": 81 | train() 82 | --------------------------------------------------------------------------------