├── .gitignore
├── DATASETS.md
├── LICENSE.md
├── README.md
├── app.py
├── configs
├── combine_datasets.yaml
├── compute_guoh3dfeats.yaml
├── data
│ ├── _base.yaml
│ ├── _base_augmented.yaml
│ ├── babel.yaml
│ ├── babel_actions_120.yaml
│ ├── babel_actions_60.yaml
│ ├── humanml3d.yaml
│ ├── humanml3d_kitml.yaml
│ ├── humanml3d_kitml_babel.yaml
│ ├── kitml.yaml
│ └── motion_loader
│ │ ├── _base.yaml
│ │ └── guoh3dfeats.yaml
├── debug
│ ├── profiler.yaml
│ └── train.yaml
├── defaults.yaml
├── encode_dataset.yaml
├── encode_motion.yaml
├── encode_text.yaml
├── extract.yaml
├── hydra
│ ├── hydra_logging
│ │ └── tqdm.yaml
│ └── job_logging
│ │ └── tqdm.yaml
├── load_model.yaml
├── model
│ ├── temos.yaml
│ ├── tmr.yaml
│ ├── tmr_hn.yaml
│ ├── tmr_text_averaging.yaml
│ └── tmr_text_averaging_hn.yaml
├── motion_stats.yaml
├── render.yaml
├── renderer
│ └── matplotlib.yaml
├── retrieval.yaml
├── retrieval_action_multi_labels.yaml
├── text_dataset_sim.yaml
├── text_embeddings.yaml
├── text_embeddings_with_augmentation.yaml
├── text_motion_sim.yaml
├── train.yaml
├── train_hn.yaml
├── train_hn_with_augmentation.yaml
├── train_with_augmentation.yaml
└── trainer.yaml
├── datasets
└── annotations
│ ├── humanml3d
│ ├── annotations.json
│ └── splits
│ │ ├── all.txt
│ │ ├── nsim_test.txt
│ │ ├── test.txt
│ │ ├── test_tiny.txt
│ │ ├── train.txt
│ │ ├── train_tiny.txt
│ │ ├── val.txt
│ │ └── val_tiny.txt
│ └── kitml
│ ├── annotations.json
│ └── splits
│ ├── all.txt
│ ├── nsim_test.txt
│ ├── test.txt
│ ├── test_tiny.txt
│ ├── train.txt
│ ├── train_tiny.txt
│ ├── val.txt
│ └── val_tiny.txt
├── demo
├── amass_to_babel.json
├── load.py
└── model.py
├── encode_dataset.py
├── encode_motion.py
├── encode_text.py
├── extract.py
├── prepare
├── combine_datasets.py
├── compute_guoh3dfeats.py
├── download_pretrain_models.sh
├── motion_stats.py
├── text_embeddings.py
└── tools.py
├── requirements.txt
├── retrieval.py
├── retrieval_action.py
├── retrieval_action_multi_labels.py
├── src
├── callback
│ ├── __init__.py
│ ├── progress.py
│ └── tqdmbar.py
├── config.py
├── data
│ ├── augmented_text_motion.py
│ ├── collate.py
│ ├── motion.py
│ ├── text.py
│ ├── text_motion.py
│ └── text_motion_multi_labels.py
├── geometry.py
├── guofeats
│ ├── __init__.py
│ ├── common
│ │ ├── quaternion.py
│ │ └── skeleton.py
│ ├── motion_representation.py
│ ├── paramUtil.py
│ └── skeleton_example_h3d.npy
├── joints.py
├── load.py
├── logger
│ ├── csv.py
│ └── csv_fabric.py
├── logging.py
├── model
│ ├── __init__.py
│ ├── actor.py
│ ├── losses.py
│ ├── metrics.py
│ ├── temos.py
│ ├── text_encoder.py
│ ├── tmr.py
│ └── tmr_text_averaging.py
├── prepare.py
├── renderer
│ └── matplotlib.py
└── rifke.py
├── stats
├── babel
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
├── babel_actions_120
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
├── babel_actions_60
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
├── humanml3d
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
├── humanml3d_kitml
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
├── humanml3d_kitml_babel
│ └── guoh3dfeats
│ │ ├── mean.pt
│ │ └── std.pt
└── kitml
│ └── guoh3dfeats
│ ├── mean.pt
│ └── std.pt
├── text_motion_sim.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/DATASETS.md:
--------------------------------------------------------------------------------
1 | ## Note on datasets
2 |
3 | Currently, three datasets are widely used for 3D text-to-motion: [KIT-ML](https://motion-annotation.humanoids.kit.edu/dataset/), [HumanML3D](https://github.com/EricGuo5513/HumanML3D) and [BABEL](https://babel.is.tue.mpg.de).
4 |
5 | ### Unifying the datasets
6 |
7 | As explained on their website, [AMASS](https://amass.is.tue.mpg.de) dataset is a large database of human motion unifying different optical marker-based motion capture datasets by representing them within a common framework and parameterization.
8 |
9 | Except from a part of HumanML3D which is based on [HumanAct12](https://ericguo5513.github.io/action-to-motion/) (which is also based on [PhSPD](https://drive.google.com/drive/folders/1ZGkpiI99J-4ygD9i3ytJdmyk_hkejKCd?usp=sharing)), almost all the motion data of KIT-ML, HumanML3D and BABEL are included in AMASS.
10 |
11 | Currently, the text-to-motion datasets are not compatible in terms of motion representation:
12 | - KIT-ML uses [Master Motor Map](https://mmm.humanoids.kit.edu) (robot-like joints)
13 | - HumanML3D takes motion from AMASS, extract joints using the SMPL layer, rotate the joints (make Y the gravity axis), crop the motions, make all the skeleton similar to a reference, and compute motion features.
14 | - BABEL use raw SMPL parameters from AMASS
15 |
16 | To be able to use any text-to-motion dataset with the same representation, I propose in [this repo](https://github.com/Mathux/AMASS-Annotation-Unifier) to unify the datasets, to have the same annotation format. With the agreement of the authors, I included the annotations files in TMR repo, in this folder: [datasets/annotations](datasets/annotations) (for BABEL please follow the instructions). For each datasets, I provide a .json file with:
17 | - The ID of the motion (as found in the original dataset)
18 | - The path of the motion in AMASS (or HumanAct12)
19 | - The duration in seconds
20 | - A list of annotations which contains:
21 | - An ID
22 | - The corresponding text
23 | - The start and end in seconds
24 |
25 | Like this one:
26 |
27 | ```json
28 | {
29 | "000000": {
30 | "path": "KIT/3/kick_high_left02_poses",
31 | "duration": 5.82,
32 | "annotations": [
33 | {
34 | "seg_id": "000000_0",
35 | "text": "a man kicks something or someone with his left leg.",
36 | "start": 0.0,
37 | "end": 5.82
38 | },
39 | ...
40 | ```
41 |
42 | We are now free to use any motion representation.
43 |
44 |
45 | ### Motion representation
46 |
47 | Guo et al. uses a representation of motion which includes rotation invariant forward kinematics features, 3D rotations, velocities, foot contacts. Currently, a lot of works in 3D motion generation uses these features. However, these features are not the same for HumanML3D and KIT-ML (not the same number of joints, the scale is different, the reference skeleton is different etc).
48 |
49 | To let people use TMR as an evaluator, and be comparable with Guo et al. feature extractor, I propose to process the whole AMASS (+HumanAct12) dataset into the HumanML3D Guo features (which I refer to ``guoh3dfeats`` in the code). Then, we can crop each feature file according to any dataset. I also included the mirrored version of each motions.
50 |
51 | ### Differences with the released version of HumanML3D
52 | For motion shorter than 10s, this process corresponds to exactly the features file of HumanML3D (example "000000.npy").
53 | As a sanity check, you can verify in python that both .npy corresponds to the same data:
54 |
55 | ```python
56 | import numpy as np
57 | new = np.load("datasets/motions/guoh3dfeats/humanact12/humanact12/P11G01R02F1812T1847A0402.npy")
58 | old = np.load("/path/to/HumanML3D/HumanML3D/new_joint_vecs/000001.npy")
59 | assert np.abs(new - old).mean() < 1e-10
60 | ```
61 |
62 | For motion longer than 10s and which are cropped (like "000004.npy"), the results of cropping the features is a bit different than computing the features of the cropped motion. That is because the ``uniform skeleton`` function takes the first frame as reference to compute bone length. However, the difference is quite small.
63 |
64 | ### Installation
65 | Go to the section "Installation - Set up the datasets" of the [README.md](README.md) to compute the features.
66 |
67 |
68 | ## Credits
69 | For all the datasets, be sure to read and follow their license agreements, and cite them accordingly.
70 |
71 | ### KIT-ML
72 | ```bibtex
73 | @article{Plappert2016,
74 | author = {Matthias Plappert and Christian Mandery and Tamim Asfour},
75 | title = {The {KIT} Motion-Language Dataset},
76 | journal = {Big Data}
77 | year = 2016
78 | }
79 | ```
80 |
81 | ### HumanML3D
82 | ```bibtex
83 | @inproceedings{Guo_2022_CVPR,
84 | author = {Guo, Chuan and Zou, Shihao and Zuo, Xinxin and Wang, Sen and Ji, Wei and Li, Xingyu and Cheng, Li},
85 | title = {Generating Diverse and Natural 3D Human Motions From Text},
86 | booktitle = {Computer Vision and Pattern Recognition (CVPR)},
87 | year = 2022
88 | }
89 | ```
90 |
91 | ### BABEL
92 | ```bibtex
93 | @inproceedings{BABEL:CVPR:2021,
94 | title = {{BABEL}: Bodies, Action and Behavior with English Labels},
95 | author = {Punnakkal, Abhinanda R. and Chandrasekaran, Arjun and Athanasiou, Nikos and Quiros-Ramirez, Alejandra and Black, Michael J.},
96 | booktitle = {Computer Vision and Pattern Recognition (CVPR)},
97 | year = 2021
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Léore Bensabath
4 |
5 | TMR LICENCE
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # TMR++
4 | ## A Cross-Dataset Study for Text-based 3D Human Motion Retrieval
5 |
6 |
Léore Bensabath
7 | ·
8 |
Mathis Petrovich
9 | ·
10 |
Gül Varol
11 |
12 |
13 | [](https://arxiv.org/abs/2405.16909)
14 |
15 |
16 |
17 |
18 | ## Description
19 | Official PyTorch implementation of the paper:
20 |
21 |
22 | [**A Cross-Dataset Study for Text-based 3D Human Motion Retrieval**](https://arxiv.org/abs/2405.16909)
23 |
24 |
25 |
26 | This repo is based on the implementation of
27 | [**TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis**](https://github.com/Mathux/TMR/tree/master).
28 |
29 | Please visit our [**webpage**](https://imagine.enpc.fr/~leore.bensabath/TMR++) for more details.
30 |
31 | ### Bibtex
32 | If you find this code useful in your research, please cite:
33 |
34 | ```bibtex
35 | @inproceedings{lbensabath2024,
36 | title={TMR++: A Cross-Dataset Study for Text-based 3D Human Motion Retrieval},
37 | author={Bensabath, Léore and Petrovich, Mathis and Varol, G{\"u}l},
38 | journal={CVPRW HuMoGen},
39 | year={2024}
40 | }
41 | ```
42 | and
43 | ```bibtex
44 | @inproceedings{petrovich23tmr,
45 | title = {{TMR}: Text-to-Motion Retrieval Using Contrastive {3D} Human Motion Synthesis},
46 | author = {Petrovich, Mathis and Black, Michael J. and Varol, G{\"u}l},
47 | booktitle = {International Conference on Computer Vision ({ICCV})},
48 | year = 2023
49 | }
50 | ```
51 |
52 | You can also put a star :star:, if the code is useful to you.
53 |
54 | ## Installation :construction_worker:
55 |
56 | Create environment
57 |
58 |
59 | Create a python virtual environnement:
60 | ```bash
61 | python -m venv ~/.venv/TMR
62 | source ~/.venv/TMR/bin/activate
63 | ```
64 |
65 | Install [PyTorch](https://pytorch.org/get-started/locally/)
66 | ```bash
67 | python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
68 | ```
69 |
70 | Then install remaining packages:
71 | ```
72 | python -m pip install -r requirements.txt
73 | ```
74 |
75 | which corresponds to the packages: pytorch_lightning, einops, hydra-core, hydra-colorlog, orjson, tqdm, scipy.
76 | The code was tested on Python 3.10.12 and PyTorch 2.0.1.
77 |
78 |
79 |
80 | Set up the datasets
81 |
82 |
83 | Please first set up the datasets as explain in https://github.com/Mathux/TMR/tree/master in the same README section.
84 |
85 | In this repo, we provide the augmented versions of dataset humanml3d, kitml and babel.
86 | For a given dataset ($DATASET), up to 3 new annotation file have been created:
87 | - ``dataset/annotations/$DATASET/annotations_paraphrases.json``: Includes all the paraphrases generated by a llm
88 | - ``dataset/annotations/$DATASET/annotations_actions.json``: For humanml3d and kitml only, includes the action type label generated by a llm
89 | - ``dataset/annotations/$DATASET/annotations_all.json``: Includes a concatenation by key id of all the annotations (original and llm generated)
90 |
91 | Copy the data in your repo from [here](https://drive.google.com/drive/u/1/folders/1_SpOgtYCZBPAXoVz00Zhyk6tPRObUIiW)
92 |
93 | ### Compute the text embeddings for the data with text augmentation
94 |
95 | Run this command to compute the sentence embeddings and token embeddings for the annotations with text augmentation:
96 |
97 | ```
98 | python -m prepare.text_embeddings --config-name=text_embeddings_with_augmentation data=$DATASET
99 | ```
100 |
101 |
102 | Combine datasets
103 |
104 |
105 | To create a combination of any of the datasets, run:
106 |
107 | ```bash
108 | python -m prepare.combine_datasets datasets=$DATASETS test_sets=$TEST_DATASETS split_suffix=$SPLIT_SUFFIX [OPTIONS]
109 | ```
110 | Where:
111 | - ``datasets``: The list of datasets to combine
112 | - ``test_sets``: The intended list on which the dataset is going to be tested. When generating the split files, this will filter from the training set the samples from one of the training datasets that overlap with samples from another provided testing dataset.
113 | Note that you can create different splits for different intended testing sets by leveraging parameter **split_suffix**. The annotations file for the given combination will stay the same regardless of the **test_sets** value.
114 | - ``split_suffix``: The split file suffix for this given combination of test sets. Training and validation split files will be saved under: ``datasets/annotations/splits/train{split_suffix}.txt`` and ``datasets/annotations/splits/val{split_suffix}.txt``
115 |
116 | The new dataset will be created inside folder ``datasets/annotations/{dataset1}_{dataset2}(_{dataset3})``
117 |
118 | **Example:**
119 | ```bash
120 | python -m prepare.combine_datasets datasets=["humanml3d","kitml"] test_sets=["babel"] split_suffix="_wo_hkb"
121 | ```
122 |
123 | Then run the ''python -m prepare.text_embeddings'' command with or without text augmentations on your new dataset combination.
124 |
125 | **Example:**
126 | ```bash
127 | python -m prepare.text_embeddings --config-name=text_embeddings_with_augmentation data=humanml3d_kitml
128 | ```
129 |
130 |
131 | ## Training :rocket:
132 |
133 | ### Training with a combination of datasets
134 |
135 | To train with a combination of datasets without any text augmentation, run the same command as in TMR with the relevant dataset name:
136 |
137 | **Example:**
138 | ```bash
139 | python train.py data=humanml3d_kitml
140 | ```
141 |
142 | ### Training with text augmentation
143 |
144 | ```bash
145 | python train.py --config-name=train_with_augmentation data=$DATASET
146 | ```
147 |
148 | Details
149 | Relevant parameters you can modify in addition to the ones in TMR are the text augmentation picking probabilities detailed in the paper:
150 | **Example**
151 | ```bash
152 | python train.py --config-name=train_with_augmentation data=humanml3d data.paraphrase_prob=0.2 data.summary_prob=0.2 data.averaging_prob=0.3 run_dir=outputs/tmr_humanml3d_w_textAugmentation_0.2_0.2_0.3
153 | ```
154 |
155 |
156 | Extracting weights
157 | After training, run the following command, to extract the weights from the checkpoint:
158 |
159 | ```bash
160 | python extract.py run_dir=$RUN_DIR
161 | ```
162 |
163 | It will take the last checkpoint by default. This should create the folder ``RUN_DIR/last_weights`` and populate it with the files: ``motion_decoder.pt``, ``motion_encoder.pt`` and ``text_encoder.pt``.
164 | This process makes loading models faster, it does not depends on the file structure anymore, and each module can be loaded independently. This is already done for pretrained models.
165 |
166 |
167 |
168 |
169 | ## Pretrained models :dvd:
170 |
171 | You can find the different models used in the paper here:
172 | [pre-trained models](https://drive.google.com/drive/u/1/folders/1otB-B4m4okpD_0crGMcpg0hOsSRYH45t)
173 |
174 |
175 | ## Evaluation :bar_chart:
176 |
177 | ### Motion to text / Text to motion retrieval
178 |
179 | ```bash
180 | python retrieval.py run_dir=$RUN_DIR data=$DATA
181 | ```
182 |
183 | ### Action recognition
184 |
185 | For action recognition on datasets babel_actions_60 and babel_actions_120, run:
186 |
187 | ```bash
188 | python retrieval_action_multi_labels.py run_dir=$RUN_DIR data=$DATA
189 | ```
190 |
191 |
192 | It will compute the metrics, show them and save them in this folder ``RUN_DIR/contrastive_metrics_$DATA/``.
193 | You can change the name of the saving file using argument ``save_file_name``.
194 |
195 |
196 | ## Usage :computer:
197 |
198 | ### Encode a motion
199 | Note that the .npy file should corresponds to HumanML3D Guo features.
200 |
201 | ```bash
202 | python encode_motion.py run_dir=RUN_DIR npy=/path/to/motion.npy
203 | ```
204 |
205 | ### Encode a text
206 |
207 | ```bash
208 | python encode_text.py run_dir=RUN_DIR text="A person is walking forward."
209 | ```
210 |
211 | ### Compute similarity between text and motion
212 | ```bash
213 | python text_motion_sim.py run_dir=RUN_DIR text=TEXT npy=/path/to/motion.npy
214 | ```
215 | For example with ``text="a man sets to do a backflips then fails back flip and falls to the ground"`` and ``npy=HumanML3D/HumanML3D/new_joint_vecs/001034.npy`` you should get around 0.96.
216 |
217 |
218 | ## Launch the demo
219 |
220 | ### Encode the whole motion dataset
221 | ```bash
222 | python encode_dataset.py run_dir=RUN_DIR
223 | ```
224 |
225 |
226 | ### Text-to-motion retrieval demo
227 | Run this command:
228 |
229 | ```bash
230 | python app.py
231 | ```
232 |
233 | and then open your web browser at the address: ``http://localhost:7860``.
234 |
235 | ## License :books:
236 | This code is distributed under an [MIT LICENSE](LICENSE).
237 |
238 | Note that our code depends on other libraries, including PyTorch, PyTorch3D, Hugging Face, Hydra, and uses datasets which each have their own respective licenses that must also be followed.
239 |
--------------------------------------------------------------------------------
/configs/combine_datasets.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | annotations_path: datasets/annotations
6 |
7 | datasets:
8 | - humanml3d
9 | - kitml
10 |
11 | test_sets:
12 | - humanml3d
13 | - kitml
14 |
15 | filter_babel_seg: False
16 |
17 | split_suffix: ''
18 |
19 | min_duration: null
20 | max_duration: null
21 |
--------------------------------------------------------------------------------
/configs/compute_guoh3dfeats.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | base_folder: datasets/motions/pose_data
6 | output_folder: datasets/motions/guoh3dfeats
7 |
8 | force_redo: false # true to recompute the features
9 |
--------------------------------------------------------------------------------
/configs/data/_base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - motion_loader: guoh3dfeats
3 | - _self_
4 |
5 | _target_: src.data.text_motion.TextMotionDataset
6 |
7 | path: datasets/annotations/${hydra:runtime.choices.data}
8 |
9 | text_to_token_emb:
10 | _target_: src.data.text.TokenEmbeddings
11 | path: datasets/annotations/${hydra:runtime.choices.data}
12 | modelname: distilbert-base-uncased
13 | modelpath: null
14 | preload: true
15 |
16 | text_to_sent_emb:
17 | _target_: src.data.text.SentenceEmbeddings
18 | path: datasets/annotations/${hydra:runtime.choices.data}
19 | modelname: sentence-transformers/all-mpnet-base-v2
20 | modelpath: null
21 | preload: true
22 |
23 | preload: true
24 |
--------------------------------------------------------------------------------
/configs/data/_base_augmented.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - motion_loader: guoh3dfeats
3 | - _base
4 | - _self_
5 |
6 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset
7 |
8 | paraphrase_filename: annotations_paraphrased.json
9 | summary_filename: annotations_summarized.json
10 | paraphrase_prob: 0.2
11 | summary_prob: 0.2
12 | averaging_prob: 0.4
13 | text_sampling_nbr: null
14 |
15 | text_to_token_emb:
16 | name: token_embeddings_all # TODO
17 |
18 | text_to_sent_emb:
19 | name: sent_embeddings_all
20 |
--------------------------------------------------------------------------------
/configs/data/babel.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/babel_actions_120.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/babel_actions_60.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/humanml3d.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/humanml3d_kitml.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/humanml3d_kitml_babel.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/kitml.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
--------------------------------------------------------------------------------
/configs/data/motion_loader/_base.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.data.motion.AMASSMotionLoader
2 |
3 | base_dir: ???
4 |
5 | normalizer:
6 | _target_: src.data.motion.Normalizer
7 | base_dir: stats/${hydra:runtime.choices.data}/${hydra:runtime.choices.data/motion_loader}
8 | eps: 1e-12
9 |
--------------------------------------------------------------------------------
/configs/data/motion_loader/guoh3dfeats.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _base
3 | - _self_
4 |
5 | base_dir: datasets/motions/guoh3dfeats
6 | fps: 20.0
7 | nfeats: 263
8 |
--------------------------------------------------------------------------------
/configs/debug/profiler.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | run_dir: debug/
4 |
5 | trainer:
6 | max_epochs: 1
7 | check_val_every_n_epoch: 1
8 | callbacks: null
9 | profiler: simple
10 |
--------------------------------------------------------------------------------
/configs/debug/train.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | run_dir: debug/
4 |
5 | data:
6 | tiny: true
7 | preload: false
8 |
9 | dataloader:
10 | num_workers: 0
11 | shuffle: false
12 |
13 | trainer:
14 | enable_model_summary: false
15 | max_epochs: 3
16 | check_val_every_n_epoch: 1
17 |
--------------------------------------------------------------------------------
/configs/defaults.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | run_dir: logs
4 |
5 | hydra:
6 | run:
7 | dir: ${run_dir}
8 |
9 | seed: 1234
10 | logger_level: INFO
11 |
12 |
13 | defaults:
14 | - _self_
15 | - override hydra/job_logging: tqdm
16 | - override hydra/hydra_logging: tqdm
17 |
--------------------------------------------------------------------------------
/configs/encode_dataset.yaml:
--------------------------------------------------------------------------------
1 | dataloader:
2 | _target_: torch.utils.data.DataLoader
3 | batch_size: 32
4 | num_workers: 8
5 | shuffle: true
6 |
7 | defaults:
8 | - data: humanml3d
9 | - defaults
10 | - _self_
11 |
12 | run_dir: ???
13 |
14 | ckpt_name: last
15 | device: cuda
16 |
--------------------------------------------------------------------------------
/configs/encode_motion.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | run_dir: ???
6 | npy: ???
7 |
8 | start: !!null
9 | end: !!null
10 |
11 | ckpt_name: last
12 | device: cuda
13 |
--------------------------------------------------------------------------------
/configs/encode_text.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | run_dir: ???
6 | text: ???
7 |
8 | ckpt_name: last
9 | device: cuda
10 |
--------------------------------------------------------------------------------
/configs/extract.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | run_dir: ???
6 | ckpt: last
7 |
--------------------------------------------------------------------------------
/configs/hydra/hydra_logging/tqdm.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | verysimple:
5 | format: '%(message)s'
6 |
7 | handlers:
8 | console:
9 | class: src.logging.TqdmLoggingHandler
10 | formatter: verysimple
11 |
12 | root:
13 | level: ${logger_level}
14 | handlers: [console]
15 |
16 |
17 | disable_existing_loggers: false
18 |
--------------------------------------------------------------------------------
/configs/hydra/job_logging/tqdm.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | simple:
5 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
6 | datefmt: '%d/%m/%y %H:%M:%S'
7 |
8 | colorlog:
9 | (): colorlog.ColoredFormatter
10 | format: '[%(white)s%(asctime)s%(reset)s] %(log_color)s%(levelname)s%(reset)s %(message)s'
11 | datefmt: '%d/%m/%y %H:%M:%S'
12 |
13 | log_colors:
14 | DEBUG: purple
15 | INFO: blue
16 | WARNING: yellow
17 | ERROR: red
18 | CRITICAL: red
19 |
20 | handlers:
21 | console:
22 | class: src.logging.TqdmLoggingHandler
23 | formatter: colorlog
24 | file_out:
25 | class: logging.FileHandler
26 | formatter: simple
27 | filename: ${run_dir}/${hydra.job.name}.out
28 |
29 | root:
30 | level: ${logger_level}
31 | handlers: [console, file_out]
32 |
33 | disable_existing_loggers: false
34 |
--------------------------------------------------------------------------------
/configs/load_model.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | run_dir: ???
6 | ckpt: last
7 | device: cuda
8 | eval_mode: true
9 |
--------------------------------------------------------------------------------
/configs/model/temos.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.model.TEMOS
2 |
3 | motion_encoder:
4 | _target_: src.model.ACTORStyleEncoder
5 | nfeats: ${data.motion_loader.nfeats}
6 | vae: true
7 | latent_dim: 256
8 | ff_size: 1024
9 | num_layers: 6
10 | num_heads: 4
11 | dropout: 0.1
12 | activation: gelu
13 |
14 | text_encoder:
15 | _target_: src.model.ACTORStyleEncoder
16 | nfeats: 768
17 | vae: true
18 | latent_dim: 256
19 | ff_size: 1024
20 | num_layers: 6
21 | num_heads: 4
22 | dropout: 0.1
23 | activation: gelu
24 |
25 | motion_decoder:
26 | _target_: src.model.ACTORStyleDecoder
27 | nfeats: ${data.motion_loader.nfeats}
28 | latent_dim: 256
29 | ff_size: 1024
30 | num_layers: 6
31 | num_heads: 4
32 | dropout: 0.1
33 | activation: gelu
34 |
35 | vae: true
36 |
37 | lmd:
38 | recons: 1.0
39 | latent: 1.0e-5
40 | kl: 1.0e-5
41 |
42 | lr: 1e-4
43 |
--------------------------------------------------------------------------------
/configs/model/tmr.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - temos
3 | - _self_
4 |
5 | _target_: src.model.TMR
6 |
7 | lmd:
8 | recons: 1.0
9 | latent: 1.0e-5
10 | kl: 1.0e-5
11 | contrastive: 0.1
12 |
13 | lr: 1e-4
14 | threshold_selfsim_metrics: 0.95
15 |
16 | contrastive_loss:
17 | _target_: src.model.losses.InfoNCE_with_filtering
18 | temperature: 0.1
19 | threshold_selfsim: 0.80
20 |
--------------------------------------------------------------------------------
/configs/model/tmr_hn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - tmr
3 | - _self_
4 |
5 | contrastive_loss:
6 | _target_: src.model.losses.HN_InfoNCE_with_filtering
7 | temperature: 0.1
8 | threshold_selfsim: 0.80
9 | alpha: 0.999
10 | beta: 0.5
11 |
--------------------------------------------------------------------------------
/configs/model/tmr_text_averaging.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - tmr
3 | - _self_
4 |
5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging
6 |
--------------------------------------------------------------------------------
/configs/model/tmr_text_averaging_hn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - tmr_hn
3 | - _self_
4 |
5 | _target_: src.model.tmr_text_averaging.TMRTextAveraging
6 |
7 | contrastive_loss:
8 | _target_: src.model.losses.HN_InfoNCE_with_filtering
9 | temperature: 0.1
10 | threshold_selfsim: 0.80
11 | alpha: 0.999
12 | beta: 0.5
13 |
--------------------------------------------------------------------------------
/configs/motion_stats.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: humanml3d
3 | - defaults
4 | - _self_
5 |
6 | data:
7 | preload: false
8 | motion_loader:
9 | normalizer:
10 | disable: true
11 | text_to_token_emb:
12 | disable: true
13 | text_to_sent_emb:
14 | disable: true
15 |
--------------------------------------------------------------------------------
/configs/render.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - renderer: matplotlib
3 | - defaults
4 | - _self_
5 |
6 | npy_path: ???
7 | title: ""
8 |
9 | swap_axis: false
10 | guofeats: false
11 | rifkefeats: false
12 |
13 | renderer:
14 | canonicalize: true
15 |
--------------------------------------------------------------------------------
/configs/renderer/matplotlib.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.renderer.matplotlib.MatplotlibRender
2 |
3 | jointstype: "guoh3djoints"
4 | fps: 20.0
5 | colors: ['black', 'magenta', 'red', 'green', 'blue']
6 | figsize: 4
7 | canonicalize: true
8 |
--------------------------------------------------------------------------------
/configs/retrieval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 | - data: humanml3d
5 |
6 | device: cuda
7 |
8 | run_dir: ???
9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data}
10 | protocol: all # (is all 4), normal (a), threshold (b), nsim (c), guo (d)
11 | threshold: 0.95 # threashold to compute (b)
12 |
13 | ckpt: last
14 | batch_size: 256
15 |
16 | split: test
17 |
--------------------------------------------------------------------------------
/configs/retrieval_action_multi_labels.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: babel_actions_120
3 | - defaults
4 | - _self_
5 |
6 | device: cuda
7 |
8 | run_dir: ???
9 | save_file_name: contrastive_metrics_${hydra:runtime.choices.data}
10 |
11 | ckpt: last
12 | batch_size: 256
13 |
14 | split: test
15 |
16 | data:
17 | _target_: src.data.text_motion_multi_labels.TextMotionMultiLabelsDataset
18 | tiny: False
19 |
20 | text_to_token_emb:
21 | name: token_embeddings
22 |
23 | text_to_sent_emb:
24 | name: sent_embeddings
25 |
--------------------------------------------------------------------------------
/configs/text_dataset_sim.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: humanml3d
3 | - defaults
4 | - _self_
5 |
6 | run_dir: outputs/tmr_humanml3d
7 | feats: false
8 | text: ???
9 | split: train
10 |
11 | ckpt_name: last
12 | device: cuda
13 |
14 | data:
15 | preload: false
16 | text_to_token_emb:
17 | preload: false
18 | disable: true
19 | text_to_sent_emb:
20 | preload: false
21 | disable: true
22 |
--------------------------------------------------------------------------------
/configs/text_embeddings.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: humanml3d
3 | - defaults
4 | - _self_
5 |
6 | device: cuda
7 |
8 | annotations_filename: annotations.json
9 | output_folder_name_token: token_embeddings
10 | output_folder_name_sent: sent_embeddings
11 |
--------------------------------------------------------------------------------
/configs/text_embeddings_with_augmentation.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: humanml3d
3 | - defaults
4 | - _self_
5 |
6 | device: cuda
7 |
8 | annotations_filename: annotations_all.json
9 | output_folder_name_token: token_embeddings_all
10 | output_folder_name_sent: sent_embeddings_all
11 |
--------------------------------------------------------------------------------
/configs/text_motion_sim.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - defaults
3 | - _self_
4 |
5 | run_dir: ???
6 | npy: ???
7 | feats: false
8 | text: ???
9 |
10 | ckpt_name: last
11 | device: cuda
12 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | ckpt: last
2 | resume_dir: null
3 | ckpt_path: null
4 |
5 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader}
6 |
7 | dataloader:
8 | _target_: torch.utils.data.DataLoader
9 | batch_size: 32
10 | num_workers: 8
11 |
12 | defaults:
13 | - data: humanml3d
14 | - data_val: null
15 | - model: tmr
16 | - trainer
17 | - defaults
18 | - _self_
19 |
--------------------------------------------------------------------------------
/configs/train_hn.yaml:
--------------------------------------------------------------------------------
1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_${hydra:runtime.choices.data/motion_loader}
2 |
3 | defaults:
4 | - train
5 | - override /model: tmr_hn
6 | - _self_
7 |
--------------------------------------------------------------------------------
/configs/train_hn_with_augmentation.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - train_with_augmentation
3 | - override model: tmr_text_averaging_hn
4 | - _self_
5 |
--------------------------------------------------------------------------------
/configs/train_with_augmentation.yaml:
--------------------------------------------------------------------------------
1 | run_dir: outputs/${hydra:runtime.choices.model}_${hydra:runtime.choices.data}_augmented_${hydra:runtime.choices.data/motion_loader}
2 |
3 | defaults:
4 | - train
5 | - data_val: null
6 | - override data: humanml3d_kitml
7 | - override model: tmr_text_averaging
8 | - _self_
9 |
10 | data:
11 | _target_: src.data.augmented_text_motion.AugmentedTextMotionDataset
12 | paraphrase_filename: annotations_paraphrases.json
13 | summary_filename: annotations_actions.json
14 | paraphrase_prob: 0.2
15 | summary_prob: 0.1
16 | averaging_prob: 0.3
17 | preload: True
18 | text_sampling_nbr: null
19 |
20 | text_to_token_emb:
21 | name: token_embeddings_all
22 |
23 | text_to_sent_emb:
24 | name: sent_embeddings_all
25 |
26 |
--------------------------------------------------------------------------------
/configs/trainer.yaml:
--------------------------------------------------------------------------------
1 | trainer:
2 | _target_: pytorch_lightning.Trainer
3 |
4 | max_epochs: 500
5 | log_every_n_steps: 50
6 | num_sanity_val_steps: 0
7 | check_val_every_n_epoch: 1
8 | accelerator: gpu
9 | devices: 1
10 |
11 | callbacks:
12 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint
13 | filename: latest-{epoch}
14 | every_n_epochs: 1
15 | save_top_k: 1
16 | save_last: true
17 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint
18 | filename: latest-{epoch}
19 | monitor: step
20 | mode: max
21 | every_n_epochs: 100
22 | save_top_k: -1
23 | save_last: false
24 | - _target_: src.callback.progress.ProgressLogger
25 | precision: 3
26 | - _target_: src.callback.tqdmbar.TQDMProgressBar
27 |
28 | logger:
29 | _target_: src.logger.csv.CSVLogger
30 | save_dir: ${run_dir}
31 | name: logs
32 |
--------------------------------------------------------------------------------
/datasets/annotations/humanml3d/splits/nsim_test.txt:
--------------------------------------------------------------------------------
1 | 007742
2 | 005935
3 | 008597
4 | 010546
5 | 005697
6 | 005668
7 | 000565
8 | 000119
9 | 008877
10 | 006058
11 | 009566
12 | 005180
13 | 012578
14 | 011493
15 | 005946
16 | 006630
17 | 008900
18 | 004601
19 | 002459
20 | 006186
21 | 000552
22 | 005674
23 | 004545
24 | 001052
25 | 002635
26 | 005672
27 | 011004
28 | 013440
29 | 009455
30 | 003463
31 | 000824
32 | 006549
33 | 007655
34 | 006762
35 | 012222
36 | 012655
37 | 012956
38 | 004973
39 | 013403
40 | 008730
41 | 003439
42 | 008824
43 | 008340
44 | 010823
45 | 007806
46 | 013898
47 | 004996
48 | 010384
49 | 004344
50 | 005048
51 | 001152
52 | 012568
53 | M008664
54 | M007889
55 | M000389
56 | M011343
57 | M012558
58 | M010392
59 | M014283
60 | M001538
61 | M011643
62 | M003677
63 | M011972
64 | M009880
65 | M013023
66 | M012399
67 | M002761
68 | M014109
69 | M004319
70 | M001648
71 | M013778
72 | M008383
73 | M000178
74 | M009148
75 | M006433
76 | M011569
77 | M001577
78 | M008275
79 | M012813
80 | M012084
81 | M009123
82 | M000179
83 | M012639
84 | M010671
85 | M008583
86 | M000972
87 | M008349
88 | M002824
89 | M003301
90 | M008490
91 | M003902
92 | M002252
93 | M008668
94 | M000903
95 | M003689
96 | M003373
97 | M010964
98 | M001193
99 | M006533
100 | M014384
101 |
--------------------------------------------------------------------------------
/datasets/annotations/humanml3d/splits/test_tiny.txt:
--------------------------------------------------------------------------------
1 | 000000
2 | 000019
3 | 000021
4 | 000022
5 | 000026
6 | 000048
7 | 000055
8 | 000063
9 | 000066
10 | 000067
11 |
--------------------------------------------------------------------------------
/datasets/annotations/humanml3d/splits/train_tiny.txt:
--------------------------------------------------------------------------------
1 | 000001
2 | 000002
3 | 000003
4 | 000004
5 | 000005
6 | 000006
7 | 000007
8 | 000008
9 | 000009
10 | 000010
11 |
--------------------------------------------------------------------------------
/datasets/annotations/humanml3d/splits/val_tiny.txt:
--------------------------------------------------------------------------------
1 | 012698
2 | 012808
3 | 008646
4 | 013022
5 | 003172
6 | 008859
7 | 005095
8 | 012044
9 | 002345
10 | 008039
11 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/nsim_test.txt:
--------------------------------------------------------------------------------
1 | 00355
2 | 01496
3 | 01344
4 | 03107
5 | 02245
6 | 01109
7 | 01349
8 | 03128
9 | 02906
10 | 03861
11 | 03670
12 | 00989
13 | 02070
14 | 02971
15 | 01727
16 | 03045
17 | 02510
18 | 01391
19 | 01532
20 | 03771
21 | 03036
22 | 01472
23 | 03691
24 | 01463
25 | 00978
26 | 01639
27 | 02407
28 | 03917
29 | 01450
30 | 01854
31 | 00594
32 | 00736
33 | 03087
34 | 01440
35 | 02021
36 | 01444
37 | 03683
38 | 01367
39 | 00390
40 | 03577
41 | 01485
42 | 02148
43 | 03190
44 | 01223
45 | 03215
46 | 03098
47 | 02139
48 | 02435
49 | 03532
50 | M00355
51 | M01496
52 | M02751
53 | M01344
54 | M03107
55 | M02245
56 | M00452
57 | M02556
58 | M01109
59 | M01027
60 | M01349
61 | M03128
62 | M02906
63 | M03861
64 | M03670
65 | M00989
66 | M02070
67 | M02971
68 | M01727
69 | M03045
70 | M02510
71 | M01391
72 | M01532
73 | M03036
74 | M01472
75 | M03691
76 | M01463
77 | M02407
78 | M03917
79 | M01450
80 | M01854
81 | M00594
82 | M00736
83 | M03087
84 | M01440
85 | M02021
86 | M01444
87 | M03683
88 | M00568
89 | M01298
90 | M01491
91 | M01367
92 | M03577
93 | M01485
94 | M02148
95 | M03190
96 | M03215
97 | M03098
98 | M02139
99 | M02435
100 | M03532
101 | M00669
102 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/test.txt:
--------------------------------------------------------------------------------
1 | 00004
2 | 00010
3 | 00019
4 | 00033
5 | 00035
6 | 00036
7 | 00044
8 | 00053
9 | 00059
10 | 00074
11 | 00104
12 | 00105
13 | 00122
14 | 00135
15 | 00136
16 | 00157
17 | 00158
18 | 00163
19 | 00172
20 | 00176
21 | 00186
22 | 00192
23 | 00198
24 | 00200
25 | 00216
26 | 00218
27 | 00228
28 | 00231
29 | 00235
30 | 00248
31 | 00255
32 | 00280
33 | 00287
34 | 00292
35 | 00303
36 | 00309
37 | 00316
38 | 00337
39 | 00342
40 | 00344
41 | 00352
42 | 00355
43 | 00358
44 | 00373
45 | 00384
46 | 00390
47 | 00396
48 | 00410
49 | 00424
50 | 00439
51 | 00441
52 | 00451
53 | 00452
54 | 00453
55 | 00455
56 | 00463
57 | 00465
58 | 00470
59 | 00507
60 | 00508
61 | 00509
62 | 00515
63 | 00522
64 | 00541
65 | 00543
66 | 00545
67 | 00547
68 | 00555
69 | 00556
70 | 00568
71 | 00578
72 | 00591
73 | 00594
74 | 00597
75 | 00648
76 | 00650
77 | 00651
78 | 00659
79 | 00669
80 | 00670
81 | 00675
82 | 00686
83 | 00687
84 | 00695
85 | 00704
86 | 00736
87 | 00755
88 | 00767
89 | 00771
90 | 00777
91 | 00784
92 | 00815
93 | 00842
94 | 00843
95 | 00846
96 | 00849
97 | 00868
98 | 00874
99 | 00876
100 | 00877
101 | 00881
102 | 00896
103 | 00902
104 | 00912
105 | 00931
106 | 00932
107 | 00947
108 | 00949
109 | 00958
110 | 00965
111 | 00971
112 | 00975
113 | 00978
114 | 00980
115 | 00989
116 | 00998
117 | 01004
118 | 01012
119 | 01021
120 | 01023
121 | 01026
122 | 01027
123 | 01029
124 | 01043
125 | 01044
126 | 01060
127 | 01061
128 | 01066
129 | 01067
130 | 01071
131 | 01074
132 | 01081
133 | 01108
134 | 01109
135 | 01115
136 | 01127
137 | 01128
138 | 01136
139 | 01139
140 | 01140
141 | 01145
142 | 01154
143 | 01168
144 | 01183
145 | 01194
146 | 01196
147 | 01214
148 | 01223
149 | 01240
150 | 01241
151 | 01259
152 | 01290
153 | 01295
154 | 01298
155 | 01309
156 | 01312
157 | 01315
158 | 01317
159 | 01320
160 | 01323
161 | 01328
162 | 01329
163 | 01332
164 | 01344
165 | 01349
166 | 01353
167 | 01367
168 | 01381
169 | 01391
170 | 01399
171 | 01406
172 | 01416
173 | 01437
174 | 01440
175 | 01444
176 | 01450
177 | 01463
178 | 01472
179 | 01477
180 | 01485
181 | 01487
182 | 01491
183 | 01496
184 | 01519
185 | 01525
186 | 01528
187 | 01532
188 | 01536
189 | 01548
190 | 01550
191 | 01563
192 | 01571
193 | 01640
194 | 01670
195 | 01671
196 | 01680
197 | 01681
198 | 01710
199 | 01724
200 | 01727
201 | 01732
202 | 01747
203 | 01751
204 | 01761
205 | 01763
206 | 01770
207 | 01780
208 | 01782
209 | 01802
210 | 01806
211 | 01808
212 | 01822
213 | 01824
214 | 01831
215 | 01832
216 | 01842
217 | 01852
218 | 01854
219 | 01861
220 | 01868
221 | 01874
222 | 01904
223 | 01908
224 | 01917
225 | 01924
226 | 01928
227 | 01941
228 | 01944
229 | 01950
230 | 01954
231 | 01963
232 | 01969
233 | 01970
234 | 01974
235 | 01978
236 | 01979
237 | 01997
238 | 01998
239 | 02000
240 | 02007
241 | 02011
242 | 02015
243 | 02021
244 | 02026
245 | 02038
246 | 02060
247 | 02070
248 | 02080
249 | 02083
250 | 02084
251 | 02107
252 | 02115
253 | 02116
254 | 02122
255 | 02125
256 | 02139
257 | 02140
258 | 02148
259 | 02157
260 | 02160
261 | 02165
262 | 02168
263 | 02180
264 | 02181
265 | 02193
266 | 02209
267 | 02234
268 | 02239
269 | 02242
270 | 02243
271 | 02245
272 | 02247
273 | 02248
274 | 02253
275 | 02278
276 | 02283
277 | 02285
278 | 02292
279 | 02298
280 | 02311
281 | 02328
282 | 02329
283 | 02339
284 | 02342
285 | 02344
286 | 02346
287 | 02364
288 | 02373
289 | 02394
290 | 02405
291 | 02407
292 | 02420
293 | 02427
294 | 02435
295 | 02440
296 | 02441
297 | 02449
298 | 02470
299 | 02480
300 | 02483
301 | 02510
302 | 02523
303 | 02542
304 | 02553
305 | 02556
306 | 02574
307 | 02581
308 | 02582
309 | 02588
310 | 02594
311 | 02644
312 | 02666
313 | 02691
314 | 02708
315 | 02722
316 | 02749
317 | 02751
318 | 02826
319 | 02830
320 | 02866
321 | 02888
322 | 02891
323 | 02895
324 | 02903
325 | 02906
326 | 02934
327 | 02945
328 | 02964
329 | 02967
330 | 02971
331 | 02979
332 | 03036
333 | 03045
334 | 03082
335 | 03087
336 | 03098
337 | 03107
338 | 03121
339 | 03128
340 | 03151
341 | 03188
342 | 03190
343 | 03194
344 | 03215
345 | 03224
346 | 03228
347 | 03238
348 | 03240
349 | 03246
350 | 03248
351 | 03252
352 | 03277
353 | 03278
354 | 03281
355 | 03293
356 | 03297
357 | 03314
358 | 03381
359 | 03382
360 | 03396
361 | 03404
362 | 03411
363 | 03441
364 | 03459
365 | 03481
366 | 03489
367 | 03495
368 | 03577
369 | 03593
370 | 03614
371 | 03632
372 | 03643
373 | 03670
374 | 03683
375 | 03685
376 | 03691
377 | 03695
378 | 03771
379 | 03785
380 | 03787
381 | 03788
382 | 03795
383 | 03798
384 | 03825
385 | 03830
386 | 03861
387 | 03866
388 | 03879
389 | 03917
390 | 03930
391 | 03939
392 | 03944
393 | 03964
394 | M00004
395 | M00010
396 | M00019
397 | M00033
398 | M00035
399 | M00036
400 | M00044
401 | M00053
402 | M00059
403 | M00074
404 | M00104
405 | M00105
406 | M00122
407 | M00135
408 | M00136
409 | M00157
410 | M00158
411 | M00163
412 | M00172
413 | M00176
414 | M00186
415 | M00192
416 | M00198
417 | M00200
418 | M00216
419 | M00218
420 | M00228
421 | M00231
422 | M00235
423 | M00248
424 | M00255
425 | M00280
426 | M00287
427 | M00292
428 | M00303
429 | M00309
430 | M00316
431 | M00337
432 | M00342
433 | M00344
434 | M00352
435 | M00355
436 | M00358
437 | M00373
438 | M00384
439 | M00390
440 | M00396
441 | M00410
442 | M00424
443 | M00439
444 | M00441
445 | M00451
446 | M00452
447 | M00453
448 | M00455
449 | M00463
450 | M00465
451 | M00470
452 | M00507
453 | M00508
454 | M00509
455 | M00515
456 | M00522
457 | M00541
458 | M00543
459 | M00545
460 | M00547
461 | M00555
462 | M00556
463 | M00568
464 | M00578
465 | M00591
466 | M00594
467 | M00597
468 | M00648
469 | M00650
470 | M00651
471 | M00659
472 | M00669
473 | M00670
474 | M00675
475 | M00686
476 | M00687
477 | M00695
478 | M00704
479 | M00736
480 | M00755
481 | M00767
482 | M00771
483 | M00777
484 | M00784
485 | M00815
486 | M00842
487 | M00843
488 | M00846
489 | M00849
490 | M00868
491 | M00874
492 | M00876
493 | M00877
494 | M00881
495 | M00896
496 | M00902
497 | M00912
498 | M00931
499 | M00932
500 | M00947
501 | M00949
502 | M00958
503 | M00965
504 | M00971
505 | M00975
506 | M00978
507 | M00980
508 | M00989
509 | M00998
510 | M01004
511 | M01012
512 | M01021
513 | M01023
514 | M01026
515 | M01027
516 | M01029
517 | M01043
518 | M01044
519 | M01060
520 | M01061
521 | M01066
522 | M01067
523 | M01071
524 | M01074
525 | M01081
526 | M01108
527 | M01109
528 | M01115
529 | M01127
530 | M01128
531 | M01136
532 | M01139
533 | M01140
534 | M01145
535 | M01154
536 | M01168
537 | M01183
538 | M01194
539 | M01196
540 | M01214
541 | M01223
542 | M01240
543 | M01241
544 | M01259
545 | M01290
546 | M01295
547 | M01298
548 | M01309
549 | M01312
550 | M01315
551 | M01317
552 | M01320
553 | M01323
554 | M01328
555 | M01329
556 | M01332
557 | M01344
558 | M01349
559 | M01353
560 | M01367
561 | M01381
562 | M01391
563 | M01399
564 | M01406
565 | M01416
566 | M01437
567 | M01440
568 | M01444
569 | M01450
570 | M01463
571 | M01472
572 | M01477
573 | M01485
574 | M01487
575 | M01491
576 | M01496
577 | M01519
578 | M01525
579 | M01528
580 | M01532
581 | M01536
582 | M01548
583 | M01550
584 | M01563
585 | M01571
586 | M01640
587 | M01670
588 | M01671
589 | M01680
590 | M01681
591 | M01710
592 | M01724
593 | M01727
594 | M01732
595 | M01747
596 | M01751
597 | M01761
598 | M01763
599 | M01770
600 | M01780
601 | M01782
602 | M01802
603 | M01806
604 | M01808
605 | M01822
606 | M01824
607 | M01831
608 | M01832
609 | M01842
610 | M01852
611 | M01854
612 | M01861
613 | M01868
614 | M01874
615 | M01904
616 | M01908
617 | M01917
618 | M01924
619 | M01928
620 | M01941
621 | M01944
622 | M01950
623 | M01954
624 | M01963
625 | M01969
626 | M01970
627 | M01974
628 | M01978
629 | M01979
630 | M01997
631 | M01998
632 | M02000
633 | M02007
634 | M02011
635 | M02015
636 | M02021
637 | M02026
638 | M02038
639 | M02060
640 | M02070
641 | M02080
642 | M02083
643 | M02084
644 | M02107
645 | M02115
646 | M02116
647 | M02122
648 | M02125
649 | M02139
650 | M02140
651 | M02148
652 | M02157
653 | M02160
654 | M02165
655 | M02168
656 | M02180
657 | M02181
658 | M02193
659 | M02209
660 | M02234
661 | M02239
662 | M02242
663 | M02243
664 | M02245
665 | M02247
666 | M02248
667 | M02253
668 | M02278
669 | M02283
670 | M02285
671 | M02292
672 | M02298
673 | M02311
674 | M02328
675 | M02329
676 | M02339
677 | M02342
678 | M02344
679 | M02346
680 | M02364
681 | M02373
682 | M02394
683 | M02405
684 | M02407
685 | M02420
686 | M02427
687 | M02435
688 | M02440
689 | M02441
690 | M02449
691 | M02470
692 | M02480
693 | M02483
694 | M02510
695 | M02523
696 | M02542
697 | M02553
698 | M02556
699 | M02574
700 | M02581
701 | M02582
702 | M02588
703 | M02594
704 | M02644
705 | M02666
706 | M02691
707 | M02708
708 | M02722
709 | M02749
710 | M02751
711 | M02826
712 | M02830
713 | M02866
714 | M02888
715 | M02891
716 | M02895
717 | M02903
718 | M02906
719 | M02934
720 | M02945
721 | M02964
722 | M02967
723 | M02971
724 | M02979
725 | M03036
726 | M03045
727 | M03082
728 | M03087
729 | M03098
730 | M03107
731 | M03121
732 | M03128
733 | M03151
734 | M03188
735 | M03190
736 | M03194
737 | M03215
738 | M03224
739 | M03228
740 | M03238
741 | M03240
742 | M03246
743 | M03248
744 | M03252
745 | M03277
746 | M03278
747 | M03281
748 | M03293
749 | M03297
750 | M03314
751 | M03381
752 | M03382
753 | M03396
754 | M03404
755 | M03411
756 | M03441
757 | M03459
758 | M03481
759 | M03489
760 | M03495
761 | M03577
762 | M03593
763 | M03614
764 | M03632
765 | M03643
766 | M03670
767 | M03683
768 | M03685
769 | M03691
770 | M03695
771 | M03771
772 | M03785
773 | M03787
774 | M03788
775 | M03795
776 | M03798
777 | M03825
778 | M03830
779 | M03861
780 | M03866
781 | M03879
782 | M03917
783 | M03930
784 | M03939
785 | M03944
786 | M03964
787 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/test_tiny.txt:
--------------------------------------------------------------------------------
1 | 00004
2 | 00010
3 | 00019
4 | 00033
5 | 00035
6 | 00036
7 | 00044
8 | 00053
9 | 00059
10 | 00074
11 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/train_tiny.txt:
--------------------------------------------------------------------------------
1 | 00001
2 | 00002
3 | 00003
4 | 00005
5 | 00007
6 | 00008
7 | 00009
8 | 00011
9 | 00013
10 | 00014
11 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/val.txt:
--------------------------------------------------------------------------------
1 | 00006
2 | 00015
3 | 00025
4 | 00052
5 | 00068
6 | 00092
7 | 00093
8 | 00126
9 | 00143
10 | 00170
11 | 00183
12 | 00222
13 | 00244
14 | 00254
15 | 00270
16 | 00271
17 | 00306
18 | 00311
19 | 00331
20 | 00332
21 | 00374
22 | 00427
23 | 00432
24 | 00476
25 | 00491
26 | 00498
27 | 00510
28 | 00539
29 | 00549
30 | 00582
31 | 00589
32 | 00605
33 | 00612
34 | 00646
35 | 00685
36 | 00690
37 | 00691
38 | 00709
39 | 00728
40 | 00733
41 | 00754
42 | 00764
43 | 00834
44 | 00841
45 | 00873
46 | 00897
47 | 00905
48 | 00908
49 | 00909
50 | 00918
51 | 00922
52 | 00943
53 | 00954
54 | 00964
55 | 00967
56 | 00981
57 | 00983
58 | 00988
59 | 01001
60 | 01006
61 | 01016
62 | 01030
63 | 01070
64 | 01073
65 | 01092
66 | 01097
67 | 01118
68 | 01123
69 | 01176
70 | 01177
71 | 01209
72 | 01264
73 | 01282
74 | 01299
75 | 01319
76 | 01369
77 | 01407
78 | 01409
79 | 01412
80 | 01418
81 | 01428
82 | 01429
83 | 01443
84 | 01447
85 | 01479
86 | 01483
87 | 01494
88 | 01503
89 | 01655
90 | 01664
91 | 01700
92 | 01703
93 | 01745
94 | 01762
95 | 01794
96 | 01795
97 | 01809
98 | 01825
99 | 01834
100 | 01836
101 | 01877
102 | 01905
103 | 01931
104 | 01948
105 | 01975
106 | 01992
107 | 01995
108 | 02048
109 | 02110
110 | 02118
111 | 02121
112 | 02171
113 | 02184
114 | 02315
115 | 02322
116 | 02418
117 | 02499
118 | 02593
119 | 02598
120 | 02797
121 | 02904
122 | 02922
123 | 02940
124 | 02947
125 | 03009
126 | 03155
127 | 03227
128 | 03232
129 | 03262
130 | 03276
131 | 03284
132 | 03290
133 | 03365
134 | 03419
135 | 03462
136 | 03480
137 | 03631
138 | 03682
139 | 03708
140 | 03722
141 | 03732
142 | 03813
143 | 03817
144 | 03832
145 | 03877
146 | 03899
147 | M00006
148 | M00015
149 | M00025
150 | M00052
151 | M00068
152 | M00092
153 | M00093
154 | M00126
155 | M00143
156 | M00170
157 | M00183
158 | M00222
159 | M00244
160 | M00254
161 | M00270
162 | M00271
163 | M00306
164 | M00311
165 | M00331
166 | M00332
167 | M00374
168 | M00427
169 | M00432
170 | M00476
171 | M00491
172 | M00498
173 | M00510
174 | M00539
175 | M00549
176 | M00582
177 | M00589
178 | M00605
179 | M00612
180 | M00646
181 | M00685
182 | M00690
183 | M00691
184 | M00709
185 | M00728
186 | M00733
187 | M00754
188 | M00764
189 | M00834
190 | M00841
191 | M00873
192 | M00897
193 | M00905
194 | M00908
195 | M00909
196 | M00918
197 | M00922
198 | M00943
199 | M00954
200 | M00964
201 | M00967
202 | M00981
203 | M00983
204 | M00988
205 | M01001
206 | M01006
207 | M01016
208 | M01030
209 | M01070
210 | M01073
211 | M01092
212 | M01097
213 | M01118
214 | M01123
215 | M01176
216 | M01177
217 | M01209
218 | M01264
219 | M01282
220 | M01299
221 | M01319
222 | M01369
223 | M01407
224 | M01409
225 | M01412
226 | M01418
227 | M01428
228 | M01429
229 | M01443
230 | M01447
231 | M01479
232 | M01483
233 | M01494
234 | M01503
235 | M01655
236 | M01664
237 | M01700
238 | M01703
239 | M01745
240 | M01762
241 | M01794
242 | M01795
243 | M01809
244 | M01825
245 | M01834
246 | M01836
247 | M01877
248 | M01905
249 | M01931
250 | M01948
251 | M01975
252 | M01992
253 | M01995
254 | M02048
255 | M02110
256 | M02118
257 | M02121
258 | M02171
259 | M02184
260 | M02315
261 | M02322
262 | M02418
263 | M02499
264 | M02593
265 | M02598
266 | M02797
267 | M02904
268 | M02922
269 | M02940
270 | M02947
271 | M03009
272 | M03155
273 | M03227
274 | M03232
275 | M03262
276 | M03276
277 | M03284
278 | M03290
279 | M03365
280 | M03419
281 | M03462
282 | M03480
283 | M03631
284 | M03682
285 | M03708
286 | M03722
287 | M03732
288 | M03813
289 | M03817
290 | M03832
291 | M03877
292 | M03899
293 |
--------------------------------------------------------------------------------
/datasets/annotations/kitml/splits/val_tiny.txt:
--------------------------------------------------------------------------------
1 | 00006
2 | 00015
3 | 00025
4 | 00052
5 | 00068
6 | 00092
7 | 00093
8 | 00126
9 | 00143
10 | 00170
11 |
--------------------------------------------------------------------------------
/demo/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import orjson
4 | import codecs as cs
5 | import torch
6 |
7 |
8 | def load_json(json_path):
9 | with open(json_path, "rb") as ff:
10 | return orjson.loads(ff.read())
11 |
12 |
13 | def load_unit_embeddings(run_dir, dataset, device="cpu"):
14 | save_dir = os.path.join(run_dir, "latents")
15 | unit_emb_path = os.path.join(save_dir, f"{dataset}_all_unit.npy")
16 | motion_embs = torch.from_numpy(np.load(unit_emb_path)).to(device)
17 |
18 | # Loading the correspondance
19 | keyids_index = load_json(os.path.join(save_dir, f"{dataset}_keyids_index_all.json"))
20 | index_keyids = load_json(os.path.join(save_dir, f"{dataset}_index_keyids_all.json"))
21 |
22 | return motion_embs, keyids_index, index_keyids
23 |
24 |
25 | def load_split(path, split):
26 | split_file = os.path.join(path, "splits", split + ".txt")
27 | id_list = []
28 | with cs.open(split_file, "r") as f:
29 | for line in f.readlines():
30 | id_list.append(line.strip())
31 | return id_list
32 |
33 |
34 | def load_splits(dataset, splits=["test", "all"]):
35 | path = f"datasets/annotations/{dataset}"
36 | return {split: load_split(path, split) for split in splits}
37 |
--------------------------------------------------------------------------------
/demo/model.py:
--------------------------------------------------------------------------------
1 | # Text model + TMR text encoder only
2 |
3 | from typing import List
4 | import torch.nn as nn
5 | import os
6 |
7 | import torch
8 | import numpy as np
9 | from torch import Tensor
10 | from transformers import AutoTokenizer, AutoModel
11 | from torch.nn.functional import normalize
12 | from einops import repeat
13 | import json
14 | import warnings
15 |
16 | import logging
17 |
18 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator")
19 | logger.setLevel(logging.ERROR)
20 |
21 |
22 | warnings.filterwarnings(
23 | "ignore", "The PyTorch API of nested tensors is in prototype stage*"
24 | )
25 |
26 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*")
27 |
28 | torch.set_float32_matmul_precision("high")
29 |
30 |
31 | class PositionalEncoding(nn.Module):
32 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
33 | super().__init__()
34 | self.batch_first = batch_first
35 |
36 | self.dropout = nn.Dropout(p=dropout)
37 |
38 | pe = torch.zeros(max_len, d_model)
39 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
40 | div_term = torch.exp(
41 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
42 | )
43 | pe[:, 0::2] = torch.sin(position * div_term)
44 | pe[:, 1::2] = torch.cos(position * div_term)
45 | pe = pe.unsqueeze(0).transpose(0, 1)
46 | self.register_buffer("pe", pe, persistent=False)
47 |
48 | def forward(self, x: Tensor) -> Tensor:
49 | if self.batch_first:
50 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
51 | else:
52 | x = x + self.pe[: x.shape[0], :]
53 | return self.dropout(x)
54 |
55 |
56 | def read_config(run_dir: str):
57 | path = os.path.join(run_dir, "config.json")
58 | with open(path, "r") as f:
59 | config = json.load(f)
60 | return config
61 |
62 |
63 | class TMR_text_encoder(nn.Module):
64 | def __init__(self, run_dir: str) -> None:
65 | config = read_config(run_dir)
66 | modelpath = config["data"]["text_to_token_emb"]["modelname"]
67 |
68 | text_encoder_conf = config["model"]["text_encoder"]
69 |
70 | vae = text_encoder_conf["vae"]
71 | latent_dim = text_encoder_conf["latent_dim"]
72 | ff_size = text_encoder_conf["ff_size"]
73 | num_layers = text_encoder_conf["num_layers"]
74 | num_heads = text_encoder_conf["num_heads"]
75 | activation = text_encoder_conf["activation"]
76 | nfeats = text_encoder_conf["nfeats"]
77 |
78 | super().__init__()
79 |
80 | # Projection of the text-outputs into the latent space
81 | self.projection = nn.Linear(nfeats, latent_dim)
82 | self.vae = vae
83 | self.nbtokens = 2 if vae else 1
84 |
85 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
86 | self.sequence_pos_encoding = PositionalEncoding(
87 | latent_dim, dropout=0.0, batch_first=True
88 | )
89 |
90 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(
91 | d_model=latent_dim,
92 | nhead=num_heads,
93 | dim_feedforward=ff_size,
94 | dropout=0.0,
95 | activation=activation,
96 | batch_first=True,
97 | )
98 |
99 | self.seqTransEncoder = nn.TransformerEncoder(
100 | seq_trans_encoder_layer, num_layers=num_layers
101 | )
102 |
103 | text_encoder_pt_path = os.path.join(run_dir, "last_weights/text_encoder.pt")
104 | state_dict = torch.load(text_encoder_pt_path)
105 | self.load_state_dict(state_dict)
106 |
107 | from transformers import logging
108 |
109 | # load text model
110 | logging.set_verbosity_error()
111 |
112 | # Tokenizer
113 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
114 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
115 |
116 | # Text model
117 | self.text_model = AutoModel.from_pretrained(modelpath)
118 | # Then configure the model
119 | self.text_encoded_dim = self.text_model.config.hidden_size
120 | self.eval()
121 |
122 | def get_last_hidden_state(self, texts: List[str], return_mask: bool = False):
123 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
124 | output = self.text_model(**encoded_inputs.to(self.text_model.device))
125 | if not return_mask:
126 | return output.last_hidden_state
127 | return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
128 |
129 | def forward(self, texts: List[str]) -> Tensor:
130 | text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
131 |
132 | x = self.projection(text_encoded)
133 |
134 | device = x.device
135 | bs = len(x)
136 |
137 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
138 | xseq = torch.cat((tokens, x), 1)
139 |
140 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)
141 | aug_mask = torch.cat((token_mask, mask), 1)
142 |
143 | # add positional encoding
144 | xseq = self.sequence_pos_encoding(xseq)
145 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
146 | return final[:, 0]
147 |
148 | # compute score for retrieval
149 | def compute_scores(self, texts, unit_embs=None, embs=None):
150 | # not both empty
151 | assert not (unit_embs is None and embs is None)
152 | # not both filled
153 | assert not (unit_embs is not None and embs is not None)
154 |
155 | output_str = False
156 | # if one input, squeeze the output
157 | if isinstance(texts, str):
158 | texts = [texts]
159 | output_str = True
160 |
161 | # compute unit_embs from embs if not given
162 | if embs is not None:
163 | unit_embs = normalize(embs)
164 |
165 | with torch.no_grad():
166 | latent_unit_texts = normalize(self(texts))
167 | # compute cosine similarity between 0 and 1
168 | scores = (unit_embs @ latent_unit_texts.T).T / 2 + 0.5
169 | scores = scores.cpu().numpy()
170 |
171 | if output_str:
172 | scores = scores[0]
173 |
174 | return scores
175 |
--------------------------------------------------------------------------------
/encode_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 | import json
6 | from hydra.core.hydra_config import HydraConfig
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def x_dict_to_device(x_dict, device):
13 | import torch
14 |
15 | for key, val in x_dict.items():
16 | if isinstance(val, torch.Tensor):
17 | x_dict[key] = val.to(device)
18 | return x_dict
19 |
20 |
21 | def write_json(data, path):
22 | with open(path, "w") as ff:
23 | ff.write(json.dumps(data, indent=4))
24 |
25 |
26 | @hydra.main(version_base=None, config_path="configs", config_name="encode_dataset")
27 | def encode_dataset(cfg: DictConfig) -> None:
28 | device = cfg.device
29 | run_dir = cfg.run_dir
30 | ckpt_name = cfg.ckpt_name
31 | cfg_data = cfg.data
32 |
33 | choices = HydraConfig.get().runtime.choices
34 | data_name = choices.data
35 |
36 | import src.prepare # noqa
37 | import torch
38 | import numpy as np
39 | from src.config import read_config
40 | from src.load import load_model_from_cfg
41 | from hydra.utils import instantiate
42 | from pytorch_lightning import seed_everything
43 |
44 | cfg = read_config(run_dir)
45 |
46 | logger.info("Loading the model")
47 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
48 |
49 | save_dir = os.path.join(run_dir, "latents")
50 | os.makedirs(save_dir, exist_ok=True)
51 |
52 | dataset = instantiate(cfg_data, split="all")
53 | dataloader = instantiate(
54 | cfg.dataloader,
55 | dataset=dataset,
56 | collate_fn=dataset.collate_fn,
57 | shuffle=False,
58 | )
59 | seed_everything(cfg.seed)
60 |
61 | all_latents = []
62 | all_keyids = []
63 |
64 | with torch.inference_mode():
65 | for batch in dataloader:
66 | motion_x_dict = batch["motion_x_dict"]
67 | x_dict_to_device(motion_x_dict, device)
68 | latents = model.encode(motion_x_dict, sample_mean=True)
69 | all_latents.append(latents.cpu().numpy())
70 | keyids = batch["keyid"]
71 | all_keyids.extend(keyids)
72 |
73 | latents = np.concatenate(all_latents)
74 | path = os.path.join(save_dir, f"{data_name}_all.npy")
75 | logger.info(f"Encoding the latents of all the splits in {path}")
76 | np.save(path, latents)
77 |
78 | path_unit = os.path.join(save_dir, f"{data_name}_all_unit.npy")
79 | logger.info(f"Encoding the unit latents of all the splits in {path_unit}")
80 |
81 | unit_latents = latents / np.linalg.norm(latents, axis=-1)[:, None]
82 | np.save(path_unit, unit_latents)
83 |
84 | # Writing the correspondance
85 | logger.info("Writing the correspondance files")
86 | keyids_index_path = os.path.join(save_dir, f"{data_name}_keyids_index_all.json")
87 | index_keyids_path = os.path.join(save_dir, f"{data_name}_index_keyids_all.json")
88 |
89 | keyids_index = {x: i for i, x in enumerate(all_keyids)}
90 | index_keyids = {i: x for i, x in enumerate(all_keyids)}
91 |
92 | write_json(keyids_index, keyids_index_path)
93 | write_json(index_keyids, index_keyids_path)
94 |
95 |
96 | if __name__ == "__main__":
97 | encode_dataset()
98 |
--------------------------------------------------------------------------------
/encode_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_motion")
10 | def encode_motion(cfg: DictConfig) -> None:
11 | device = cfg.device
12 | run_dir = cfg.run_dir
13 | ckpt_name = cfg.ckpt_name
14 | npy_path = cfg.npy
15 |
16 | import src.prepare # noqa
17 | import torch
18 | import numpy as np
19 | from src.config import read_config
20 | from src.load import load_model_from_cfg
21 | from hydra.utils import instantiate
22 | from pytorch_lightning import seed_everything
23 | from src.data.collate import collate_x_dict
24 |
25 | cfg = read_config(run_dir)
26 |
27 | logger.info("Loading the model")
28 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
29 | normalizer = instantiate(cfg.data.motion_loader.normalizer)
30 |
31 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float)
32 | motion = normalizer(motion)
33 | motion = motion.to(device)
34 |
35 | motion_x_dict = {"x": motion, "length": len(motion)}
36 |
37 | seed_everything(cfg.seed)
38 | with torch.inference_mode():
39 | motion_x_dict = collate_x_dict([motion_x_dict])
40 | latent = model.encode(motion_x_dict, sample_mean=True)[0]
41 | latent = latent.cpu().numpy()
42 |
43 | fname = os.path.split(npy_path)[1]
44 | output_folder = os.path.join(run_dir, "encoded")
45 | os.makedirs(output_folder, exist_ok=True)
46 | path = os.path.join(output_folder, fname)
47 |
48 | np.save(path, latent)
49 | logger.info(f"Encoding done, latent saved in:\n{path}")
50 |
51 |
52 | if __name__ == "__main__":
53 | encode_motion()
54 |
--------------------------------------------------------------------------------
/encode_text.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | @hydra.main(version_base=None, config_path="configs", config_name="encode_text")
10 | def encode_text(cfg: DictConfig) -> None:
11 | device = cfg.device
12 | run_dir = cfg.run_dir
13 | ckpt_name = cfg.ckpt_name
14 | text = cfg.text
15 |
16 | import src.prepare # noqa
17 | import torch
18 | import numpy as np
19 | from src.config import read_config
20 | from src.load import load_model_from_cfg
21 | from hydra.utils import instantiate
22 | from pytorch_lightning import seed_everything
23 | from src.data.collate import collate_x_dict
24 |
25 | cfg = read_config(run_dir)
26 |
27 | logger.info("Loading the text model")
28 | text_model = instantiate(cfg.data.text_to_token_emb, device=device)
29 |
30 | logger.info("Loading the model")
31 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
32 |
33 | seed_everything(cfg.seed)
34 | with torch.inference_mode():
35 | text_x_dict = collate_x_dict(text_model([text]))
36 | latent = model.encode(text_x_dict, sample_mean=True)[0]
37 | latent = latent.cpu().numpy()
38 |
39 | fname = text.lower().replace(" ", "_") + ".npy"
40 |
41 | output_folder = os.path.join(run_dir, "encoded")
42 | os.makedirs(output_folder, exist_ok=True)
43 | path = os.path.join(output_folder, fname)
44 |
45 | np.save(path, latent)
46 | logger.info(f"Encoding done, latent saved in:\n{path}")
47 |
48 |
49 | if __name__ == "__main__":
50 | encode_text()
51 |
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import hydra
3 | from omegaconf import DictConfig
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | @hydra.main(config_path="configs", config_name="extract", version_base="1.3")
9 | def extract(cfg: DictConfig):
10 | run_dir = cfg.run_dir
11 | ckpt = cfg.ckpt
12 |
13 | from src.load import extract_ckpt
14 |
15 | logger.info("Extracting the checkpoint...")
16 | extract_ckpt(run_dir, ckpt_name=ckpt)
17 | logger.info("Done")
18 |
19 |
20 | if __name__ == "__main__":
21 | extract()
22 |
--------------------------------------------------------------------------------
/prepare/combine_datasets.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import hydra
4 | import json
5 | import numpy as np
6 | import os
7 | from omegaconf import DictConfig
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | SUFFIX_DICT = {"humanml3d": "h", "kitml": "k", "babel": "b"}
13 |
14 | @hydra.main(config_path="../configs", config_name="combine_datasets", version_base="1.3")
15 | def combine_datasets(cfg: DictConfig):
16 |
17 | train_datasets = cfg.datasets
18 | annotations_folder_path = cfg.annotations_path
19 |
20 | combined_dataset_name = "_".join(train_datasets)
21 | combined_dataset_folder = os.path.join(annotations_folder_path, combined_dataset_name)
22 | os.makedirs(combined_dataset_folder, exist_ok=True)
23 |
24 | annotations = {}
25 | annotations_paraphrases = {}
26 | annotations_actions = {}
27 |
28 | annotations = {}
29 | annotations_paraphrases = {}
30 | annotations_actions = {}
31 | annotations_all = {}
32 |
33 | dataset_annotations = {}
34 |
35 | for dataset in train_datasets:
36 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json")
37 | with open(annotations_path) as f:
38 | d = json.load(f)
39 | dataset_annotations[dataset] = d
40 | d_new = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()}
41 | annotations.update(d_new)
42 |
43 | annotations_paraphrases_path = os.path.join(annotations_folder_path, dataset, "annotations_paraphrases.json")
44 | if os.path.exists(annotations_paraphrases_path):
45 | with open(annotations_paraphrases_path) as f:
46 | d = json.load(f)
47 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()}
48 | annotations_paraphrases.update(d)
49 |
50 | annotations_actions_path = os.path.join(annotations_folder_path, dataset, "annotations_actions.json")
51 | if os.path.exists(annotations_actions_path):
52 | with open(annotations_actions_path) as f:
53 | d = json.load(f)
54 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()}
55 | annotations_actions.update(d)
56 |
57 | annotations_all_path = os.path.join(annotations_folder_path, dataset, "annotations_all.json")
58 | if os.path.exists(annotations_all_path):
59 | with open(annotations_all_path) as f:
60 | d = json.load(f)
61 | d = {f"{key}_{SUFFIX_DICT[dataset]}": val for key, val in d.items()}
62 | annotations_all.update(d)
63 |
64 | with open(os.path.join(combined_dataset_folder, "annotations.json"), "w") as f:
65 | json.dump(annotations, f, indent=2)
66 | with open(os.path.join(combined_dataset_folder, "annotations_paraphrases.json"), "w") as f:
67 | json.dump(annotations_paraphrases, f, indent=2)
68 | with open(os.path.join(combined_dataset_folder, "annotations_actions.json"), "w") as f:
69 | json.dump(annotations_actions, f, indent=2)
70 | with open(os.path.join(combined_dataset_folder, "annotations_all.json"), "w") as f:
71 | json.dump(annotations_all, f, indent=2)
72 |
73 | test_datasets = cfg.test_sets
74 |
75 | for dataset in test_datasets:
76 | if dataset not in dataset_annotations:
77 | annotations_path = os.path.join(annotations_folder_path, dataset, "annotations.json")
78 | with open(annotations_path) as f:
79 | d = json.load(f)
80 | dataset_annotations[dataset] = d
81 |
82 | # Splits creation
83 |
84 | logger.info(f"Creating train split by combining train splits from: {', '.join(train_datasets)}")
85 | logger.info(f"Removing from the train splits samples overlapping with test split from: {', '.join(test_datasets)}")
86 |
87 | dataset_splits = {}
88 |
89 | splits = ["train", "val"]
90 | for dataset in train_datasets:
91 | if dataset not in dataset_splits:
92 | dataset_splits[dataset] = {}
93 | for split in splits:
94 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f:
95 | str_inds = f.read()
96 | inds = str_inds.split("\n")
97 | if inds[-1] == "":
98 | inds.pop(-1)
99 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds}
100 | dataset_splits[dataset][split] = dic_ind_path
101 |
102 | split = 'test'
103 | for dataset in test_datasets:
104 | if dataset not in dataset_splits:
105 | dataset_splits[dataset] = {}
106 | with open(os.path.join(annotations_folder_path, dataset, "splits", f"{split}.txt")) as f:
107 | str_inds = f.read()
108 | inds = str_inds.split("\n")
109 | if inds[-1] == "":
110 | inds.pop(-1)
111 | dic_ind_path = {ind: dataset_annotations[dataset][ind]["path"] for ind in inds}
112 | dataset_splits[dataset][split] = dic_ind_path
113 |
114 |
115 | to_remove = {train_dataset: {test_dataset: {"train": [], "val": []} for test_dataset in test_datasets if test_dataset != train_dataset} for train_dataset in train_datasets}
116 |
117 | for train_dataset in train_datasets:
118 |
119 | for split in ["train", "val"]:
120 | for train_id, train_path in dataset_splits[train_dataset][split].items():
121 |
122 | for test_dataset in set(test_datasets) - set([train_dataset]):
123 |
124 | for test_id, test_path in dataset_splits[test_dataset]["test"].items():
125 | if train_path == test_path:
126 |
127 | if not cfg.filter_babel_seg:
128 | if test_dataset == "babel":
129 | test_duration = float(dataset_annotations[test_dataset][test_id]["duration"])
130 | test_fragment_duration = float(dataset_annotations[test_dataset][test_id]["fragment_duration"])
131 |
132 | if not np.isclose([test_duration], [test_fragment_duration], atol=0.1, rtol=0):
133 | continue
134 |
135 | if train_dataset == "babel":
136 | train_duration = float(dataset_annotations[train_dataset][train_id]["duration"])
137 | train_fragment_duration = float(dataset_annotations[train_dataset][train_id]["fragment_duration"])
138 |
139 | if not np.isclose([train_duration], [train_fragment_duration], atol=0.1, rtol=0):
140 | continue
141 |
142 | train_start = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["start"])
143 | train_end = float(dataset_annotations[train_dataset][train_id]["annotations"][0]["end"])
144 | test_start = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["start"])
145 | test_end = float(dataset_annotations[test_dataset][test_id]["annotations"][0]["end"])
146 |
147 | if not ((train_end <= test_start) or (test_end <= train_start)):
148 | to_remove[train_dataset][test_dataset][split].append(train_id)
149 |
150 | datasets_curated = {train_dataset: {split: list(dataset_splits[train_dataset][split].keys()) for split in dataset_splits[train_dataset].keys()} for train_dataset in train_datasets}
151 |
152 | for train_dataset in train_datasets:
153 | for test_dataset in set(test_datasets) - set([train_dataset]):
154 | for split in ["train", "val"]:
155 | for keyid in to_remove[train_dataset][test_dataset][split]:
156 | if keyid in datasets_curated[train_dataset][split]:
157 | datasets_curated[train_dataset][split].remove(keyid)
158 |
159 | splits_folder = os.path.join(annotations_folder_path, combined_dataset_name, "splits")
160 | os.makedirs(splits_folder, exist_ok=True)
161 | all_ids = []
162 | for split in ["train", "val"]:
163 | ids = []
164 | for train_dataset in train_datasets:
165 | ids += [f'{elt}_{SUFFIX_DICT[train_dataset]}' for elt in datasets_curated[train_dataset][split]]
166 | all_ids += ids
167 | ids_str = "\n".join(ids)
168 | filename = f"{split}{cfg.split_suffix}.txt"
169 | with open(os.path.join(splits_folder, filename), "w") as f:
170 | f.write(ids_str)
171 |
172 | all_ids_str = "\n".join(all_ids)
173 | with open(os.path.join(splits_folder, f"all{cfg.split_suffix}.txt"), "w") as f:
174 | f.write(all_ids_str)
175 |
176 |
177 | if __name__ == "__main__":
178 | combine_datasets()
179 |
180 |
--------------------------------------------------------------------------------
/prepare/compute_guoh3dfeats.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import hydra
4 | from omegaconf import DictConfig
5 |
6 | import numpy as np
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def extract_h3d(feats):
12 | from einops import unpack
13 |
14 | root_data, ric_data, rot_data, local_vel, feet_l, feet_r = unpack(
15 | feats, [[4], [63], [126], [66], [2], [2]], "i *"
16 | )
17 | return root_data, ric_data, rot_data, local_vel, feet_l, feet_r
18 |
19 |
20 | def swap_left_right(data):
21 | assert len(data.shape) == 3 and data.shape[-1] == 3
22 | data = data.copy()
23 | data[..., 0] *= -1
24 | right_chain = [2, 5, 8, 11, 14, 17, 19, 21]
25 | left_chain = [1, 4, 7, 10, 13, 16, 18, 20]
26 | left_hand_chain = [22, 23, 24, 34, 35, 36, 25, 26, 27, 31, 32, 33, 28, 29, 30]
27 | right_hand_chain = [43, 44, 45, 46, 47, 48, 40, 41, 42, 37, 38, 39, 49, 50, 51]
28 | tmp = data[:, right_chain]
29 | data[:, right_chain] = data[:, left_chain]
30 | data[:, left_chain] = tmp
31 | if data.shape[1] > 24:
32 | tmp = data[:, right_hand_chain]
33 | data[:, right_hand_chain] = data[:, left_hand_chain]
34 | data[:, left_hand_chain] = tmp
35 | return data
36 |
37 |
38 | @hydra.main(
39 | config_path="../configs", config_name="compute_guoh3dfeats", version_base="1.3"
40 | )
41 | def compute_guoh3dfeats(cfg: DictConfig):
42 | base_folder = cfg.base_folder
43 | output_folder = cfg.output_folder
44 | force_redo = cfg.force_redo
45 |
46 | from src.guofeats import joints_to_guofeats
47 | from .tools import loop_amass
48 |
49 | output_folder_M = os.path.join(output_folder, "M")
50 |
51 | print("Get h3d features from Guo et al.")
52 | print("The processed motions will be stored in this folder:")
53 | print(output_folder)
54 |
55 | iterator = loop_amass(
56 | base_folder, output_folder, ext=".npy", newext=".npy", force_redo=force_redo
57 | )
58 |
59 | for motion_path, new_motion_path in iterator:
60 | joints = np.load(motion_path)
61 |
62 | if "humanact12" not in motion_path:
63 | # This is because the authors of HumanML3D
64 | # save the motions by swapping Y and Z (det = -1)
65 | # which is not a proper rotation (det = 1)
66 | # so we should invert x, to make it a rotation
67 | # that is why the authors use "data[..., 0] *= -1" inside the "if"
68 | # before swapping left/right
69 | # https://github.com/EricGuo5513/HumanML3D/blob/main/raw_pose_processing.ipynb
70 | joints[..., 0] *= -1
71 | # the humanact12 motions are normally saved correctly, no need to swap again
72 | # (but in fact this may not be true and the orignal H3D features
73 | # corresponding to HumanAct12 appears to be left/right flipped..)
74 | # At least we are compatible with previous work :/
75 |
76 | joints_m = swap_left_right(joints)
77 |
78 | # apply transformation
79 | try:
80 | features = joints_to_guofeats(joints)
81 | features_m = joints_to_guofeats(joints_m)
82 | except (IndexError, ValueError):
83 | # The sequence should be only 1 frame long
84 | # so we cannot compute features (which involve velocities etc)
85 | assert len(joints) == 1
86 | continue
87 | # save the features
88 | np.save(new_motion_path, features)
89 |
90 | # save the mirrored features as well
91 | new_motion_path_M = new_motion_path.replace(output_folder, output_folder_M)
92 | os.makedirs(os.path.split(new_motion_path_M)[0], exist_ok=True)
93 | np.save(new_motion_path_M, features_m)
94 |
95 |
96 | if __name__ == "__main__":
97 | compute_guoh3dfeats()
98 |
--------------------------------------------------------------------------------
/prepare/download_pretrain_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo -e "The pretrained models will stored in the 'models' folder\n"
3 | mkdir -p models
4 | python -m gdown.cli "https://drive.google.com/uc?id=1n6kRb-d2gKsk8EXfFULFIpaUKYcnaYmm"
5 |
6 | echo -e "Please check that the md5sum is: 7b6d8814f9c1ca972f62852ebb6c7a6f"
7 | echo -e "+ md5sum tmr_models.tgz"
8 | md5sum tmr_models.tgz
9 |
10 | echo -e "If it is not, please rerun this script"
11 |
12 | sleep 5
13 | tar xfzv tmr_models.tgz
14 |
15 | echo -e "Cleaning\n"
16 | rm tmr_models.tgz
17 |
18 | echo -e "Downloading done!"
19 |
--------------------------------------------------------------------------------
/prepare/motion_stats.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import hydra
3 | from omegaconf import DictConfig
4 | from hydra.utils import instantiate
5 | from tqdm import tqdm
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | @hydra.main(config_path="../configs", config_name="motion_stats", version_base="1.3")
11 | def motion_stats(cfg: DictConfig):
12 | logger.info("Computing motion stats")
13 | import src.prepare # noqa
14 |
15 | train_dataset = instantiate(cfg.data, split="train")
16 | import torch
17 |
18 | feats = torch.cat([x["motion_x_dict"]["x"] for x in tqdm(train_dataset)])
19 | mean = feats.mean(0)
20 | std = feats.std(0)
21 |
22 | normalizer = train_dataset.motion_loader.normalizer
23 | logger.info(f"Saving them in {normalizer.base_dir}")
24 | normalizer.save(mean, std)
25 |
26 |
27 | if __name__ == "__main__":
28 | motion_stats()
29 |
--------------------------------------------------------------------------------
/prepare/text_embeddings.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import hydra
4 | from omegaconf import DictConfig
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | @hydra.main(config_path="../configs", config_name="text_embeddings", version_base="1.3")
10 | def text_embeddings(cfg: DictConfig):
11 | device = cfg.device
12 |
13 | from src.data.text import save_token_embeddings, save_sent_embeddings
14 |
15 | annotations_filename = cfg.annotations_filename
16 |
17 | # Compute token embeddings
18 | modelname = cfg.data.text_to_token_emb.modelname
19 | modelpath = cfg.data.text_to_token_emb.modelpath
20 | logger.info(f"Compute token embeddings for {modelname}")
21 | path = cfg.data.text_to_token_emb.path
22 | output_folder_name = cfg.output_folder_name_token
23 | save_token_embeddings(path, annotations_filename=annotations_filename,
24 | output_folder_name=output_folder_name,
25 | modelname=modelname, modelpath=modelpath,
26 | device=device)
27 |
28 | # Compute sent embeddings
29 | modelname = cfg.data.text_to_sent_emb.modelname
30 | modelpath = cfg.data.text_to_sent_emb.modelpath
31 | logger.info(f"Compute sentence embeddings for {modelname}")
32 | path = cfg.data.text_to_sent_emb.path
33 | output_folder_name = cfg.output_folder_name_sent
34 | save_sent_embeddings(path, annotations_filename=annotations_filename,
35 | output_folder_name=output_folder_name,
36 | modelname=modelname, modelpath=modelpath,
37 | device=device)
38 |
39 |
40 | if __name__ == "__main__":
41 | text_embeddings()
42 |
--------------------------------------------------------------------------------
/prepare/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 |
5 |
6 | def loop_amass(
7 | base_folder,
8 | new_base_folder,
9 | ext=".npz",
10 | newext=".npz",
11 | force_redo=False,
12 | exclude=None,
13 | ):
14 | match_str = f"**/*{ext}"
15 |
16 | for motion_file in tqdm(glob(match_str, root_dir=base_folder, recursive=True)):
17 | if exclude and exclude in motion_file:
18 | continue
19 |
20 | motion_path = os.path.join(base_folder, motion_file)
21 |
22 | if motion_path.endswith("shape.npz"):
23 | continue
24 |
25 | new_motion_path = os.path.join(
26 | new_base_folder, motion_file.replace(ext, newext)
27 | )
28 | if not force_redo and os.path.exists(new_motion_path):
29 | continue
30 |
31 | new_folder = os.path.split(new_motion_path)[0]
32 | os.makedirs(new_folder, exist_ok=True)
33 |
34 | yield motion_path, new_motion_path
35 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.5
2 | aiosignal==1.3.1
3 | antlr4-python3-runtime==4.9.3
4 | async-timeout==4.0.3
5 | attrs==23.1.0
6 | certifi==2023.7.22
7 | charset-normalizer==2.1.1
8 | cmake==3.25.0
9 | colorlog==6.7.0
10 | einops==0.6.1
11 | filelock==3.9.0
12 | frozenlist==1.4.0
13 | fsspec==2023.9.0
14 | hydra-colorlog==1.2.0
15 | hydra-core==1.3.2
16 | idna==3.4
17 | Jinja2==3.1.2
18 | lightning-utilities==0.9.0
19 | lit==15.0.7
20 | MarkupSafe==2.1.2
21 | mpmath==1.3.0
22 | multidict==6.0.4
23 | networkx==3.0
24 | numpy==1.24.1
25 | omegaconf==2.3.0
26 | orjson==3.9.7
27 | packaging==23.1
28 | Pillow==9.3.0
29 | pytorch-lightning==2.0.9
30 | PyYAML==6.0.1
31 | requests==2.31.0
32 | scipy==1.11.2
33 | sympy==1.11.1
34 | torchmetrics==1.1.2
35 | transformers==4.41.2
36 | tqdm==4.66.1
37 | triton==2.0.0
38 | typing_extensions==4.4.0
39 | urllib3==1.26.13
40 | yarl==1.9.2
41 |
--------------------------------------------------------------------------------
/retrieval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 | import yaml
6 | from tqdm import tqdm
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def save_metric(path, metrics):
12 | strings = yaml.dump(metrics, indent=4, sort_keys=False)
13 | with open(path, "w") as f:
14 | f.write(strings)
15 |
16 |
17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256):
18 | import torch
19 | import numpy as np
20 | from src.data.collate import collate_text_motion
21 | from src.model.tmr import get_sim_matrix
22 |
23 | device = model.device
24 |
25 | nsplit = int(np.ceil(len(keyids) / min(batch_size, len(keyids))))
26 | with torch.inference_mode():
27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids]
28 | all_data_splitted = np.array_split(all_data, nsplit)
29 |
30 | # by batch (can be too costly on cuda device otherwise)
31 | latent_texts = []
32 | latent_motions = []
33 | sent_embs = []
34 | for data in tqdm(all_data_splitted, leave=True):
35 | batch = collate_text_motion(data, device=device)
36 | # Text is already encoded
37 | text_x_dict = batch["text_x_dict"]
38 | motion_x_dict = batch["motion_x_dict"]
39 | sent_emb = batch["sent_emb"]
40 |
41 | # Encode both motion and text
42 | latent_text = model.encode(text_x_dict, sample_mean=True)
43 | latent_motion = model.encode(motion_x_dict, sample_mean=True)
44 |
45 | latent_texts.append(latent_text)
46 | latent_motions.append(latent_motion)
47 | sent_embs.append(sent_emb)
48 |
49 | latent_texts = torch.cat(latent_texts)
50 | latent_motions = torch.cat(latent_motions)
51 | sent_embs = torch.cat(sent_embs)
52 | sim_matrix = get_sim_matrix(latent_texts, latent_motions)
53 | returned = {
54 | "sim_matrix": sim_matrix.cpu().numpy(),
55 | "sent_emb": sent_embs.cpu().numpy(),
56 | }
57 | return returned
58 |
59 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval")
60 | def retrieval(newcfg: DictConfig) -> None:
61 | protocol = newcfg.protocol
62 | threshold_val = newcfg.threshold
63 | device = newcfg.device
64 | run_dir = newcfg.run_dir
65 | ckpt_name = newcfg.ckpt
66 | batch_size = newcfg.batch_size
67 | save_file_name = newcfg.save_file_name
68 | split = newcfg.split
69 |
70 | print("protocol : ", protocol)
71 | assert protocol in ["all", "normal", "threshold", "nsim", "guo", "normal_no_mirror", "threshold_no_mirror"]
72 | assert split == "test" or (protocol != "nsim" and protocol != "all")
73 |
74 | if protocol == "all":
75 | protocols = ["normal", "threshold", "nsim", "guo"]
76 | else:
77 | protocols = [protocol]
78 |
79 | save_dir = os.path.join(run_dir, save_file_name)
80 | os.makedirs(save_dir, exist_ok=True)
81 |
82 | # Load last config
83 | from src.config import read_config
84 | import src.prepare # noqa
85 |
86 | cfg = read_config(run_dir)
87 |
88 | import pytorch_lightning as pl
89 | import numpy as np
90 | from hydra.utils import instantiate
91 | from src.load import load_model_from_cfg
92 | from src.model.metrics import all_contrastive_metrics, print_latex_metrics
93 |
94 | pl.seed_everything(cfg.seed)
95 |
96 | logger.info("Loading the model")
97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
98 |
99 |
100 | data = newcfg.data
101 | if data is None:
102 | data = cfg.data
103 |
104 | datasets = {}
105 | results = {}
106 | for protocol in protocols:
107 | # Load the dataset if not already
108 | if protocol not in datasets:
109 | if protocol in ["normal", "threshold", "guo"]:
110 | dataset = instantiate(data, split=split)
111 | datasets.update(
112 | {key: dataset for key in ["normal", "threshold", "guo"]}
113 | )
114 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]:
115 | datasets[protocol] = instantiate(data, split=split + "_no_mirror")
116 | elif protocol == "nsim":
117 | datasets[protocol] = instantiate(data, split="nsim_test")
118 | dataset = datasets[protocol]
119 |
120 | # Compute sim_matrix for each protocol
121 | if protocol not in results:
122 | if protocol in ["normal", "threshold"]:
123 | res = compute_sim_matrix(
124 | model, dataset, dataset.keyids, batch_size=batch_size
125 | )
126 | results.update({key: res for key in ["normal", "threshold"]})
127 | elif protocol in ["normal_no_mirror", "threshold_no_mirror"]:
128 | res = compute_sim_matrix(
129 | model, dataset, dataset.keyids, batch_size=batch_size
130 | )
131 | results[protocol] = res
132 | elif protocol == "nsim":
133 | res = compute_sim_matrix(
134 | model, dataset, dataset.keyids, batch_size=batch_size
135 | )
136 | results[protocol] = res
137 | elif protocol == "guo":
138 | keyids = sorted(dataset.keyids)
139 | N = len(keyids)
140 |
141 | # make batches of 32
142 | idx = np.arange(N)
143 | np.random.seed(0)
144 | np.random.shuffle(idx)
145 | idx_batches = [
146 | idx[32 * i : 32 * (i + 1)] for i in range(len(keyids) // 32)
147 | ]
148 |
149 | # split into batches of 32
150 | # batched_keyids = [ [32], [32], [...]]
151 | results["guo"] = [
152 | compute_sim_matrix(
153 | model,
154 | dataset,
155 | np.array(keyids)[idx_batch],
156 | batch_size=batch_size,
157 | )
158 | for idx_batch in idx_batches
159 | ]
160 | result = results[protocol]
161 |
162 | # Compute the metrics
163 | if protocol == "guo":
164 | all_metrics = []
165 | for x in result:
166 | sim_matrix = x["sim_matrix"]
167 | metrics = all_contrastive_metrics(sim_matrix, rounding=None)
168 | all_metrics.append(metrics)
169 |
170 | avg_metrics = {}
171 | for key in all_metrics[0].keys():
172 | avg_metrics[key] = round(
173 | float(np.mean([metrics[key] for metrics in all_metrics])), 2
174 | )
175 |
176 | metrics = avg_metrics
177 | protocol_name = protocol
178 | else:
179 | sim_matrix = result["sim_matrix"]
180 |
181 | protocol_name = protocol
182 | if protocol == "threshold":
183 | emb = result["sent_emb"]
184 | threshold = threshold_val
185 | protocol_name = protocol + f"_{threshold}"
186 | else:
187 | emb, threshold = None, None
188 | metrics = all_contrastive_metrics(sim_matrix, emb, threshold=threshold, t2m=True, m2t=False)
189 |
190 | print_latex_metrics(metrics, ranks=[1, 3, 10], t2m=True, m2t=False, MedR=False)
191 |
192 | print("protocol_name : ", protocol_name)
193 | metric_name = f"{protocol_name}.yaml"
194 | path = os.path.join(save_dir, metric_name)
195 | save_metric(path, metrics)
196 |
197 | logger.info(f"Testing done, metrics saved in:\n{path}")
198 |
199 |
200 | if __name__ == "__main__":
201 | retrieval()
202 |
--------------------------------------------------------------------------------
/retrieval_action.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 | import yaml
6 | from tqdm import tqdm
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def save_metric(path, metrics):
12 | strings = yaml.dump(metrics, indent=4, sort_keys=False)
13 | with open(path, "w") as f:
14 | f.write(strings)
15 |
16 |
17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256):
18 | import torch
19 | import numpy as np
20 | from src.data.collate import collate_text_motion
21 | from src.model.tmr import get_sim_matrix
22 |
23 | device = model.device
24 |
25 | nsplit = int(np.ceil(len(dataset) / batch_size))
26 | with torch.inference_mode():
27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids]
28 | all_data_splitted = np.array_split(all_data, nsplit)
29 |
30 | # by batch (can be too costly on cuda device otherwise)
31 | latent_texts = []
32 | latent_motions = []
33 | sent_embs = []
34 |
35 | for data in tqdm(all_data_splitted, leave=True):
36 | batch = collate_text_motion(data, device=device)
37 |
38 | # Text is already encoded
39 | text_x_dict = batch["text_x_dict"]
40 | motion_x_dict = batch["motion_x_dict"]
41 | sent_emb = batch["sent_emb"]
42 |
43 | # Encode both motion and text
44 | latent_text = model.encode(text_x_dict, sample_mean=True)
45 | latent_motion = model.encode(motion_x_dict, sample_mean=True)
46 |
47 | latent_texts.append(latent_text)
48 | latent_motions.append(latent_motion)
49 | sent_embs.append(sent_emb)
50 |
51 | latent_texts = torch.cat(latent_texts)
52 | action_latent_text = torch.unique(latent_texts, dim=0)
53 |
54 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))}
55 |
56 | latent_motions = torch.cat(latent_motions)
57 | motion_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_motions))]
58 |
59 | #sent_embs = torch.cat(sent_embs)
60 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions)
61 | returned = {
62 | "sim_matrix": sim_matrix.cpu().numpy(),
63 | "motion_cat_idx": motion_cat_idx
64 | }
65 | return returned
66 |
67 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval")
68 | def retrieval(newcfg: DictConfig) -> None:
69 | device = newcfg.device
70 | run_dir = newcfg.run_dir
71 | ckpt_name = newcfg.ckpt
72 | batch_size = newcfg.batch_size
73 | save_file_name = newcfg.save_file_name
74 | split = newcfg.split
75 |
76 | assert split == "test"
77 | protocols = ["normal"]
78 |
79 | save_dir = os.path.join(run_dir, save_file_name)
80 | os.makedirs(save_dir, exist_ok=True)
81 |
82 | # Load last config
83 | from src.config import read_config
84 | import src.prepare # noqa
85 |
86 | cfg = read_config(run_dir)
87 |
88 | import pytorch_lightning as pl
89 | import numpy as np
90 | from hydra.utils import instantiate
91 | from src.load import load_model_from_cfg
92 | from src.model.metrics import all_contrastive_metrics_action_retrieval, print_latex_metrics
93 |
94 | pl.seed_everything(cfg.seed)
95 |
96 | logger.info("Loading the model")
97 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
98 |
99 |
100 | data = newcfg.data
101 | if data is None:
102 | data = cfg.data
103 |
104 | datasets = {}
105 | for protocol in protocols:
106 | # Load the dataset if not already
107 | if protocol not in datasets:
108 | dataset = instantiate(data, split=split)
109 | datasets.update(
110 | {key: dataset for key in ["normal", "threshold", "guo"]}
111 | )
112 | dataset = datasets[protocol]
113 |
114 | # Compute sim_matrix for each protocol
115 | protocol = "normal"
116 | result = compute_sim_matrix(
117 | model, dataset, dataset.keyids, batch_size=batch_size
118 | )
119 |
120 | # Compute the metrics
121 | sim_matrix = result["sim_matrix"]
122 | motion_cat_idx = result["motion_cat_idx"]
123 |
124 | protocol_name = protocol
125 | metrics = all_contrastive_metrics_action_retrieval(sim_matrix, motion_cat_idx, norm_metrics=True)
126 |
127 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False)
128 |
129 | metric_name = f"{protocol_name}.yaml"
130 | path = os.path.join(save_dir, metric_name)
131 | save_metric(path, metrics)
132 |
133 | logger.info(f"Testing done, metrics saved in:\n{path}")
134 |
135 |
136 | if __name__ == "__main__":
137 | retrieval()
138 |
--------------------------------------------------------------------------------
/retrieval_action_multi_labels.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 | import yaml
6 | from tqdm import tqdm
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | def save_metric(path, metrics):
12 | strings = yaml.dump(metrics, indent=4, sort_keys=False)
13 | with open(path, "w") as f:
14 | f.write(strings)
15 |
16 |
17 | def compute_sim_matrix(model, dataset, keyids, batch_size=256):
18 | import torch
19 | import numpy as np
20 | from src.data.collate import collate_text_motion_multiple_texts
21 | from src.model.tmr import get_sim_matrix
22 |
23 | device = model.device
24 |
25 | nsplit = int(np.ceil(len(dataset) / batch_size))
26 | with torch.inference_mode():
27 | all_data = [dataset.load_keyid(keyid) for keyid in keyids]
28 | all_data_splitted = np.array_split(all_data, nsplit)
29 |
30 | # by batch (can be too costly on cuda device otherwise)
31 | latent_texts = []
32 | latent_motions = []
33 | sent_embs = []
34 | text_indices = []
35 | indices_shift = 0
36 |
37 | #for data in tqdm(all_data_splitted, leave=True):
38 | for data in all_data_splitted:
39 |
40 | batch = collate_text_motion_multiple_texts(data, device=device)
41 | # Text is already encoded
42 | text_x_dict = batch["text_x_dict"]
43 | motion_x_dict = batch["motion_x_dict"]
44 | sent_emb = batch["sent_emb"]
45 |
46 | # Encode both motion and text
47 | latent_text = model.encode(text_x_dict, sample_mean=True)
48 | latent_motion = model.encode(motion_x_dict, sample_mean=True)
49 |
50 | latent_texts.append(latent_text)
51 | latent_motions.append(latent_motion)
52 | sent_embs.append(sent_emb)
53 | idx = batch["text_slices"]
54 | idx = [[elt[0] + indices_shift, elt[1] + indices_shift] for elt in idx]
55 | text_indices.extend(idx)
56 | indices_shift += len(latent_text)
57 |
58 | latent_texts = torch.cat(latent_texts)
59 | action_latent_text = torch.unique(latent_texts, dim=0)
60 | action_latent_text_idx = {tuple(action_latent_text[i].to("cpu").numpy()): i for i in range(len(action_latent_text))}
61 | text_cat_idx = [action_latent_text_idx[tuple(latent_texts[i].to("cpu").numpy())] for i in range(len(latent_texts))]
62 |
63 | latent_motions = torch.cat(latent_motions)
64 | motion_cat_idx = []
65 | for start_ind, end_ind in text_indices:
66 | motion_cat_idx.append(text_cat_idx[start_ind:end_ind])
67 |
68 | #sent_embs = torch.cat(sent_embs)
69 | sim_matrix = get_sim_matrix(action_latent_text, latent_motions)
70 |
71 | returned = {
72 | "sim_matrix": sim_matrix.cpu().numpy(),
73 | "motion_cat_idx": motion_cat_idx
74 | }
75 | return returned
76 |
77 | @hydra.main(version_base=None, config_path="configs", config_name="retrieval_action_multi_labels")
78 | def retrieval_action_multi_labels(newcfg: DictConfig) -> None:
79 | device = newcfg.device
80 | run_dir = newcfg.run_dir
81 | ckpt_name = newcfg.ckpt
82 | batch_size = newcfg.batch_size
83 | save_file_name = newcfg.save_file_name
84 | split = newcfg.split
85 |
86 | assert split == "test"
87 |
88 | save_dir = os.path.join(run_dir, save_file_name)
89 | os.makedirs(save_dir, exist_ok=True)
90 |
91 | # Load last config
92 | from src.config import read_config
93 | import src.prepare # noqa
94 |
95 | cfg = read_config(run_dir)
96 |
97 | import pytorch_lightning as pl
98 | import numpy as np
99 | from hydra.utils import instantiate
100 | from src.load import load_model_from_cfg
101 | from src.model.metrics import all_contrastive_metrics_action_retrieval_multi_labels, print_latex_metrics
102 |
103 | pl.seed_everything(cfg.seed)
104 |
105 | logger.info("Loading the model")
106 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
107 |
108 |
109 | data = newcfg.data
110 | if data is None:
111 | data = cfg.data
112 |
113 | dataset = instantiate(data, split=split)
114 |
115 | # Compute sim_matrix for each protocol
116 | protocol = "normal"
117 | result = compute_sim_matrix(
118 | model, dataset, dataset.keyids, batch_size=batch_size
119 | )
120 |
121 | # Compute the metrics
122 | sim_matrix = result["sim_matrix"]
123 | motion_cat_idx = result["motion_cat_idx"]
124 |
125 | protocol_name = protocol
126 | metrics = all_contrastive_metrics_action_retrieval_multi_labels(sim_matrix, motion_cat_idx, norm_metrics=True)
127 |
128 | print_latex_metrics(metrics, ranks=[1, 2, 3, 5, 10], t2m=False, m2t=True, MedR=False)
129 |
130 | metric_name = f"{protocol_name}.yaml"
131 | path = os.path.join(save_dir, metric_name)
132 | save_metric(path, metrics)
133 |
134 | logger.info(f"Testing done, metrics saved in:\n{path}")
135 |
136 |
137 | if __name__ == "__main__":
138 | retrieval_action_multi_labels()
139 |
--------------------------------------------------------------------------------
/src/callback/__init__.py:
--------------------------------------------------------------------------------
1 | # from .render import RenderCallback
2 | from .progress import ProgressLogger
3 |
--------------------------------------------------------------------------------
/src/callback/progress.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pytorch_lightning import LightningModule, Trainer
4 | from pytorch_lightning.callbacks import Callback
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class ProgressLogger(Callback):
11 | def __init__(self, precision: int = 2):
12 | self.precision = precision
13 |
14 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, **kwargs):
15 | logger.info("Training started")
16 |
17 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, **kwargs):
18 | logger.info("Training done")
19 |
20 | def on_validation_epoch_end(
21 | self, trainer: Trainer, pl_module: LightningModule, **kwargs
22 | ):
23 | if trainer.sanity_checking:
24 | logger.info("Sanity checking ok.")
25 |
26 | def on_train_epoch_end(
27 | self, trainer: Trainer, pl_module: LightningModule, **kwargs
28 | ):
29 | metric_format = f"{{:.{self.precision}e}}"
30 | line = f"Epoch {trainer.current_epoch}"
31 | metrics_str = []
32 |
33 | losses_dict = trainer.callback_metrics
34 |
35 | def is_contrastive_metrics(x):
36 | return "t2m" in x or "m2t" in x
37 |
38 | losses_to_print = [
39 | x
40 | for x in losses_dict.keys()
41 | for y in [x.split("_")]
42 | if len(y) == 3
43 | and y[2] == "epoch"
44 | and (
45 | y[1] in pl_module.lmd or y[1] == "loss" or is_contrastive_metrics(y[1])
46 | )
47 | ]
48 |
49 | # Natual order for contrastive
50 | letters = "0123456789"
51 | mapping = str.maketrans(letters, letters[::-1])
52 |
53 | def sort_losses(x):
54 | split, name, epoch_step = x.split("_")
55 | if is_contrastive_metrics(x):
56 | # put them at the end
57 | name = "a" + name.translate(mapping)
58 | return (name, split)
59 |
60 | losses_to_print = sorted(losses_to_print, key=sort_losses, reverse=True)
61 | for metric_name in losses_to_print:
62 | split, name, _ = metric_name.split("_")
63 |
64 | metric = losses_dict[metric_name].item()
65 |
66 | if is_contrastive_metrics(metric_name):
67 | if "len" in metric_name:
68 | metric = str(int(metric))
69 | elif "MedR" in metric_name:
70 | metric = str(int(metric * 100) / 100) + "%"
71 | else:
72 | metric = str(int(metric * 100) / 100) + "%"
73 | else:
74 | metric = metric_format.format(metric)
75 |
76 | if split == "train":
77 | mname = name
78 | else:
79 | mname = f"v_{name}"
80 |
81 | metric = f"{mname} {metric}"
82 | metrics_str.append(metric)
83 |
84 | if len(metrics_str) == 0:
85 | return
86 |
87 | line = line + ": " + " ".join(metrics_str)
88 | logger.info(line)
89 |
--------------------------------------------------------------------------------
/src/callback/tqdmbar.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytorch_lightning as pl
3 | from pytorch_lightning.callbacks import TQDMProgressBar as OriginalTQDMProgressBar
4 |
5 |
6 | def customize_bar(bar):
7 | if not sys.stdout.isatty():
8 | bar.disable = True
9 | bar.leave = True # remove the bar after completion
10 | return bar
11 |
12 |
13 | class TQDMProgressBar(OriginalTQDMProgressBar):
14 | # remove the annoying v_num in the bar
15 | def get_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
16 | items_dict = super().get_metrics(trainer, pl_module).copy()
17 |
18 | if "v_num" in items_dict:
19 | items_dict.pop("v_num")
20 | return items_dict
21 |
22 | def init_sanity_tqdm(self):
23 | bar = super().init_sanity_tqdm()
24 | return customize_bar(bar)
25 |
26 | def init_train_tqdm(self):
27 | bar = super().init_train_tqdm()
28 | return customize_bar(bar)
29 |
30 | def init_validation_tqdm(self):
31 | bar = super().init_validation_tqdm()
32 | bar.disable = True
33 | return bar
34 |
35 | def init_predict_tqdm(self):
36 | bar = super().init_predict_tqdm()
37 | return customize_bar(bar)
38 |
39 | def init_test_tqdm(self):
40 | bar = super().init_test_tqdm()
41 | return customize_bar(bar)
42 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from omegaconf import DictConfig, OmegaConf
4 |
5 |
6 | def save_config(cfg: DictConfig) -> str:
7 | path = os.path.join(cfg.run_dir, "config.json")
8 | config = OmegaConf.to_container(cfg, resolve=True)
9 | with open(path, "w") as f:
10 | string = json.dumps(config, indent=4)
11 | f.write(string)
12 | return path
13 |
14 |
15 | def read_config(run_dir: str, return_json=False) -> DictConfig:
16 | path = os.path.join(run_dir, "config.json")
17 | with open(path, "r") as f:
18 | config = json.load(f)
19 | if return_json:
20 | return config
21 | cfg = OmegaConf.create(config)
22 | cfg.run_dir = run_dir
23 | return cfg
24 |
--------------------------------------------------------------------------------
/src/data/augmented_text_motion.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import numpy as np
4 | import random
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from .collate import collate_text_motion_multiple_texts
9 | from .text_motion import load_annotations, TextMotionDataset
10 |
11 |
12 | class AugmentedTextMotionDataset(TextMotionDataset):
13 | def __init__(
14 | self,
15 | path: str,
16 | motion_loader,
17 | text_to_sent_emb,
18 | text_to_token_emb,
19 | split: str = "train",
20 | min_seconds: float = 2.0,
21 | max_seconds: float = 10.0,
22 | preload: bool = True,
23 | tiny: bool = False,
24 | paraphrase_filename: str = None,
25 | summary_filename: str = None,
26 | paraphrase_prob: float = 0.2,
27 | summary_prob: float = 0.2,
28 | averaging_prob: float = 0.4,
29 | text_sampling_nbr: int = 4
30 | ):
31 | super().__init__(path, motion_loader, text_to_sent_emb, text_to_token_emb,
32 | split=split, min_seconds=min_seconds, max_seconds=max_seconds,
33 | preload=False, tiny=tiny)
34 |
35 | self.collate_fn = collate_text_motion_multiple_texts
36 |
37 | assert paraphrase_prob == 0 or paraphrase_filename is not None
38 | assert summary_prob == 0 or summary_filename is not None
39 |
40 | self.text_sampling_nbr = text_sampling_nbr
41 | self.paraphrase_prob = 0
42 | if split=="train" and paraphrase_filename is not None:
43 | self.annotations_paraphrased = load_annotations(path, name=paraphrase_filename)
44 | self.paraphrase_prob = paraphrase_prob
45 | self.summary_prob = 0
46 | if split=="train" and summary_filename is not None:
47 | self.annotations_summary = load_annotations(path, name=summary_filename)
48 | self.summary_prob = summary_prob
49 | self.averaging_prob = 0
50 | if split=="train" and paraphrase_filename is not None:
51 | self.averaging_prob = averaging_prob
52 |
53 | # filter annotations (min/max)
54 | # but not for the test set
55 | # otherwise it is not fair for everyone
56 | if "test" not in split:
57 | if "train" in split and paraphrase_filename is not None:
58 | self.annotations_paraphrased = self.filter_annotations(self.annotations_paraphrased)
59 | if "train" in split and summary_filename is not None:
60 | self.annotations_summary = self.filter_annotations(self.annotations_summary)
61 |
62 | if preload:
63 | for _ in tqdm(self, desc="Preloading the dataset"):
64 | continue
65 |
66 | def load_keyid(self, keyid, text_idx=None, sent_emb_mode="first"):
67 |
68 | p = random.random() # Probability that will determine if we use data from augmentation, and with which config
69 | averaging = False
70 | if self.is_training and p < self.paraphrase_prob:
71 | annotations = self.annotations_paraphrased[keyid]
72 | elif self.is_training and p < self.summary_prob + self.paraphrase_prob:
73 | if keyid in self.annotations_summary:
74 | annotations = self.annotations_summary[keyid]
75 | else:
76 | annotations = self.annotations_paraphrased[keyid] # For Babel that has no summary
77 | elif self.is_training and p < self.averaging_prob + self.summary_prob + self.paraphrase_prob:
78 | annotations = copy.deepcopy(self.annotations[keyid])
79 | if hasattr(self, "annotations_paraphrased") and keyid in self.annotations_paraphrased:
80 | annotations["annotations"] += self.annotations_paraphrased[keyid]["annotations"]
81 | if hasattr(self, "annotations_summary") and keyid in self.annotations_summary:
82 | annotations["annotations"] += self.annotations_summary[keyid]["annotations"]
83 | averaging = True
84 | else:
85 | annotations = self.annotations[keyid]
86 |
87 | # Take the first one for testing/validation
88 | # Otherwise take a random one
89 | index = 0
90 | if averaging:
91 | if isinstance(self.text_sampling_nbr, int): # If number of samples if provided
92 | n = min(self.text_sampling_nbr, len(annotations["annotations"]))
93 | else: # If number of sample not provided, it's chosen randomly
94 | n = random.randint(2, len(annotations["annotations"]))
95 | index = random.sample(range(0, len(annotations["annotations"])), n)
96 | elif text_idx is not None:
97 | index = text_idx % len(annotations["annotations"])
98 | elif self.is_training:
99 | index = np.random.randint(len(annotations["annotations"]))
100 |
101 | if isinstance(index, int):
102 | index = [index]
103 |
104 | annotation_list = [annotations["annotations"][i] for i in index]
105 | text = [ann["text"] for ann in annotation_list]
106 | annotation0 = annotations["annotations"][index[0]]
107 |
108 | text_x_dict = [self.text_to_token_emb(t) for t in text]
109 |
110 | motion_x_dict = self.motion_loader(
111 | path=annotations["path"],
112 | start=annotation0["start"],
113 | end=annotation0["end"],
114 | )
115 |
116 | if sent_emb_mode == "first":
117 | sent_emb = self.text_to_sent_emb(text[0])
118 | elif sent_emb_mode == "average":
119 | sent_emb = torch.stack([self.text_to_sent_emb(t) for t in text])
120 | sent_emb = torch.mean(sent_emb, axis=0)
121 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0)
122 |
123 | output = {
124 | "motion_x_dict": motion_x_dict,
125 | "text_x_dict": text_x_dict,
126 | "text": text,
127 | "keyid": keyid,
128 | "sent_emb": sent_emb,
129 | }
130 |
131 | # TODO
132 | #if device is not None:
133 | # output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device)
134 | # for i in range(len(output["text_x_dict"][i])):
135 | # output["text_x_dict"][i]["x"] = output["text_x_dict"][i]["x"].to(device)
136 |
137 | return output
138 |
--------------------------------------------------------------------------------
/src/data/collate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from typing import List, Dict, Optional
4 | from torch import Tensor
5 | from torch.utils.data import default_collate
6 |
7 |
8 | def length_to_mask(length, device: torch.device = None) -> Tensor:
9 | if device is None:
10 | device = "cpu"
11 |
12 | if isinstance(length, list):
13 | length = torch.tensor(length, device=device)
14 |
15 | max_len = max(length)
16 | mask = torch.arange(max_len, device=device).expand(
17 | len(length), max_len
18 | ) < length.unsqueeze(1)
19 | return mask
20 |
21 |
22 | def collate_tensor_with_padding(batch: List[Tensor]) -> Tensor:
23 | dims = batch[0].dim()
24 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
25 | size = (len(batch),) + tuple(max_size)
26 | canvas = batch[0].new_zeros(size=size)
27 | for i, b in enumerate(batch):
28 | sub_tensor = canvas[i]
29 | for d in range(dims):
30 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
31 | sub_tensor.add_(b)
32 | return canvas
33 |
34 |
35 | def collate_x_dict(lst_x_dict: List, *, device: Optional[str] = None) -> Dict:
36 | x = collate_tensor_with_padding([x_dict["x"] for x_dict in lst_x_dict])
37 | if device is not None:
38 | x = x.to(device)
39 | length = [x_dict["length"] for x_dict in lst_x_dict]
40 | mask = length_to_mask(length, device=x.device)
41 | batch = {"x": x, "length": length, "mask": mask}
42 | return batch
43 |
44 |
45 | def collate_text_motion(lst_elements: List, *, device: Optional[str] = None) -> Dict:
46 | one_el = lst_elements[0]
47 | keys = one_el.keys()
48 |
49 | x_dict_keys = [key for key in keys if "x_dict" in key]
50 | other_keys = [key for key in keys if "x_dict" not in key]
51 |
52 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys}
53 | for key, val in batch.items():
54 | if isinstance(val, torch.Tensor) and device is not None:
55 | batch[key] = val.to(device)
56 |
57 | for key in x_dict_keys:
58 | batch[key] = collate_x_dict([x[key] for x in lst_elements], device=device)
59 | return batch
60 |
61 |
62 | def collate_text_motion_multiple_texts(lst_elements: List, *, device: Optional[str] = None):
63 | other_keys = ['keyid', 'sent_emb']
64 |
65 | batch = {key: default_collate([x[key] for x in lst_elements]) for key in other_keys}
66 | batch["text"] = [elt["text"] for elt in lst_elements]
67 |
68 | for key, val in batch.items():
69 | if isinstance(val, torch.Tensor) and device is not None:
70 | batch[key] = val.to(device)
71 |
72 | batch["motion_x_dict"] = collate_x_dict([x["motion_x_dict"] for x in lst_elements], device=device)
73 |
74 | batch["text_slices"] = []
75 | current_index = 0
76 | for elt in lst_elements:
77 | batch["text_slices"].append((current_index, current_index + len(elt["text"])))
78 | current_index += len(elt["text"])
79 |
80 | texts_concat = [x_dict for x in lst_elements for x_dict in x["text_x_dict"]]
81 | batch["text_x_dict"] = collate_x_dict(
82 | texts_concat,
83 | device=device
84 | )
85 | return batch
86 |
--------------------------------------------------------------------------------
/src/data/motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class AMASSMotionLoader:
7 | def __init__(
8 | self, base_dir, fps, normalizer=None, disable: bool = False, nfeats=None
9 | ):
10 | self.fps = fps
11 | self.base_dir = base_dir
12 | self.motions = {}
13 | self.normalizer = normalizer
14 | self.disable = disable
15 | self.nfeats = nfeats
16 |
17 | def __call__(self, path, start, end):
18 | if self.disable:
19 | return {"x": path, "length": int(self.fps * (end - start))}
20 |
21 | begin = int(start * self.fps)
22 | end = int(end * self.fps)
23 | if path not in self.motions:
24 | motion_path = os.path.join(self.base_dir, path + ".npy")
25 | motion = np.load(motion_path)
26 | motion = torch.from_numpy(motion).to(torch.float)
27 | if self.normalizer is not None:
28 | motion = self.normalizer(motion)
29 | self.motions[path] = motion
30 |
31 | motion = self.motions[path][begin:end]
32 | x_dict = {"x": motion, "length": len(motion)}
33 | return x_dict
34 |
35 |
36 | class Normalizer:
37 | def __init__(self, base_dir: str, eps: float = 1e-12, disable: bool = False):
38 | self.base_dir = base_dir
39 | self.mean_path = os.path.join(base_dir, "mean.pt")
40 | self.std_path = os.path.join(base_dir, "std.pt")
41 | self.eps = eps
42 |
43 | self.disable = disable
44 | if not disable:
45 | self.load()
46 |
47 | def load(self):
48 | self.mean = torch.load(self.mean_path)
49 | self.std = torch.load(self.std_path)
50 |
51 | def save(self, mean, std):
52 | os.makedirs(self.base_dir, exist_ok=True)
53 | torch.save(mean, self.mean_path)
54 | torch.save(std, self.std_path)
55 |
56 | def __call__(self, x):
57 | if self.disable:
58 | return x
59 | x = (x - self.mean) / (self.std + self.eps)
60 | return x
61 |
62 | def inverse(self, x):
63 | if self.disable:
64 | return x
65 | x = x * (self.std + self.eps) + self.mean
66 | return x
67 |
--------------------------------------------------------------------------------
/src/data/text.py:
--------------------------------------------------------------------------------
1 | import os
2 | import orjson
3 | import json
4 | import torch
5 | from torch import Tensor
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from abc import ABC, abstractmethod
10 |
11 | from src.model import TextToEmb
12 |
13 |
14 | class TextEmbeddings(ABC):
15 | name = ...
16 |
17 | def __init__(
18 | self,
19 | modelname: str,
20 | modelpath: str = None,
21 | path: str = "",
22 | device: str = "cpu",
23 | preload: bool = True,
24 | disable: bool = False,
25 | name: str = None # TODO check if needed keep this change compared to TMR
26 | ):
27 | if name is not None:
28 | self.name = name
29 |
30 | self.modelname = modelname
31 | self.modelpath = modelpath
32 | if modelpath is None:
33 | self.modelpath = modelname
34 | self.embeddings_folder = os.path.join(path, self.name)
35 |
36 | self.cache = {}
37 | self.device = device
38 | self.disable = disable
39 |
40 | if preload and not disable:
41 | self.load_embeddings()
42 | else:
43 | self.embeddings_index = {}
44 |
45 | @abstractmethod
46 | def load_model(self) -> None:
47 | ...
48 |
49 | @abstractmethod
50 | def load_embeddings(self) -> None:
51 | ...
52 |
53 | @abstractmethod
54 | def get_embedding(self, text: str) -> Tensor:
55 | ...
56 |
57 | def __contains__(self, text):
58 | return text in self.embeddings_index
59 |
60 | def get_model(self):
61 | model = getattr(self, "model", None)
62 | if model is None:
63 | model = self.load_model()
64 | return model
65 |
66 | def __call__(self, texts):
67 | if self.disable:
68 | return texts
69 |
70 | squeeze = False
71 | if isinstance(texts, str):
72 | texts = [texts]
73 | squeeze = True
74 |
75 | x_dict_lst = []
76 | # one at a time here
77 | for text in texts:
78 | # Precomputed in advance
79 | if text in self:
80 | x_dict = self.get_embedding(text)
81 | # Already computed during the session
82 | elif text in self.cache:
83 | x_dict = self.cache[text]
84 | # Load the text model (if not already loaded) + compute on the fly
85 | else:
86 | model = self.get_model()
87 | x_dict = model(text)
88 | self.cache[text] = x_dict
89 | x_dict_lst.append(x_dict)
90 |
91 | if squeeze:
92 | return x_dict_lst[0]
93 | return x_dict_lst
94 |
95 |
96 | class TokenEmbeddings(TextEmbeddings):
97 | name = "token_embeddings"
98 |
99 | def load_model(self):
100 | self.model = TextToEmb(self.modelpath, mean_pooling=False, device=self.device)
101 | return self.model
102 |
103 | def load_embeddings(self):
104 | self.embeddings_big = torch.from_numpy(
105 | np.load(os.path.join(self.embeddings_folder, self.modelname + ".npy"))
106 | ).to(dtype=torch.float, device=self.device)
107 | self.embeddings_slice = np.load(
108 | os.path.join(self.embeddings_folder, self.modelname + "_slice.npy")
109 | )
110 | self.embeddings_index = load_json(
111 | os.path.join(self.embeddings_folder, self.modelname + "_index.json")
112 | )
113 | self.text_dim = self.embeddings_big.shape[-1]
114 |
115 | def get_embedding(self, text):
116 | # Precomputed in advance
117 | index = self.embeddings_index[text]
118 | begin, end = self.embeddings_slice[index]
119 | embedding = self.embeddings_big[begin:end]
120 | x_dict = {"x": embedding, "length": len(embedding)}
121 | return x_dict
122 |
123 |
124 | class SentenceEmbeddings(TextEmbeddings):
125 | name = "sent_embeddings"
126 |
127 | def load_model(self):
128 | self.model = TextToEmb(self.modelpath, mean_pooling=True, device=self.device)
129 | return self.model
130 |
131 | def load_embeddings(self):
132 | self.embeddings = torch.from_numpy(
133 | np.load(os.path.join(self.embeddings_folder, self.modelname + ".npy"))
134 | ).to(dtype=torch.float, device=self.device)
135 | self.embeddings_index = load_json(
136 | os.path.join(self.embeddings_folder, self.modelname + "_index.json")
137 | )
138 | assert len(self.embeddings_index) == len(self.embeddings)
139 |
140 | self.text_dim = self.embeddings.shape[-1]
141 |
142 | def get_embedding(self, text):
143 | index = self.embeddings_index[text]
144 | embedding = self.embeddings[index]
145 | return embedding.to(self.device)
146 |
147 |
148 | def load_json(json_path):
149 | with open(json_path, "rb") as ff:
150 | return orjson.loads(ff.read())
151 |
152 |
153 | def load_annotations(path, name="annotations.json"):
154 | json_path = os.path.join(path, name)
155 | return load_json(json_path)
156 |
157 |
158 | def write_json(data, path):
159 | with open(path, "w") as ff:
160 | ff.write(json.dumps(data, indent=4))
161 |
162 |
163 | def save_token_embeddings(
164 | path, annotations_filename="annotations.json", output_folder_name=None,
165 | modelname="sentence-transformers/all-mpnet-base-v2", modelpath=None, device="cuda"
166 | ):
167 | if modelpath is None:
168 | modelpath = modelname
169 | model = TextToEmb(modelpath, device=device)
170 |
171 | annotations = load_annotations(path, name=annotations_filename)
172 |
173 | if output_folder_name is None:
174 | output_folder_name = TokenEmbeddings.name
175 | path = os.path.join(path, output_folder_name)
176 |
177 | ptpath = os.path.join(path, f"{modelname}.npy")
178 | slicepath = os.path.join(path, f"{modelname}_slice.npy")
179 | jsonpath = os.path.join(path, f"{modelname}_index.json")
180 |
181 | # modelname can have folders
182 | path = os.path.split(ptpath)[0]
183 | os.makedirs(path, exist_ok=True)
184 |
185 | # fetch all the texts
186 | all_texts = []
187 | for dico in annotations.values():
188 | for lst in dico["annotations"]:
189 | all_texts.append(lst["text"])
190 |
191 | # remove duplicates
192 | all_texts = list(set(all_texts))
193 |
194 | # batch of N/10
195 | nb_tokens = []
196 | all_texts_batched = np.array_split(all_texts, min(len(all_texts), 100))
197 |
198 | nb_tokens_so_far = 0
199 | big_tensor = []
200 | index = []
201 | for all_texts_batch in tqdm(all_texts_batched):
202 | x_dict = model(list(all_texts_batch))
203 |
204 | tensor = x_dict["x"]
205 | nb_tokens = x_dict["length"]
206 |
207 | # remove padding
208 | tensor_no_padding = [x[:n].cpu() for x, n in zip(tensor, nb_tokens)]
209 | tensor_concat = torch.cat(tensor_no_padding)
210 |
211 | big_tensor.append(tensor_concat)
212 | # save where it is
213 | ends = torch.cumsum(nb_tokens, 0)
214 | begins = torch.cat((0 * ends[[0]], ends[:-1]))
215 |
216 | # offset
217 | ends += nb_tokens_so_far
218 | begins += nb_tokens_so_far
219 | nb_tokens_so_far += len(tensor_concat)
220 |
221 | index.append(torch.stack((begins, ends)).T)
222 |
223 | big_tensor = torch.cat(big_tensor).cpu().numpy()
224 | index = torch.cat(index).cpu().numpy()
225 |
226 | np.save(ptpath, big_tensor)
227 | np.save(slicepath, index)
228 | print(f"{ptpath} written")
229 | print(f"{slicepath} written")
230 |
231 | # correspondance
232 | dico = {txt: i for i, txt in enumerate(all_texts)}
233 | write_json(dico, jsonpath)
234 | print(f"{jsonpath} written")
235 |
236 |
237 | def save_sent_embeddings(
238 | path, annotations_filename="annotations.json", output_folder_name=None,
239 | modelname="sentence-transformers/all-mpnet-base-v2", modelpath=None, device="cuda"
240 | ):
241 | # Provide modelpath as a path to the local folder of the model if you can't access internet during training
242 | if modelpath is None:
243 | modelpath = modelname
244 | model = TextToEmb(modelpath, mean_pooling=True, device=device)
245 | annotations = load_annotations(path, name=annotations_filename)
246 |
247 | if output_folder_name is None:
248 | output_folder_name = SentenceEmbeddings.name
249 | path = os.path.join(path, output_folder_name)
250 |
251 | ptpath = os.path.join(path, f"{modelname}.npy")
252 | jsonpath = os.path.join(path, f"{modelname}_index.json")
253 |
254 | # modelname can have folders
255 | path = os.path.split(ptpath)[0]
256 | os.makedirs(path, exist_ok=True)
257 |
258 | # fetch all the texts
259 | all_texts = []
260 | for dico in annotations.values():
261 | for lst in dico["annotations"]:
262 | all_texts.append(lst["text"])
263 |
264 | # remove duplicates
265 | all_texts = list(set(all_texts))
266 |
267 | # batch of N/10
268 | all_texts_batched = np.array_split(all_texts, min(len(all_texts), 100))
269 | embeddings = []
270 | for all_texts_batch in tqdm(all_texts_batched):
271 | embedding = model(list(all_texts_batch)).cpu()
272 | embeddings.append(embedding)
273 |
274 | embeddings = torch.cat(embeddings).numpy()
275 | np.save(ptpath, embeddings)
276 | print(f"{ptpath} written")
277 |
278 | # correspondance
279 | dico = {txt: i for i, txt in enumerate(all_texts)}
280 | write_json(dico, jsonpath)
281 | print(f"{jsonpath} written")
282 |
283 |
--------------------------------------------------------------------------------
/src/data/text_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import codecs as cs
3 | import orjson # loading faster than json
4 | import json
5 |
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | from tqdm import tqdm
9 |
10 | from .collate import collate_text_motion
11 |
12 |
13 | def read_split(path, split):
14 | split_file = os.path.join(path, "splits", split + ".txt")
15 | id_list = []
16 | with cs.open(split_file, "r") as f:
17 | for line in f.readlines():
18 | id_list.append(line.strip())
19 | return id_list
20 |
21 |
22 | def load_annotations(path, name="annotations.json"):
23 | json_path = os.path.join(path, name)
24 | with open(json_path, "rb") as ff:
25 | return orjson.loads(ff.read())
26 |
27 |
28 | class TextMotionDataset(Dataset):
29 | def __init__(
30 | self,
31 | path: str,
32 | motion_loader,
33 | text_to_sent_emb,
34 | text_to_token_emb,
35 | split: str = "train",
36 | min_seconds: float = 2.0,
37 | max_seconds: float = 10.0,
38 | preload: bool = True,
39 | tiny: bool = False,
40 | ):
41 | if tiny:
42 | split = split + "_tiny"
43 |
44 | self.collate_fn = collate_text_motion
45 | self.split = split
46 | self.keyids = read_split(path, split)
47 |
48 | self.text_to_sent_emb = text_to_sent_emb
49 | self.text_to_token_emb = text_to_token_emb
50 | self.motion_loader = motion_loader
51 |
52 | self.min_seconds = min_seconds
53 | self.max_seconds = max_seconds
54 |
55 | # remove too short or too long annotations
56 | self.annotations = load_annotations(path)
57 |
58 | # filter annotations (min/max)
59 | # but not for the test set
60 | # otherwise it is not fair for everyone
61 | if "test" not in split:
62 | self.annotations = self.filter_annotations(self.annotations)
63 |
64 | self.is_training = "train" in split
65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations]
66 | self.nfeats = self.motion_loader.nfeats
67 |
68 | if preload:
69 | for _ in tqdm(self, desc="Preloading the dataset"):
70 | continue
71 |
72 | def __len__(self):
73 | return len(self.keyids)
74 |
75 | def __getitem__(self, index):
76 | keyid = self.keyids[index]
77 | return self.load_keyid(keyid)
78 |
79 | def load_keyid(self, keyid):
80 | annotations = self.annotations[keyid]
81 |
82 | # Take the first one for testing/validation
83 | # Otherwise take a random one
84 | index = 0
85 | if self.is_training:
86 | index = np.random.randint(len(annotations["annotations"]))
87 | annotation = annotations["annotations"][index]
88 |
89 | text = annotation["text"]
90 | text_x_dict = self.text_to_token_emb(text)
91 | motion_x_dict = self.motion_loader(
92 | path=annotations["path"],
93 | start=annotation["start"],
94 | end=annotation["end"],
95 | )
96 | sent_emb = self.text_to_sent_emb(text)
97 |
98 | output = {
99 | "motion_x_dict": motion_x_dict,
100 | "text_x_dict": text_x_dict,
101 | "text": text,
102 | "keyid": keyid,
103 | "sent_emb": sent_emb,
104 | }
105 | return output
106 |
107 | def filter_annotations(self, annotations):
108 | filtered_annotations = {}
109 | for key, val in annotations.items():
110 | annots = val.pop("annotations")
111 | filtered_annots = []
112 | for annot in annots:
113 | duration = annot["end"] - annot["start"]
114 | if self.max_seconds >= duration >= self.min_seconds:
115 | filtered_annots.append(annot)
116 |
117 | if filtered_annots:
118 | val["annotations"] = filtered_annots
119 | filtered_annotations[key] = val
120 |
121 | return filtered_annotations
122 |
123 |
124 | def write_json(data, path):
125 | with open(path, "w") as ff:
126 | ff.write(json.dumps(data, indent=4))
127 |
--------------------------------------------------------------------------------
/src/data/text_motion_multi_labels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import codecs as cs
3 | import orjson # loading faster than json
4 | import json
5 | import logging
6 | import random
7 |
8 | import torch
9 | import numpy as np
10 | from .text_motion import TextMotionDataset
11 | from tqdm import tqdm
12 |
13 | from .collate import collate_text_motion_multiple_texts
14 |
15 |
16 |
17 | def read_split(path, split):
18 | split_file = os.path.join(path, "splits", split + ".txt")
19 | id_list = []
20 | with cs.open(split_file, "r") as f:
21 | for line in f.readlines():
22 | id_list.append(line.strip())
23 | return id_list
24 |
25 |
26 | def load_annotations(path, name="annotations.json"):
27 | json_path = os.path.join(path, name)
28 | with open(json_path, "rb") as ff:
29 | return orjson.loads(ff.read())
30 |
31 |
32 | class TextMotionMultiLabelsDataset(TextMotionDataset):
33 | def __init__(
34 | self,
35 | path: str,
36 | motion_loader,
37 | text_to_sent_emb,
38 | text_to_token_emb,
39 | split: str = "train",
40 | min_seconds: float = 2.0,
41 | max_seconds: float = 10.0,
42 | preload: bool = True,
43 | tiny: bool = False,
44 | ):
45 | if tiny:
46 | split = split + "_tiny"
47 |
48 | self.collate_fn = collate_text_motion_multiple_texts
49 | self.split = split
50 | self.keyids = read_split(path, split)
51 |
52 | self.text_to_sent_emb = text_to_sent_emb
53 | self.text_to_token_emb = text_to_token_emb
54 | self.motion_loader = motion_loader
55 |
56 | self.min_seconds = min_seconds
57 | self.max_seconds = max_seconds
58 |
59 | # remove too short or too long annotations
60 | self.annotations = load_annotations(path)
61 | if "test" not in split:
62 | self.annotations = self.filter_annotations(self.annotations)
63 |
64 | self.is_training = "train" in split
65 | self.keyids = [keyid for keyid in self.keyids if keyid in self.annotations]
66 |
67 | self.nfeats = self.motion_loader.nfeats
68 |
69 | if preload:
70 | for _ in tqdm(self, desc="Preloading the dataset"):
71 | continue
72 |
73 | def load_keyid(self, keyid, device=None, text_idx=None, sent_emb_mode="first"):
74 | annotations = self.annotations[keyid]
75 |
76 | index = 0
77 |
78 | path = annotations["path"]
79 | annotation = annotations["annotations"][index]
80 | start = annotation["start"]
81 | end = annotation["end"]
82 |
83 | texts = [ann["text"] for ann in annotations["annotations"]]
84 |
85 | text_x_dicts = self.text_to_token_emb(texts) # [{"x": ..., "length": ...}, {"x": ..., "length"}: ..., ... ]
86 | motion_x_dict = self.motion_loader(
87 | path=path,
88 | start=start,
89 | end=end,
90 | )
91 |
92 | if sent_emb_mode == "first":
93 | sent_emb = self.text_to_sent_emb(texts[0])
94 | elif sent_emb_mode == "average":
95 | sent_emb = torch.stack([self.text_to_sent_emb(text) for text in texts])
96 | sent_emb = torch.mean(sent_emb, axis=0)
97 | sent_emb = torch.nn.functional.normalize(sent_emb, dim=0)
98 |
99 | output = {
100 | "motion_x_dict": motion_x_dict,
101 | "text_x_dict": text_x_dicts,
102 | "text": texts,
103 | "keyid": keyid,
104 | "sent_emb": sent_emb,
105 | }
106 |
107 | if device is not None:
108 | output["motion_x_dict"]["x"] = output["motion_x_dict"]["x"].to(device)
109 | for text_x_dict in output["text_x_dict"]:
110 | text_x_dict["x"] = text_x_dict["x"].to(device)
111 |
112 | return output
113 |
114 |
115 | def write_json(data, path):
116 | with open(path, "w") as ff:
117 | ff.write(json.dumps(data, indent=4))
118 |
119 |
--------------------------------------------------------------------------------
/src/guofeats/__init__.py:
--------------------------------------------------------------------------------
1 | from .motion_representation import joints_to_guofeats, guofeats_to_joints # noqa
2 |
--------------------------------------------------------------------------------
/src/guofeats/paramUtil.py:
--------------------------------------------------------------------------------
1 | # Taken from
2 | # https://github.com/EricGuo5513/HumanML3D/blob/main/paramUtil.py
3 |
4 | import numpy as np
5 |
6 | # Define a kinematic tree for the skeletal struture
7 | kit_kinematic_chain = [
8 | [0, 11, 12, 13, 14, 15],
9 | [0, 16, 17, 18, 19, 20],
10 | [0, 1, 2, 3, 4],
11 | [3, 5, 6, 7],
12 | [3, 8, 9, 10],
13 | ]
14 |
15 | kit_raw_offsets = np.array(
16 | [
17 | [0, 0, 0],
18 | [0, 1, 0],
19 | [0, 1, 0],
20 | [0, 1, 0],
21 | [0, 1, 0],
22 | [1, 0, 0],
23 | [0, -1, 0],
24 | [0, -1, 0],
25 | [-1, 0, 0],
26 | [0, -1, 0],
27 | [0, -1, 0],
28 | [1, 0, 0],
29 | [0, -1, 0],
30 | [0, -1, 0],
31 | [0, 0, 1],
32 | [0, 0, 1],
33 | [-1, 0, 0],
34 | [0, -1, 0],
35 | [0, -1, 0],
36 | [0, 0, 1],
37 | [0, 0, 1],
38 | ]
39 | )
40 |
41 | t2m_raw_offsets = np.array(
42 | [
43 | [0, 0, 0],
44 | [1, 0, 0],
45 | [-1, 0, 0],
46 | [0, 1, 0],
47 | [0, -1, 0],
48 | [0, -1, 0],
49 | [0, 1, 0],
50 | [0, -1, 0],
51 | [0, -1, 0],
52 | [0, 1, 0],
53 | [0, 0, 1],
54 | [0, 0, 1],
55 | [0, 1, 0],
56 | [1, 0, 0],
57 | [-1, 0, 0],
58 | [0, 0, 1],
59 | [0, -1, 0],
60 | [0, -1, 0],
61 | [0, -1, 0],
62 | [0, -1, 0],
63 | [0, -1, 0],
64 | [0, -1, 0],
65 | ]
66 | )
67 |
68 | t2m_kinematic_chain = [
69 | [0, 2, 5, 8, 11],
70 | [0, 1, 4, 7, 10],
71 | [0, 3, 6, 9, 12, 15],
72 | [9, 14, 17, 19, 21],
73 | [9, 13, 16, 18, 20],
74 | ]
75 | t2m_left_hand_chain = [
76 | [20, 22, 23, 24],
77 | [20, 34, 35, 36],
78 | [20, 25, 26, 27],
79 | [20, 31, 32, 33],
80 | [20, 28, 29, 30],
81 | ]
82 | t2m_right_hand_chain = [
83 | [21, 43, 44, 45],
84 | [21, 46, 47, 48],
85 | [21, 40, 41, 42],
86 | [21, 37, 38, 39],
87 | [21, 49, 50, 51],
88 | ]
89 |
90 |
91 | kit_tgt_skel_id = "03950"
92 | t2m_tgt_skel_id = "000021"
93 |
--------------------------------------------------------------------------------
/src/guofeats/skeleton_example_h3d.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/src/guofeats/skeleton_example_h3d.npy
--------------------------------------------------------------------------------
/src/joints.py:
--------------------------------------------------------------------------------
1 | JOINT_NAMES = {
2 | "smpljoints": [
3 | "pelvis",
4 | "left_hip",
5 | "right_hip",
6 | "spine1",
7 | "left_knee",
8 | "right_knee",
9 | "spine2",
10 | "left_ankle",
11 | "right_ankle",
12 | "spine3",
13 | "left_foot",
14 | "right_foot",
15 | "neck",
16 | "left_collar",
17 | "right_collar",
18 | "head",
19 | "left_shoulder",
20 | "right_shoulder",
21 | "left_elbow",
22 | "right_elbow",
23 | "left_wrist",
24 | "right_wrist",
25 | "left_hand",
26 | "right_hand",
27 | ],
28 | "guoh3djoints": [
29 | "pelvis",
30 | "left_hip",
31 | "right_hip",
32 | "spine1",
33 | "left_knee",
34 | "right_knee",
35 | "spine2",
36 | "left_ankle",
37 | "right_ankle",
38 | "spine3",
39 | "left_foot",
40 | "right_foot",
41 | "neck",
42 | "left_collar",
43 | "right_collar",
44 | "head",
45 | "left_shoulder",
46 | "right_shoulder",
47 | "left_elbow",
48 | "right_elbow",
49 | "left_wrist",
50 | "right_wrist",
51 | ],
52 | }
53 |
54 | INFOS = {
55 | "smpljoints": {
56 | "LM": JOINT_NAMES["smpljoints"].index("left_ankle"),
57 | "RM": JOINT_NAMES["smpljoints"].index("right_ankle"),
58 | "LF": JOINT_NAMES["smpljoints"].index("left_foot"),
59 | "RF": JOINT_NAMES["smpljoints"].index("right_foot"),
60 | "LS": JOINT_NAMES["smpljoints"].index("left_shoulder"),
61 | "RS": JOINT_NAMES["smpljoints"].index("right_shoulder"),
62 | "LH": JOINT_NAMES["smpljoints"].index("left_hip"),
63 | "RH": JOINT_NAMES["smpljoints"].index("right_hip"),
64 | "njoints": len(JOINT_NAMES["smpljoints"]),
65 | },
66 | "guoh3djoints": {
67 | "LM": JOINT_NAMES["guoh3djoints"].index("left_ankle"),
68 | "RM": JOINT_NAMES["guoh3djoints"].index("right_ankle"),
69 | "LF": JOINT_NAMES["guoh3djoints"].index("left_foot"),
70 | "RF": JOINT_NAMES["guoh3djoints"].index("right_foot"),
71 | "LS": JOINT_NAMES["guoh3djoints"].index("left_shoulder"),
72 | "RS": JOINT_NAMES["guoh3djoints"].index("right_shoulder"),
73 | "LH": JOINT_NAMES["guoh3djoints"].index("left_hip"),
74 | "RH": JOINT_NAMES["guoh3djoints"].index("right_hip"),
75 | "njoints": len(JOINT_NAMES["guoh3djoints"]),
76 | },
77 | }
78 |
--------------------------------------------------------------------------------
/src/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import DictConfig
3 | import logging
4 | import hydra
5 |
6 | from src.config import read_config
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | # split the lightning checkpoint into
12 | # seperate state_dict modules for faster loading
13 | def extract_ckpt(run_dir, ckpt_name="last"):
14 | import torch
15 |
16 | ckpt_path = os.path.join(run_dir, f"logs/checkpoints/{ckpt_name}.ckpt")
17 |
18 | extracted_path = os.path.join(run_dir, f"{ckpt_name}_weights")
19 | os.makedirs(extracted_path, exist_ok=True)
20 |
21 | new_path_template = os.path.join(extracted_path, "{}.pt")
22 | ckpt_dict = torch.load(ckpt_path)
23 | state_dict = ckpt_dict["state_dict"]
24 | module_names = list(set([x.split(".")[0] for x in state_dict.keys()]))
25 |
26 | # should be ['motion_encoder', 'text_encoder', 'motion_decoder'] for example
27 | for module_name in module_names:
28 | path = new_path_template.format(module_name)
29 | sub_state_dict = {
30 | ".".join(x.split(".")[1:]): y.cpu()
31 | for x, y in state_dict.items()
32 | if x.split(".")[0] == module_name
33 | }
34 | torch.save(sub_state_dict, path)
35 |
36 |
37 | def load_model(run_dir, **params):
38 | # Load last config
39 | cfg = read_config(run_dir)
40 | cfg.run_dir = run_dir
41 | return load_model_from_cfg(cfg, **params)
42 |
43 |
44 | def load_model_from_cfg(cfg, ckpt_name="last", device="cpu", eval_mode=True):
45 | import src.prepare # noqa
46 | import torch
47 |
48 | run_dir = cfg.run_dir
49 | model = hydra.utils.instantiate(cfg.model)
50 |
51 | # Loading modules one by one
52 | # motion_encoder / text_encoder / text_decoder
53 | pt_path = os.path.join(run_dir, f"{ckpt_name}_weights")
54 |
55 | if not os.path.exists(pt_path):
56 | logger.info("The extracted model is not found. Split into submodules..")
57 | extract_ckpt(run_dir, ckpt_name)
58 |
59 | for fname in os.listdir(pt_path):
60 | module_name, ext = os.path.splitext(fname)
61 |
62 | if ext != ".pt":
63 | continue
64 |
65 | module = getattr(model, module_name, None)
66 | if module is None:
67 | continue
68 |
69 | module_path = os.path.join(pt_path, fname)
70 | state_dict = torch.load(module_path)
71 | module.load_state_dict(state_dict)
72 | logger.info(f" {module_name} loaded")
73 |
74 | logger.info("Loading previous checkpoint done")
75 | model = model.to(device)
76 | logger.info(f"Put the model on {device}")
77 | if eval_mode:
78 | model = model.eval()
79 | logger.info("Put the model in eval mode")
80 | return model
81 |
82 |
83 | @hydra.main(version_base=None, config_path="../configs", config_name="load_model")
84 | def hydra_load_model(cfg: DictConfig) -> None:
85 | run_dir = cfg.run_dir
86 | ckpt_name = cfg.ckpt
87 | device = cfg.device
88 | eval_mode = cfg.eval_mode
89 | return load_model(run_dir, ckpt_name, device, eval_mode)
90 |
91 |
92 | if __name__ == "__main__":
93 | hydra_load_model()
94 |
--------------------------------------------------------------------------------
/src/logger/csv.py:
--------------------------------------------------------------------------------
1 | # from pytorch_lightning/loggers/csv_logs.py
2 | # of pytorch_lightning version 2.04
3 |
4 | # Copyright The Lightning AI team.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """
18 | CSV logger
19 | ----------
20 |
21 | CSV logger for basic experiment logging that does not require opening ports
22 |
23 | """
24 | import logging
25 | import os
26 | from argparse import Namespace
27 | from typing import Any, Dict, Optional, Union
28 |
29 | # from lightning_fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
30 | # from lightning_fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
31 | # Local replacement
32 | from .csv_fabric import _ExperimentWriter as _FabricExperimentWriter
33 | from .csv_fabric import CSVLogger as FabricCSVLogger
34 |
35 | from lightning_fabric.loggers.logger import rank_zero_experiment
36 | from lightning_fabric.utilities.logger import _convert_params
37 | from lightning_fabric.utilities.types import _PATH
38 | from pytorch_lightning.core.saving import save_hparams_to_yaml
39 | from pytorch_lightning.loggers.logger import Logger
40 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
41 |
42 | log = logging.getLogger(__name__)
43 |
44 |
45 | class ExperimentWriter(_FabricExperimentWriter):
46 | r"""Experiment writer for CSVLogger.
47 |
48 | Currently, supports to log hyperparameters and metrics in YAML and CSV
49 | format, respectively.
50 |
51 | Args:
52 | log_dir: Directory for the experiment logs
53 | """
54 |
55 | NAME_HPARAMS_FILE = "hparams.yaml"
56 |
57 | def __init__(self, log_dir: str) -> None:
58 | super().__init__(log_dir=log_dir)
59 | self.hparams: Dict[str, Any] = {}
60 |
61 | def log_hparams(self, params: Dict[str, Any]) -> None:
62 | """Record hparams."""
63 | self.hparams.update(params)
64 |
65 | def save(self) -> None:
66 | """Save recorded hparams and metrics into files."""
67 | hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
68 | save_hparams_to_yaml(hparams_file, self.hparams)
69 | return super().save()
70 |
71 |
72 | class CSVLogger(Logger, FabricCSVLogger):
73 | r"""Log to local file system in yaml and CSV format.
74 |
75 | Logs are saved to ``os.path.join(save_dir, name)``.
76 |
77 | Example:
78 | >>> from pytorch_lightning import Trainer
79 | >>> from pytorch_lightning.loggers import CSVLogger
80 | >>> logger = CSVLogger("logs", name="my_exp_name")
81 | >>> trainer = Trainer(logger=logger)
82 |
83 | Args:
84 | save_dir: Save directory
85 | name: Experiment name. Defaults to ``'lightning_logs'``.
86 | prefix: A string to put at the beginning of metric keys.
87 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
88 | """
89 |
90 | LOGGER_JOIN_CHAR = "-"
91 |
92 | def __init__(
93 | self,
94 | save_dir: _PATH,
95 | name: str = "lightning_logs",
96 | prefix: str = "",
97 | flush_logs_every_n_steps: int = 100,
98 | ):
99 | super().__init__(
100 | root_dir=save_dir,
101 | name=name,
102 | prefix=prefix,
103 | flush_logs_every_n_steps=flush_logs_every_n_steps,
104 | )
105 | self._save_dir = os.fspath(save_dir)
106 |
107 | @property
108 | def root_dir(self) -> str:
109 | """Parent directory for all checkpoint subdirectories.
110 |
111 | If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will
112 | be saved in "save_dir/"
113 | """
114 | return os.path.join(self.save_dir, self.name)
115 |
116 | @property
117 | def log_dir(self) -> str:
118 | """The log directory for this run."""
119 | return self.root_dir
120 |
121 | @property
122 | def save_dir(self) -> str:
123 | """The current directory where logs are saved.
124 |
125 | Returns:
126 | The path to current directory where logs are saved.
127 | """
128 | return self._save_dir
129 |
130 | @rank_zero_only
131 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
132 | # don't log hyperparameters
133 | # already done in the config
134 | return
135 |
136 | @property
137 | @rank_zero_experiment
138 | def experiment(self) -> _FabricExperimentWriter:
139 | r"""
140 |
141 | Actual _ExperimentWriter object. To use _ExperimentWriter features in your
142 | :class:`~pytorch_lightning.core.module.LightningModule` do the following.
143 |
144 | Example::
145 |
146 | self.logger.experiment.some_experiment_writer_function()
147 |
148 | """
149 | if self._experiment is not None:
150 | return self._experiment
151 |
152 | self._fs.makedirs(self.root_dir, exist_ok=True)
153 | self._experiment = ExperimentWriter(log_dir=self.log_dir)
154 | return self._experiment
155 |
--------------------------------------------------------------------------------
/src/logger/csv_fabric.py:
--------------------------------------------------------------------------------
1 | # from lightning_fabric/loggers/csv_logs.py
2 | # of lightning_fabric version 2.04
3 |
4 | # Copyright The Lightning AI team.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import csv
19 | import logging
20 | import os
21 | from argparse import Namespace
22 | from typing import Any, Dict, List, Optional, Union
23 |
24 | from torch import Tensor
25 |
26 | from lightning_fabric.loggers.logger import Logger, rank_zero_experiment
27 | from lightning_fabric.utilities.cloud_io import get_filesystem
28 | from lightning_fabric.utilities.logger import _add_prefix
29 | from lightning_fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
30 | from lightning_fabric.utilities.types import _PATH
31 |
32 | log = logging.getLogger(__name__)
33 |
34 |
35 | class CSVLogger(Logger):
36 | r"""Log to the local file system in CSV format.
37 |
38 | Logs are saved to ``os.path.join(root_dir, name)``.
39 |
40 | Args:
41 | root_dir: The root directory in which all your experiments with different names and versions will be stored.
42 | name: Experiment name. Defaults to ``'lightning_logs'``.
43 | prefix: A string to put at the beginning of metric keys.
44 | flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
45 |
46 | Example::
47 |
48 | from lightning_fabric.loggers import CSVLogger
49 |
50 | logger = CSVLogger("path/to/logs/root", name="my_model")
51 | logger.log_metrics({"loss": 0.235, "acc": 0.75})
52 | logger.finalize("success")
53 | """
54 |
55 | LOGGER_JOIN_CHAR = "-"
56 |
57 | def __init__(
58 | self,
59 | root_dir: _PATH,
60 | name: str = "lightning_logs",
61 | prefix: str = "",
62 | flush_logs_every_n_steps: int = 100,
63 | ):
64 | super().__init__()
65 | root_dir = os.fspath(root_dir)
66 | self._root_dir = root_dir
67 | self._name = name or ""
68 | self._prefix = prefix
69 | self._fs = get_filesystem(root_dir)
70 | self._experiment: Optional[_ExperimentWriter] = None
71 | self._flush_logs_every_n_steps = flush_logs_every_n_steps
72 |
73 | @property
74 | def name(self) -> str:
75 | """Gets the name of the experiment.
76 |
77 | Returns:
78 | The name of the experiment.
79 | """
80 | return self._name
81 |
82 | @property
83 | def version(self) -> str:
84 | return ""
85 |
86 | @property
87 | def root_dir(self) -> str:
88 | """Gets the save directory where the versioned CSV experiments are saved."""
89 | return self._root_dir
90 |
91 | @property
92 | def log_dir(self) -> str:
93 | """The log directory for this run."""
94 | # create a pseudo standard path
95 | return os.path.join(self.root_dir, self.name)
96 |
97 | @property
98 | @rank_zero_experiment
99 | def experiment(self) -> "_ExperimentWriter":
100 | """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the
101 | following.
102 |
103 | Example::
104 |
105 | self.logger.experiment.some_experiment_writer_function()
106 | """
107 | if self._experiment is not None:
108 | return self._experiment
109 |
110 | os.makedirs(self.root_dir, exist_ok=True)
111 | self._experiment = _ExperimentWriter(log_dir=self.log_dir)
112 | return self._experiment
113 |
114 | @rank_zero_only
115 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
116 | raise NotImplementedError(
117 | "The `CSVLogger` does not yet support logging hyperparameters."
118 | )
119 |
120 | @rank_zero_only
121 | def log_metrics(
122 | self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
123 | ) -> None:
124 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
125 | self.experiment.log_metrics(metrics, step)
126 | if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0:
127 | self.save()
128 |
129 | @rank_zero_only
130 | def save(self) -> None:
131 | super().save()
132 | self.experiment.save()
133 |
134 | @rank_zero_only
135 | def finalize(self, status: str) -> None:
136 | if self._experiment is None:
137 | # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been
138 | # initialized there
139 | return
140 | self.save()
141 |
142 |
143 | class _ExperimentWriter:
144 | r"""Experiment writer for CSVLogger.
145 |
146 | Args:
147 | log_dir: Directory for the experiment logs
148 | """
149 |
150 | NAME_METRICS_FILE = "metrics.csv"
151 |
152 | def __init__(self, log_dir: str) -> None:
153 | self.metrics: List[Dict[str, float]] = []
154 |
155 | self._fs = get_filesystem(log_dir)
156 | self.log_dir = log_dir
157 | self._fs.makedirs(self.log_dir, exist_ok=True)
158 | self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
159 |
160 | if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):
161 | # Read previous logs
162 | if os.path.exists(self.metrics_file_path):
163 | with self._fs.open(self.metrics_file_path, "r") as f:
164 | reader = csv.DictReader(f)
165 | self.metrics = [x for x in reader]
166 |
167 | def log_metrics(
168 | self, metrics_dict: Dict[str, float], step: Optional[int] = None
169 | ) -> None:
170 | """Record metrics."""
171 |
172 | def _handle_value(value: Union[Tensor, Any]) -> Any:
173 | if isinstance(value, Tensor):
174 | return value.item()
175 | return value
176 |
177 | if step is None:
178 | step = len(self.metrics)
179 |
180 | metrics = {k: _handle_value(v) for k, v in metrics_dict.items()}
181 | metrics["step"] = step
182 | self.metrics.append(metrics)
183 |
184 | def save(self) -> None:
185 | """Save recorded metrics into files."""
186 | if not self.metrics:
187 | return
188 |
189 | last_m = {}
190 | for m in self.metrics:
191 | last_m.update(m)
192 | metrics_keys = list(last_m.keys())
193 |
194 | with self._fs.open(self.metrics_file_path, "w", newline="") as f:
195 | writer = csv.DictWriter(f, fieldnames=metrics_keys)
196 | writer.writeheader()
197 | writer.writerows(self.metrics)
198 |
--------------------------------------------------------------------------------
/src/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tqdm
3 |
4 |
5 | # from https://stackoverflow.com/questions/38543506/change-logging-print-function-to-tqdm-write-so-logging-doesnt-interfere-wit
6 | class TqdmLoggingHandler(logging.Handler):
7 | def __init__(self, level=logging.NOTSET):
8 | super().__init__(level)
9 |
10 | def emit(self, record):
11 | try:
12 | msg = self.format(record)
13 | tqdm.tqdm.write(msg)
14 | self.flush()
15 | except Exception:
16 | self.handleError(record)
17 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import PositionalEncoding, ACTORStyleEncoder, ACTORStyleDecoder # noqa
2 | from .temos import TEMOS # noqa
3 | from .tmr import TMR # noqa
4 | from .text_encoder import TextToEmb # noqa
5 |
--------------------------------------------------------------------------------
/src/model/actor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | import numpy as np
7 |
8 | from einops import repeat
9 |
10 |
11 | class PositionalEncoding(nn.Module):
12 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
13 | super().__init__()
14 | self.batch_first = batch_first
15 |
16 | self.dropout = nn.Dropout(p=dropout)
17 |
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
20 | div_term = torch.exp(
21 | torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
22 | )
23 | pe[:, 0::2] = torch.sin(position * div_term)
24 | pe[:, 1::2] = torch.cos(position * div_term)
25 | pe = pe.unsqueeze(0).transpose(0, 1)
26 | self.register_buffer("pe", pe, persistent=False)
27 |
28 | def forward(self, x: Tensor) -> Tensor:
29 | if self.batch_first:
30 | x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
31 | else:
32 | x = x + self.pe[: x.shape[0], :]
33 | return self.dropout(x)
34 |
35 |
36 | class ACTORStyleEncoder(nn.Module):
37 | # Similar to ACTOR but "action agnostic" and more general
38 | def __init__(
39 | self,
40 | nfeats: int,
41 | vae: bool,
42 | latent_dim: int = 256,
43 | ff_size: int = 1024,
44 | num_layers: int = 4,
45 | num_heads: int = 4,
46 | dropout: float = 0.1,
47 | activation: str = "gelu",
48 | ) -> None:
49 | super().__init__()
50 |
51 | self.nfeats = nfeats
52 | self.projection = nn.Linear(nfeats, latent_dim)
53 |
54 | self.vae = vae
55 | self.nbtokens = 2 if vae else 1
56 | self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
57 |
58 | self.sequence_pos_encoding = PositionalEncoding(
59 | latent_dim, dropout=dropout, batch_first=True
60 | )
61 |
62 | seq_trans_encoder_layer = nn.TransformerEncoderLayer(
63 | d_model=latent_dim,
64 | nhead=num_heads,
65 | dim_feedforward=ff_size,
66 | dropout=dropout,
67 | activation=activation,
68 | batch_first=True,
69 | )
70 |
71 | self.seqTransEncoder = nn.TransformerEncoder(
72 | seq_trans_encoder_layer, num_layers=num_layers
73 | )
74 |
75 | def forward(self, x_dict: Dict) -> Tensor:
76 | x = x_dict["x"]
77 | mask = x_dict["mask"]
78 |
79 | x = self.projection(x)
80 |
81 | device = x.device
82 | bs = len(x)
83 |
84 | tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
85 | xseq = torch.cat((tokens, x), 1)
86 |
87 | token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)
88 | aug_mask = torch.cat((token_mask, mask), 1)
89 |
90 | # add positional encoding
91 | xseq = self.sequence_pos_encoding(xseq)
92 | final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
93 | return final[:, : self.nbtokens]
94 |
95 |
96 | class ACTORStyleDecoder(nn.Module):
97 | # Similar to ACTOR Decoder
98 |
99 | def __init__(
100 | self,
101 | nfeats: int,
102 | latent_dim: int = 256,
103 | ff_size: int = 1024,
104 | num_layers: int = 4,
105 | num_heads: int = 4,
106 | dropout: float = 0.1,
107 | activation: str = "gelu",
108 | ) -> None:
109 | super().__init__()
110 | output_feats = nfeats
111 | self.nfeats = nfeats
112 |
113 | self.sequence_pos_encoding = PositionalEncoding(
114 | latent_dim, dropout, batch_first=True
115 | )
116 |
117 | seq_trans_decoder_layer = nn.TransformerDecoderLayer(
118 | d_model=latent_dim,
119 | nhead=num_heads,
120 | dim_feedforward=ff_size,
121 | dropout=dropout,
122 | activation=activation,
123 | batch_first=True,
124 | )
125 |
126 | self.seqTransDecoder = nn.TransformerDecoder(
127 | seq_trans_decoder_layer, num_layers=num_layers
128 | )
129 |
130 | self.final_layer = nn.Linear(latent_dim, output_feats)
131 |
132 | def forward(self, z_dict: Dict) -> Tensor:
133 | z = z_dict["z"]
134 | mask = z_dict["mask"]
135 |
136 | latent_dim = z.shape[1]
137 | bs, nframes = mask.shape
138 |
139 | z = z[:, None] # sequence of 1 element for the memory
140 |
141 | # Construct time queries
142 | time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device)
143 | time_queries = self.sequence_pos_encoding(time_queries)
144 |
145 | # Pass through the transformer decoder
146 | # with the latent vector for memory
147 | output = self.seqTransDecoder(
148 | tgt=time_queries, memory=z, tgt_key_padding_mask=~mask
149 | )
150 |
151 | output = self.final_layer(output)
152 | # zero for padded area
153 | output[~mask] = 0
154 | return output
155 |
--------------------------------------------------------------------------------
/src/model/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # For reference
6 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
7 | # https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
8 | class KLLoss:
9 | def __call__(self, q, p):
10 | mu_q, logvar_q = q
11 | mu_p, logvar_p = p
12 |
13 | log_var_ratio = logvar_q - logvar_p
14 | t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
15 | div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
16 | return div.mean()
17 |
18 | def __repr__(self):
19 | return "KLLoss()"
20 |
21 |
22 | def get_sim_matrix(x, y):
23 | x_logits = torch.nn.functional.normalize(x, dim=-1)
24 | y_logits = torch.nn.functional.normalize(y, dim=-1)
25 | sim_matrix = x_logits @ y_logits.T
26 | return sim_matrix
27 |
28 |
29 | class InfoNCE_with_filtering:
30 | def __init__(self, temperature=0.7, threshold_selfsim=0.8):
31 | self.temperature = temperature
32 | self.threshold_selfsim = threshold_selfsim
33 |
34 | def filter_sim_mat_with_sent_emb(self, sim_matrix, sent_emb):
35 | # put the threshold value between -1 and 1
36 | real_threshold_selfsim = 2 * self.threshold_selfsim - 1
37 | # Filtering too close values
38 | # mask them by putting -inf in the sim_matrix
39 | selfsim = sent_emb @ sent_emb.T
40 | selfsim_nodiag = selfsim - selfsim.diag().diag()
41 | idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
42 | sim_matrix[idx] = -torch.inf
43 | return sim_matrix # TODO check if return necessary or in place operation
44 |
45 | def __call__(self, x, y, sent_emb=None):
46 | bs, device = len(x), x.device
47 | sim_matrix = get_sim_matrix(x, y) / self.temperature
48 |
49 | if sent_emb is not None and self.threshold_selfsim:
50 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb)
51 |
52 | labels = torch.arange(bs, device=device)
53 |
54 | total_loss = (
55 | F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
56 | ) / 2
57 |
58 | return total_loss
59 |
60 | def __repr__(self):
61 | return f"Constrastive(temp={self.temp})"
62 |
63 |
64 | class HN_InfoNCE_with_filtering(InfoNCE_with_filtering):
65 | def __init__(self, temperature=0.7, threshold_selfsim=0.8, alpha=1.0, beta=0.25):
66 | super().__init__(temperature=temperature, threshold_selfsim=threshold_selfsim)
67 | self.alpha = alpha
68 | self.beta = beta
69 |
70 | def cross_entropy_with_HN_weights(self, sim_matrix):
71 | n = sim_matrix.shape[0]
72 |
73 | labels = range(sim_matrix.shape[0])
74 | exp_mat = torch.exp(sim_matrix)
75 | num = exp_mat[range(exp_mat.shape[0]), labels]
76 |
77 | exp_mat_beta = torch.exp(self.beta * sim_matrix)
78 | weights = (n - 1) * exp_mat_beta / torch.unsqueeze((torch.sum(exp_mat_beta, axis=1) - exp_mat_beta.diag()), dim=1)
79 | weights = weights.fill_diagonal_(self.alpha)
80 | denum = torch.sum(weights * exp_mat, axis=1)
81 |
82 | return -torch.mean(torch.log(num/denum))
83 |
84 | def __call__(self, x, y, sent_emb=None):
85 | bs, device = len(x), x.device
86 | sim_matrix = get_sim_matrix(x, y) / self.temperature
87 |
88 | if sent_emb is not None and self.threshold_selfsim:
89 | sim_matrix = self.filter_sim_mat_with_sent_emb(sim_matrix, sent_emb)
90 |
91 | total_loss = (
92 | self.cross_entropy_with_HN_weights(sim_matrix) + self.cross_entropy_with_HN_weights(sim_matrix.T)
93 | ) / 2
94 |
95 | return total_loss
96 |
97 | def __repr__(self):
98 | return f"Constrastive(temp={self.temp})"
99 |
--------------------------------------------------------------------------------
/src/model/temos.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from pytorch_lightning import LightningModule
7 |
8 | from src.model.losses import KLLoss
9 |
10 |
11 | def length_to_mask(length: List[int], device: torch.device = None) -> Tensor:
12 | if device is None:
13 | device = "cpu"
14 |
15 | if isinstance(length, list):
16 | length = torch.tensor(length, device=device)
17 |
18 | max_len = max(length)
19 | mask = torch.arange(max_len, device=device).expand(
20 | len(length), max_len
21 | ) < length.unsqueeze(1)
22 | return mask
23 |
24 |
25 | class TEMOS(LightningModule):
26 | r"""TEMOS: Generating diverse human motions
27 | from textual descriptions
28 | Find more information about the model on the following website:
29 | https://mathis.petrovich.fr/temos
30 |
31 | Args:
32 | motion_encoder: a module to encode the input motion features in the latent space (required).
33 | text_encoder: a module to encode the text embeddings in the latent space (required).
34 | motion_decoder: a module to decode the latent vector into motion features (required).
35 | vae: a boolean to make the model probabilistic (required).
36 | fact: a scaling factor for sampling the VAE (optional).
37 | sample_mean: sample the mean vector instead of random sampling (optional).
38 | lmd: dictionary of losses weights (optional).
39 | lr: learninig rate for the optimizer (optional).
40 | """
41 |
42 | def __init__(
43 | self,
44 | motion_encoder: nn.Module,
45 | text_encoder: nn.Module,
46 | motion_decoder: nn.Module,
47 | vae: bool,
48 | fact: Optional[float] = None,
49 | sample_mean: Optional[bool] = False,
50 | lmd: Dict = {"recons": 1.0, "latent": 1.0e-5, "kl": 1.0e-5},
51 | lr: float = 1e-4,
52 | ) -> None:
53 | super().__init__()
54 |
55 | self.motion_encoder = motion_encoder
56 | self.text_encoder = text_encoder
57 | self.motion_decoder = motion_decoder
58 |
59 | # sampling parameters
60 | self.vae = vae
61 | self.fact = fact if fact is not None else 1.0
62 | self.sample_mean = sample_mean
63 |
64 | # losses
65 | self.reconstruction_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
66 | self.latent_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
67 | self.kl_loss_fn = KLLoss()
68 |
69 | # lambda weighting for the losses
70 | self.lmd = lmd
71 | self.lr = lr
72 |
73 | def configure_optimizers(self) -> None:
74 | return {"optimizer": torch.optim.AdamW(lr=self.lr, params=self.parameters())}
75 |
76 | def _find_encoder(self, inputs, modality):
77 | assert modality in ["text", "motion", "auto"]
78 |
79 | if modality == "text":
80 | return self.text_encoder
81 | elif modality == "motion":
82 | return self.motion_encoder
83 |
84 | m_nfeats = self.motion_encoder.nfeats
85 | t_nfeats = self.text_encoder.nfeats
86 |
87 | if m_nfeats == t_nfeats:
88 | raise ValueError(
89 | "Cannot automatically find the encoder, as they share the same input space."
90 | )
91 |
92 | nfeats = inputs["x"].shape[-1]
93 | if nfeats == m_nfeats:
94 | return self.motion_encoder
95 | elif nfeats == t_nfeats:
96 | return self.text_encoder
97 | else:
98 | raise ValueError("The inputs is not recognized.")
99 |
100 | def encode(
101 | self,
102 | inputs,
103 | modality: str = "auto",
104 | sample_mean: Optional[bool] = None,
105 | fact: Optional[float] = None,
106 | return_distribution: bool = False,
107 | ):
108 | sample_mean = self.sample_mean if sample_mean is None else sample_mean
109 | fact = self.fact if fact is None else fact
110 |
111 | # Encode the inputs
112 | encoder = self._find_encoder(inputs, modality)
113 | encoded = encoder(inputs)
114 |
115 | # Sampling
116 | if self.vae:
117 | dists = encoded.unbind(1)
118 | mu, logvar = dists
119 | if sample_mean:
120 | latent_vectors = mu
121 | else:
122 | # Reparameterization trick
123 | std = logvar.exp().pow(0.5)
124 | eps = std.data.new(std.size()).normal_()
125 | latent_vectors = mu + fact * eps * std
126 | else:
127 | dists = None
128 | (latent_vectors, _) = encoded.unbind(1)
129 |
130 | if return_distribution:
131 | return latent_vectors, dists
132 |
133 | return latent_vectors
134 |
135 | def decode(
136 | self,
137 | latent_vectors: Tensor,
138 | lengths: Optional[List[int]] = None,
139 | mask: Optional[Tensor] = None,
140 | ):
141 | mask = mask if mask is not None else length_to_mask(lengths, device=self.device)
142 | z_dict = {"z": latent_vectors, "mask": mask}
143 | motions = self.motion_decoder(z_dict)
144 | return motions
145 |
146 | # Forward: X => motions
147 | def forward(
148 | self,
149 | inputs,
150 | lengths: Optional[List[int]] = None,
151 | mask: Optional[Tensor] = None,
152 | sample_mean: Optional[bool] = None,
153 | fact: Optional[float] = None,
154 | return_all: bool = False,
155 | ) -> List[Tensor]:
156 | # Encoding the inputs and sampling if needed
157 | latent_vectors, distributions = self.encode(
158 | inputs, sample_mean=sample_mean, fact=fact, return_distribution=True
159 | )
160 | # Decoding the latent vector: generating motions
161 | motions = self.decode(latent_vectors, lengths, mask)
162 |
163 | if return_all:
164 | return {"motions": motions,
165 | "latent_vectors": latent_vectors,
166 | "distributions": distributions}
167 |
168 | return {"motions": motions}
169 |
170 | def call_models(self, batch):
171 | text_x_dict = batch["text_x_dict"]
172 | motion_x_dict = batch["motion_x_dict"]
173 |
174 | mask = motion_x_dict["mask"]
175 |
176 | # text -> motion
177 | t_results = self(text_x_dict, mask=mask, return_all=True)
178 |
179 | # motion -> motion
180 | m_results = self(motion_x_dict, mask=mask, return_all=True)
181 |
182 | return t_results, m_results
183 |
184 | def compute_loss(self, batch: Dict) -> Dict:
185 | t_results, m_results = self.call_models(batch)
186 | t_motions, t_latents, t_dists = t_results["motions"], t_results["latent_vectors"], t_results["distributions"]
187 | m_motions, m_latents, m_dists = m_results["motions"], m_results["latent_vectors"], m_results["distributions"]
188 |
189 | ref_motions = batch["motion_x_dict"]["x"]
190 |
191 | # Store all losses
192 | losses = {}
193 |
194 | # Reconstructions losses
195 | # fmt: off
196 | losses["recons"] = (
197 | + self.reconstruction_loss_fn(t_motions, ref_motions) # text -> motion
198 | + self.reconstruction_loss_fn(m_motions, ref_motions) # motion -> motion
199 | )
200 | # fmt: on
201 |
202 | # VAE losses
203 | if self.vae:
204 | # Create a centred normal distribution to compare with
205 | # logvar = 0 -> std = 1
206 | ref_mus = torch.zeros_like(m_dists[0])
207 | ref_logvar = torch.zeros_like(m_dists[1])
208 | ref_dists = (ref_mus, ref_logvar)
209 |
210 | losses["kl"] = (
211 | self.kl_loss_fn(t_dists, m_dists) # text_to_motion
212 | + self.kl_loss_fn(m_dists, t_dists) # motion_to_text
213 | + self.kl_loss_fn(m_dists, ref_dists) # motion
214 | + self.kl_loss_fn(t_dists, ref_dists) # text
215 | )
216 |
217 | # Latent manifold loss
218 | losses["latent"] = self.latent_loss_fn(t_latents, m_latents)
219 |
220 | # Weighted average of the losses
221 | losses["loss"] = sum(
222 | self.lmd[x] * val for x, val in losses.items() if x in self.lmd
223 | )
224 | return losses
225 |
226 | def training_step(self, batch: Dict, batch_idx: int) -> Tensor:
227 | bs = len(batch["motion_x_dict"]["x"])
228 | losses = self.compute_loss(batch)
229 |
230 | for loss_name in sorted(losses):
231 | loss_val = losses[loss_name]
232 | self.log(
233 | f"train_{loss_name}",
234 | loss_val,
235 | on_epoch=True,
236 | on_step=True,
237 | batch_size=bs,
238 | )
239 | return losses["loss"]
240 |
241 | def validation_step(self, batch: Dict, batch_idx: int) -> Tensor:
242 | bs = len(batch["motion_x_dict"]["x"])
243 | losses = self.compute_loss(batch)
244 |
245 | for loss_name in sorted(losses):
246 | loss_val = losses[loss_name]
247 | self.log(
248 | f"val_{loss_name}",
249 | loss_val,
250 | on_epoch=True,
251 | on_step=True,
252 | batch_size=bs,
253 | )
254 | return losses["loss"]
255 |
--------------------------------------------------------------------------------
/src/model/text_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch
4 | from torch import Tensor
5 | from typing import Dict, List
6 | import torch.nn.functional as F
7 |
8 |
9 | class TextToEmb(nn.Module):
10 | def __init__(
11 | self, modelpath: str, mean_pooling: bool = False, device: str = "cpu"
12 | ) -> None:
13 | super().__init__()
14 |
15 | self.device = device
16 | from transformers import AutoTokenizer, AutoModel
17 | from transformers import logging
18 |
19 | logging.set_verbosity_error()
20 |
21 | # Tokenizer
22 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
23 | self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
24 |
25 | # Text model
26 | self.text_model = AutoModel.from_pretrained(modelpath)
27 | # Then configure the model
28 | self.text_encoded_dim = self.text_model.config.hidden_size
29 |
30 | if mean_pooling:
31 | self.forward = self.forward_pooling
32 |
33 | # put it in eval mode by default
34 | self.eval()
35 |
36 | # Freeze the weights just in case
37 | for param in self.parameters():
38 | param.requires_grad = False
39 |
40 | self.to(device)
41 |
42 | def train(self, mode: bool = True) -> nn.Module:
43 | # override it to be always false
44 | self.training = False
45 | for module in self.children():
46 | module.train(False)
47 | return self
48 |
49 | @torch.no_grad()
50 | def forward(self, texts: List[str], device=None) -> Dict:
51 | device = device if device is not None else self.device
52 |
53 | squeeze = False
54 | if isinstance(texts, str):
55 | texts = [texts]
56 | squeeze = True
57 |
58 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
59 | output = self.text_model(**encoded_inputs.to(device))
60 | length = encoded_inputs.attention_mask.to(dtype=bool).sum(1)
61 |
62 | if squeeze:
63 | x_dict = {"x": output.last_hidden_state[0], "length": length[0]}
64 | else:
65 | x_dict = {"x": output.last_hidden_state, "length": length}
66 | return x_dict
67 |
68 | @torch.no_grad()
69 | def forward_pooling(self, texts: List[str], device=None) -> Tensor:
70 | device = device if device is not None else self.device
71 |
72 | squeeze = False
73 | if isinstance(texts, str):
74 | texts = [texts]
75 | squeeze = True
76 |
77 | # From: https://huggingface.co/sentence-transformers/all-mpnet-base-v2
78 | encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
79 | output = self.text_model(**encoded_inputs.to(device))
80 | attention_mask = encoded_inputs["attention_mask"]
81 |
82 | # Mean Pooling - Take attention mask into account for correct averaging
83 | token_embeddings = output["last_hidden_state"]
84 | input_mask_expanded = (
85 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
86 | )
87 | sentence_embeddings = torch.sum(
88 | token_embeddings * input_mask_expanded, 1
89 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
90 | # Normalize embeddings
91 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
92 | if squeeze:
93 | sentence_embeddings = sentence_embeddings[0]
94 | return sentence_embeddings
95 |
--------------------------------------------------------------------------------
/src/model/tmr.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from torch import Tensor
3 |
4 | import torch
5 | import torch.nn as nn
6 | from .temos import TEMOS
7 | from .losses import InfoNCE_with_filtering
8 | from .metrics import all_contrastive_metrics
9 |
10 |
11 | # x.T will be deprecated in pytorch
12 | def transpose(x):
13 | return x.permute(*torch.arange(x.ndim - 1, -1, -1))
14 |
15 |
16 | def get_sim_matrix(x, y):
17 | x_logits = torch.nn.functional.normalize(x, dim=-1)
18 | y_logits = torch.nn.functional.normalize(y, dim=-1)
19 | sim_matrix = x_logits @ transpose(y_logits)
20 | return sim_matrix
21 |
22 |
23 | # Scores are between 0 and 1
24 | def get_score_matrix(x, y):
25 | sim_matrix = get_sim_matrix(x, y)
26 | scores = sim_matrix / 2 + 0.5
27 | return scores
28 |
29 |
30 | class TMR(TEMOS):
31 | r"""TMR: Text-to-Motion Retrieval
32 | Using Contrastive 3D Human Motion Synthesis
33 | Find more information about the model on the following website:
34 | https://mathis.petrovich.fr/tmr
35 |
36 | Args:
37 | motion_encoder: a module to encode the input motion features in the latent space (required).
38 | text_encoder: a module to encode the text embeddings in the latent space (required).
39 | motion_decoder: a module to decode the latent vector into motion features (required).
40 | vae: a boolean to make the model probabilistic (required).
41 | fact: a scaling factor for sampling the VAE (optional).
42 | sample_mean: sample the mean vector instead of random sampling (optional).
43 | lmd: dictionary of losses weights (optional).
44 | lr: learninig rate for the optimizer (optional).
45 | temperature: temperature of the softmax in the contrastive loss (optional).
46 | threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional).
47 | threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional).
48 | """
49 |
50 | def __init__(
51 | self,
52 | motion_encoder: nn.Module,
53 | text_encoder: nn.Module,
54 | motion_decoder: nn.Module,
55 | vae: bool,
56 | contrastive_loss: Optional[InfoNCE_with_filtering] = None,
57 | temperature: float = 0.7, # For compatibility with TMR original code
58 | threshold_selfsim: float = 0.80, # For compatibility with TMR original code
59 | fact: Optional[float] = None,
60 | sample_mean: Optional[bool] = False,
61 | lmd: Dict = {"recons": 1.0, "latent": 1.0e-5, "kl": 1.0e-5, "contrastive": 0.1},
62 | lr: float = 1e-4,
63 | threshold_selfsim_metrics: float = 0.95,
64 | ) -> None:
65 | # Initialize module like TEMOS
66 | super().__init__(
67 | motion_encoder=motion_encoder,
68 | text_encoder=text_encoder,
69 | motion_decoder=motion_decoder,
70 | vae=vae,
71 | fact=fact,
72 | sample_mean=sample_mean,
73 | lmd=lmd,
74 | lr=lr,
75 | )
76 |
77 | # adding the contrastive loss
78 | self.contrastive_loss_fn = contrastive_loss
79 | if self.contrastive_loss_fn is None: # For compatibility with TMR original code
80 | self.contrastive_loss_fn = InfoNCE_with_filtering(
81 | temperature=temperature, threshold_selfsim=threshold_selfsim
82 | )
83 | self.threshold_selfsim_metrics = threshold_selfsim_metrics
84 |
85 | # store validation values to compute retrieval metrics
86 | # on the whole validation set
87 | self.validation_step_t_latents = []
88 | self.validation_step_m_latents = []
89 | self.validation_step_sent_emb = []
90 |
91 | def compute_loss(self, batch: Dict, return_all=False) -> Dict:
92 | t_results, m_results = self.call_models(batch)
93 | t_motions, t_latents, t_dists = t_results["motions"], t_results["latent_vectors"], t_results["distributions"]
94 | m_motions, m_latents, m_dists = m_results["motions"], m_results["latent_vectors"], m_results["distributions"]
95 |
96 | ref_motions = batch["motion_x_dict"]["x"]
97 |
98 | # sentence embeddings
99 | sent_emb = batch["sent_emb"]
100 |
101 | # Store all losses
102 | losses = {}
103 |
104 | # Reconstructions losses
105 | # fmt: off
106 | losses["recons"] = (
107 | + self.reconstruction_loss_fn(t_motions, ref_motions) # text -> motion
108 | + self.reconstruction_loss_fn(m_motions, ref_motions) # motion -> motion
109 | )
110 | # fmt: on
111 |
112 | # VAE losses
113 | if self.vae:
114 | # Create a centred normal distribution to compare with
115 | # logvar = 0 -> std = 1
116 | ref_mus = torch.zeros_like(m_dists[0])
117 | ref_logvar = torch.zeros_like(m_dists[1])
118 | ref_dists = (ref_mus, ref_logvar)
119 |
120 | losses["kl"] = (
121 | self.kl_loss_fn(t_dists, m_dists) # text_to_motion
122 | + self.kl_loss_fn(m_dists, t_dists) # motion_to_text
123 | + self.kl_loss_fn(m_dists, ref_dists) # motion
124 | + self.kl_loss_fn(t_dists, ref_dists) # text
125 | )
126 |
127 | # Latent manifold loss
128 | losses["latent"] = self.latent_loss_fn(t_latents, m_latents)
129 |
130 | # TMR: adding the contrastive loss
131 | losses["contrastive"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb)
132 |
133 | # Weighted average of the losses
134 | losses["loss"] = sum(
135 | self.lmd[x] * val for x, val in losses.items() if x in self.lmd
136 | )
137 |
138 | # Used for the validation step
139 | if return_all:
140 | return losses, t_latents, m_latents
141 |
142 | return losses
143 |
144 | def validation_step(self, batch: Dict, batch_idx: int) -> Tensor:
145 | bs = len(batch["motion_x_dict"]["x"])
146 | losses, t_latents, m_latents = self.compute_loss(batch, return_all=True)
147 |
148 | # Store the latent vectors
149 | self.validation_step_t_latents.append(t_latents)
150 | self.validation_step_m_latents.append(m_latents)
151 | self.validation_step_sent_emb.append(batch["sent_emb"])
152 |
153 | for loss_name in sorted(losses):
154 | loss_val = losses[loss_name]
155 | self.log(
156 | f"val_{loss_name}",
157 | loss_val,
158 | on_epoch=True,
159 | on_step=True,
160 | batch_size=bs,
161 | )
162 |
163 | return losses["loss"]
164 |
165 | def on_validation_epoch_end(self):
166 | # Compute contrastive metrics on the whole batch
167 | t_latents = torch.cat(self.validation_step_t_latents)
168 | m_latents = torch.cat(self.validation_step_m_latents)
169 | sent_emb = torch.cat(self.validation_step_sent_emb)
170 |
171 | # Compute the similarity matrix
172 | sim_matrix = get_sim_matrix(t_latents, m_latents).cpu().numpy()
173 |
174 | contrastive_metrics = all_contrastive_metrics(
175 | sim_matrix,
176 | emb=sent_emb.cpu().numpy(),
177 | threshold=self.threshold_selfsim_metrics,
178 | )
179 |
180 | for loss_name in sorted(contrastive_metrics):
181 | loss_val = contrastive_metrics[loss_name]
182 | self.log(
183 | f"val_{loss_name}_epoch",
184 | loss_val,
185 | on_epoch=True,
186 | on_step=False,
187 | )
188 |
189 | self.validation_step_t_latents.clear()
190 | self.validation_step_m_latents.clear()
191 | self.validation_step_sent_emb.clear()
192 |
--------------------------------------------------------------------------------
/src/model/tmr_text_averaging.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from torch import Tensor
3 |
4 | import torch
5 | from .tmr import TMR
6 |
7 | class TMRTextAveraging(TMR):
8 | """Compatible with AugmentedTextMotionDataset Dataset object and collate_text_motion_multiple_texts collate function."""
9 |
10 | # Forward: X => motions
11 | def forward(
12 | self,
13 | inputs,
14 | text_slices: Optional[List[int]] = None,
15 | lengths: Optional[List[int]] = None,
16 | mask: Optional[Tensor] = None,
17 | sample_mean: Optional[bool] = None,
18 | fact: Optional[float] = None,
19 | return_all: bool = False,
20 | ) -> List[Tensor]:
21 |
22 | # Encoding the inputs and sampling if needed
23 | latent_vectors, distributions = self.encode(
24 | inputs, sample_mean=sample_mean, fact=fact, return_distribution=True
25 | )
26 |
27 | # Averages over the different text embbedings for each sample.
28 | if text_slices is not None:
29 | latent_vectors = [torch.mean(latent_vectors[i:j], dim=0) for i, j in text_slices]
30 | latent_vectors = torch.stack(latent_vectors, dim=0)
31 | distributions = list(distributions)
32 | distributions[0] = torch.stack([torch.mean(distributions[0][i:j], dim=0) for i, j in text_slices], dim=0)
33 | distributions[1] = torch.stack([torch.mean(distributions[1][i:j], dim=0) for i, j in text_slices], dim=0)
34 | distributions = tuple(distributions)
35 |
36 | # Decoding the latent vector: generating motions
37 | motions = self.decode(latent_vectors, lengths, mask)
38 |
39 | if return_all:
40 | return {"motions": motions,
41 | "latent_vectors": latent_vectors,
42 | "distributions": distributions}
43 |
44 | return {"motions": motions}
45 |
46 | def call_models(self, batch):
47 | text_x_dict = batch["text_x_dict"]
48 | motion_x_dict = batch["motion_x_dict"]
49 | text_slices = batch["text_slices"]
50 |
51 | mask = motion_x_dict["mask"]
52 |
53 | # text -> motion
54 | t_results = self(text_x_dict, mask=mask, return_all=True, text_slices=text_slices)
55 |
56 | # motion -> motion
57 | m_results = self(motion_x_dict, mask=mask, return_all=True)
58 |
59 | return t_results, m_results
60 |
--------------------------------------------------------------------------------
/src/prepare.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import warnings
4 |
5 | logger = logging.getLogger("torch.distributed.nn.jit.instantiator")
6 | logger.setLevel(logging.ERROR)
7 |
8 |
9 | warnings.filterwarnings(
10 | "ignore", "The PyTorch API of nested tensors is in prototype stage*"
11 | )
12 |
13 | warnings.filterwarnings("ignore", "Converting mask without torch.bool dtype to bool*")
14 |
15 | torch.set_float32_matmul_precision("high")
16 |
--------------------------------------------------------------------------------
/src/renderer/matplotlib.py:
--------------------------------------------------------------------------------
1 | # From TEMOS: temos/render/anim.py
2 | # Assume Z is the gravity axis
3 | # Inspired by
4 | # - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py
5 | # - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py
6 |
7 | import logging
8 |
9 | from dataclasses import dataclass
10 | from typing import List, Tuple, Optional
11 | import numpy as np
12 | from src.rifke import canonicalize_rotation
13 |
14 | logger = logging.getLogger("matplotlib.animation")
15 | logger.setLevel(logging.ERROR)
16 |
17 | colors = ("black", "magenta", "red", "green", "blue")
18 |
19 | KINEMATIC_TREES = {
20 | "smpljoints": [
21 | [0, 3, 6, 9, 12, 15],
22 | [9, 13, 16, 18, 20],
23 | [9, 14, 17, 19, 21],
24 | [0, 1, 4, 7, 10],
25 | [0, 2, 5, 8, 11],
26 | ],
27 | "guoh3djoints": [ # no hands
28 | [0, 3, 6, 9, 12, 15],
29 | [9, 13, 16, 18, 20],
30 | [9, 14, 17, 19, 21],
31 | [0, 1, 4, 7, 10],
32 | [0, 2, 5, 8, 11],
33 | ],
34 | }
35 |
36 |
37 | @dataclass
38 | class MatplotlibRender:
39 | jointstype: str = "smpljoints"
40 | fps: float = 20.0
41 | colors: List[str] = colors
42 | figsize: int = 4
43 | fontsize: int = 15
44 | canonicalize: bool = False
45 |
46 | def __call__(
47 | self,
48 | joints,
49 | highlights=None,
50 | title: str = "",
51 | output: str = "notebook",
52 | jointstype=None,
53 | ):
54 | jointstype = jointstype if jointstype is not None else self.jointstype
55 | render_animation(
56 | joints,
57 | title=title,
58 | highlights=highlights,
59 | output=output,
60 | jointstype=jointstype,
61 | fps=self.fps,
62 | colors=self.colors,
63 | figsize=(self.figsize, self.figsize),
64 | fontsize=self.fontsize,
65 | canonicalize=self.canonicalize,
66 | )
67 |
68 |
69 | def init_axis(fig, title, radius=1.5):
70 | ax = fig.add_subplot(1, 1, 1, projection="3d")
71 | ax.view_init(elev=20.0, azim=-60)
72 |
73 | fact = 2
74 | ax.set_xlim3d([-radius / fact, radius / fact])
75 | ax.set_ylim3d([-radius / fact, radius / fact])
76 | ax.set_zlim3d([0, radius])
77 |
78 | ax.set_aspect("auto")
79 | ax.set_xticklabels([])
80 | ax.set_yticklabels([])
81 | ax.set_zticklabels([])
82 |
83 | ax.set_axis_off()
84 | ax.grid(b=False)
85 |
86 | ax.set_title(title, loc="center", wrap=True)
87 | return ax
88 |
89 |
90 | def plot_floor(ax, minx, maxx, miny, maxy, minz):
91 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
92 |
93 | # Plot a plane XZ
94 | verts = [
95 | [minx, miny, minz],
96 | [minx, maxy, minz],
97 | [maxx, maxy, minz],
98 | [maxx, miny, minz],
99 | ]
100 | xz_plane = Poly3DCollection([verts], zorder=1)
101 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 1))
102 | ax.add_collection3d(xz_plane)
103 |
104 | # Plot a bigger square plane XZ
105 | radius = max((maxx - minx), (maxy - miny))
106 |
107 | # center +- radius
108 | minx_all = (maxx + minx) / 2 - radius
109 | maxx_all = (maxx + minx) / 2 + radius
110 |
111 | miny_all = (maxy + miny) / 2 - radius
112 | maxy_all = (maxy + miny) / 2 + radius
113 |
114 | verts = [
115 | [minx_all, miny_all, minz],
116 | [minx_all, maxy_all, minz],
117 | [maxx_all, maxy_all, minz],
118 | [maxx_all, miny_all, minz],
119 | ]
120 | xz_plane = Poly3DCollection([verts], zorder=1)
121 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
122 | ax.add_collection3d(xz_plane)
123 | return ax
124 |
125 |
126 | def update_camera(ax, root, radius=1.5):
127 | fact = 2
128 | ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]])
129 | ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]])
130 |
131 |
132 | def render_animation(
133 | joints: np.ndarray,
134 | output: str = "notebook",
135 | highlights: Optional[np.ndarray] = None,
136 | jointstype: str = "smpljoints",
137 | title: str = "",
138 | fps: float = 20.0,
139 | colors: List[str] = colors,
140 | figsize: Tuple[int] = (4, 4),
141 | fontsize: int = 15,
142 | canonicalize: bool = False,
143 | agg=True,
144 | ):
145 | if agg:
146 | import matplotlib
147 |
148 | matplotlib.use("Agg")
149 |
150 | if highlights is not None:
151 | assert len(highlights) == len(joints)
152 |
153 | assert jointstype in KINEMATIC_TREES
154 | kinematic_tree = KINEMATIC_TREES[jointstype]
155 |
156 | import matplotlib.pyplot as plt
157 | from matplotlib.animation import FuncAnimation
158 | import matplotlib.patheffects as pe
159 |
160 | mean_fontsize = fontsize
161 |
162 | # heuristic to change fontsize
163 | fontsize = mean_fontsize - (len(title) - 30) / 20
164 | plt.rcParams.update({"font.size": fontsize})
165 |
166 | # Z is gravity here
167 | x, y, z = 0, 1, 2
168 |
169 | joints = joints.copy()
170 |
171 | if canonicalize:
172 | joints = canonicalize_rotation(joints, jointstype=jointstype)
173 |
174 | # Create a figure and initialize 3d plot
175 | fig = plt.figure(figsize=figsize)
176 | ax = init_axis(fig, title)
177 |
178 | # Create spline line
179 | trajectory = joints[:, 0, [x, y]]
180 | avg_segment_length = (
181 | np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3
182 | )
183 | draw_offset = int(25 / avg_segment_length)
184 | (spline_line,) = ax.plot(*trajectory.T, zorder=10, color="white")
185 |
186 | # Create a floor
187 | minx, miny, _ = joints.min(axis=(0, 1))
188 | maxx, maxy, _ = joints.max(axis=(0, 1))
189 | plot_floor(ax, minx, maxx, miny, maxy, 0)
190 |
191 | # Put the character on the floor
192 | height_offset = np.min(joints[:, :, z]) # Min height
193 | joints = joints.copy()
194 | joints[:, :, z] -= height_offset
195 |
196 | # Initialization for redrawing
197 | lines = []
198 | initialized = False
199 |
200 | def update(frame):
201 | nonlocal initialized
202 | skeleton = joints[frame]
203 |
204 | root = skeleton[0]
205 | update_camera(ax, root)
206 |
207 | hcolors = colors
208 | if highlights is not None and highlights[frame]:
209 | hcolors = ("red", "red", "red", "red", "red")
210 |
211 | for index, (chain, color) in enumerate(
212 | zip(reversed(kinematic_tree), reversed(hcolors))
213 | ):
214 | if not initialized:
215 | lines.append(
216 | ax.plot(
217 | skeleton[chain, x],
218 | skeleton[chain, y],
219 | skeleton[chain, z],
220 | linewidth=6.0,
221 | color=color,
222 | zorder=20,
223 | path_effects=[pe.SimpleLineShadow(), pe.Normal()],
224 | )
225 | )
226 |
227 | else:
228 | lines[index][0].set_xdata(skeleton[chain, x])
229 | lines[index][0].set_ydata(skeleton[chain, y])
230 | lines[index][0].set_3d_properties(skeleton[chain, z])
231 | lines[index][0].set_color(color)
232 |
233 | left = max(frame - draw_offset, 0)
234 | right = min(frame + draw_offset, trajectory.shape[0])
235 |
236 | spline_line.set_xdata(trajectory[left:right, 0])
237 | spline_line.set_ydata(trajectory[left:right, 1])
238 | spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0]))
239 | initialized = True
240 |
241 | fig.tight_layout()
242 | frames = joints.shape[0]
243 | anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False)
244 |
245 | if output == "notebook":
246 | from IPython.display import HTML
247 |
248 | HTML(anim.to_jshtml())
249 | else:
250 | # anim.save(output, writer='ffmpeg', fps=fps)
251 | anim.save(output, fps=fps)
252 |
253 | plt.close()
254 |
--------------------------------------------------------------------------------
/src/rifke.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from einops import rearrange
4 | import numpy as np
5 |
6 | from torch import Tensor
7 |
8 | from .geometry import axis_angle_rotation, matrix_to_axis_angle
9 | from .joints import INFOS
10 |
11 |
12 | def joints_to_rifke(joints, jointstype="smpljoints"):
13 | # Joints to rotation invariant poses (Holden et. al.)
14 | # Similar function than fke2rifke in Language2Pose repository
15 | # Adapted from the pytorch version of TEMOS
16 | # https://github.com/Mathux/TEMOS
17 | # Estimate the last velocities based on acceleration
18 | # Difference of rotations are in SO3 space now
19 |
20 | # First remove the ground
21 | ground = joints[..., 2].min()
22 | poses = joints.clone()
23 | poses[..., 2] -= ground
24 |
25 | poses = joints.clone()
26 | translation = poses[..., 0, :].clone()
27 |
28 | # Let the root have the Z translation --> gravity axis
29 | root_grav_axis = translation[..., 2]
30 |
31 | # Trajectory => Translation without gravity axis (Z)
32 | trajectory = translation[..., [0, 1]]
33 |
34 | # Compute the forward direction (before removing the root joint)
35 | forward = get_forward_direction(poses, jointstype=jointstype)
36 |
37 | # Delete the root joints of the poses
38 | poses = poses[..., 1:, :]
39 |
40 | # Remove the trajectory of the poses
41 | poses[..., [0, 1]] -= trajectory[..., None, :]
42 |
43 | vel_trajectory = torch.diff(trajectory, dim=-2)
44 |
45 | # repeat the last acceleration
46 | # for the last (not seen) velocity
47 | last_acceleration = vel_trajectory[..., -1, :] - vel_trajectory[..., -2, :]
48 |
49 | future_velocity = vel_trajectory[..., -1, :] + last_acceleration
50 | vel_trajectory = torch.cat((vel_trajectory, future_velocity[..., None, :]), dim=-2)
51 |
52 | angles = torch.atan2(*(forward.transpose(0, -1))).transpose(0, -1)
53 |
54 | # True difference of angles
55 | mat_rotZ = axis_angle_rotation("Z", angles)
56 | vel_mat_rotZ = mat_rotZ[..., 1:, :, :] @ mat_rotZ.transpose(-1, -2)[..., :-1, :, :]
57 | # repeat the last acceleration (same as the trajectory but in the 3D rotation space)
58 | last_acc_rotZ = (
59 | vel_mat_rotZ[..., -1, :, :] @ vel_mat_rotZ.transpose(-1, -2)[..., -2, :, :]
60 | )
61 | future_vel_rotZ = vel_mat_rotZ[..., -1, :, :] @ last_acc_rotZ
62 | vel_mat_rotZ = torch.cat((vel_mat_rotZ, future_vel_rotZ[..., None, :, :]), dim=-3)
63 | vel_angles = matrix_to_axis_angle(vel_mat_rotZ)[..., 2]
64 |
65 | # Construct the inverse rotation matrix
66 | rotations_inv = mat_rotZ.transpose(-1, -2)[..., :2, :2]
67 |
68 | poses_local = torch.einsum("...lj,...jk->...lk", poses[..., [0, 1]], rotations_inv)
69 | poses_local = torch.stack(
70 | (poses_local[..., 0], poses_local[..., 1], poses[..., 2]), axis=-1
71 | )
72 |
73 | # stack the xyz joints into feature vectors
74 | poses_features = rearrange(poses_local, "... joints xyz -> ... (joints xyz)")
75 |
76 | # Rotate the vel_trajectory
77 | vel_trajectory_local = torch.einsum(
78 | "...j,...jk->...k", vel_trajectory, rotations_inv
79 | )
80 |
81 | # Stack things together
82 | features = group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local)
83 | return features
84 |
85 |
86 | def rifke_to_joints(features: Tensor, jointstype="smpljoints") -> Tensor:
87 | root_grav_axis, poses_features, vel_angles, vel_trajectory_local = ungroup(features)
88 |
89 | # Remove the dummy last angle and integrate the angles
90 | angles = torch.cumsum(vel_angles[..., :-1], dim=-1)
91 | # The first angle is zero
92 | angles = torch.cat((0 * angles[..., [0]], angles), dim=-1)
93 | rotations = axis_angle_rotation("Z", angles)[..., :2, :2]
94 |
95 | # Get back the poses
96 | poses_local = rearrange(poses_features, "... (joints xyz) -> ... joints xyz", xyz=3)
97 |
98 | # Rotate the poses
99 | poses = torch.einsum("...lj,...jk->...lk", poses_local[..., [0, 1]], rotations)
100 | poses = torch.stack((poses[..., 0], poses[..., 1], poses_local[..., 2]), axis=-1)
101 |
102 | # Rotate the vel_trajectory
103 | vel_trajectory = torch.einsum("...j,...jk->...k", vel_trajectory_local, rotations)
104 | # Remove the dummy last velocity and integrate the trajectory
105 | trajectory = torch.cumsum(vel_trajectory[..., :-1, :], dim=-2)
106 | # The first position is zero
107 | trajectory = torch.cat((0 * trajectory[..., [0], :], trajectory), dim=-2)
108 |
109 | # Add the root joints (which is still zero)
110 | poses = torch.cat((0 * poses[..., [0], :], poses), -2)
111 |
112 | # put back the gravity offset
113 | poses[..., 0, 2] = root_grav_axis
114 |
115 | # Add the trajectory globally
116 | poses[..., [0, 1]] += trajectory[..., None, :]
117 | return poses
118 |
119 |
120 | def group(root_grav_axis, poses_features, vel_angles, vel_trajectory_local):
121 | # Stack things together
122 | features = torch.cat(
123 | (
124 | root_grav_axis[..., None],
125 | poses_features,
126 | vel_angles[..., None],
127 | vel_trajectory_local,
128 | ),
129 | -1,
130 | )
131 | return features
132 |
133 |
134 | def ungroup(features: Tensor) -> tuple[Tensor]:
135 | # Unbind things
136 | root_grav_axis = features[..., 0]
137 | poses_features = features[..., 1:-3]
138 | vel_angles = features[..., -3]
139 | vel_trajectory_local = features[..., -2:]
140 | return root_grav_axis, poses_features, vel_angles, vel_trajectory_local
141 |
142 |
143 | def get_forward_direction(poses, jointstype="smpljoints"):
144 | assert jointstype in INFOS
145 | infos = INFOS[jointstype]
146 | assert poses.shape[-2] == infos["njoints"]
147 | RH, LH, RS, LS = infos["RH"], infos["LH"], infos["RS"], infos["LS"]
148 | across = (
149 | poses[..., RH, :] - poses[..., LH, :] + poses[..., RS, :] - poses[..., LS, :]
150 | )
151 | forward = torch.stack((-across[..., 1], across[..., 0]), axis=-1)
152 | forward = torch.nn.functional.normalize(forward, dim=-1)
153 | return forward
154 |
155 |
156 | def canonicalize_rotation(joints, jointstype="smpljoints"):
157 | return_np = False
158 | if isinstance(joints, np.ndarray):
159 | joints = torch.from_numpy(joints)
160 | return_np = True
161 |
162 | features = joints_to_rifke(joints, jointstype=jointstype)
163 | joints_c = rifke_to_joints(features, jointstype=jointstype)
164 | if return_np:
165 | joints_c = joints_c.numpy()
166 | return joints_c
167 |
--------------------------------------------------------------------------------
/stats/babel/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/babel/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/babel_actions_120/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_120/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/babel_actions_120/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_120/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/babel_actions_60/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_60/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/babel_actions_60/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/babel_actions_60/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/humanml3d/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/humanml3d/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/humanml3d_kitml/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/humanml3d_kitml/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml_babel/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/humanml3d_kitml_babel/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/humanml3d_kitml_babel/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/stats/kitml/guoh3dfeats/mean.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/kitml/guoh3dfeats/mean.pt
--------------------------------------------------------------------------------
/stats/kitml/guoh3dfeats/std.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leorebensabath/TMRPlusPlus/23e4deab9c4529bd22e3d40942a9a9a9d7a26db2/stats/kitml/guoh3dfeats/std.pt
--------------------------------------------------------------------------------
/text_motion_sim.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 | import logging
3 | import hydra
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | @hydra.main(version_base=None, config_path="configs", config_name="text_motion_sim")
9 | def text_motion_sim(cfg: DictConfig) -> None:
10 | device = cfg.device
11 | run_dir = cfg.run_dir
12 | ckpt_name = cfg.ckpt_name
13 | npy_path = cfg.npy
14 | text = cfg.text
15 |
16 | import src.prepare # noqa
17 | import torch
18 | import numpy as np
19 | from src.config import read_config
20 | from src.load import load_model_from_cfg
21 | from hydra.utils import instantiate
22 | from pytorch_lightning import seed_everything
23 | from src.data.collate import collate_x_dict
24 | from src.model.tmr import get_score_matrix
25 |
26 | cfg = read_config(run_dir)
27 |
28 | seed_everything(cfg.seed)
29 |
30 | logger.info("Loading the text model")
31 | text_model = instantiate(cfg.data.text_to_token_emb, device=device)
32 |
33 | logger.info("Loading the model")
34 | model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)
35 |
36 | normalizer = instantiate(cfg.data.motion_loader.normalizer)
37 |
38 | motion = torch.from_numpy(np.load(npy_path)).to(torch.float)
39 | motion = normalizer(motion)
40 | motion = motion.to(device)
41 |
42 | motion_x_dict = {"x": motion, "length": len(motion)}
43 |
44 | with torch.inference_mode():
45 | # motion -> latent
46 | motion_x_dict = collate_x_dict([motion_x_dict])
47 | lat_m = model.encode(motion_x_dict, sample_mean=True)[0]
48 |
49 | # text -> latent
50 | text_x_dict = collate_x_dict(text_model([text]))
51 | lat_t = model.encode(text_x_dict, sample_mean=True)[0]
52 |
53 | score = get_score_matrix(lat_t, lat_m).cpu()
54 |
55 | score_str = f"{score:.3}"
56 | logger.info(
57 | f"The similariy score s (0 <= s <= 1) between the text and the motion is: {score_str}"
58 | )
59 |
60 |
61 | if __name__ == "__main__":
62 | text_motion_sim()
63 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from hydra.utils import instantiate
3 | import logging
4 | from omegaconf import DictConfig
5 | import os
6 | import pytorch_lightning as pl
7 |
8 | from src.config import read_config, save_config
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @hydra.main(config_path="configs", config_name="train", version_base="1.3")
14 | def train(cfg: DictConfig):
15 | # Resuming if needed
16 | ckpt = None
17 |
18 | if cfg.ckpt is not None:
19 | ckpt = cfg.ckpt
20 |
21 | if cfg.resume_dir is not None:
22 | assert cfg.ckpt is not None
23 | max_epochs = cfg.trainer.max_epochs
24 | ckpt = os.path.join(cfg.resume_dir, 'logs', 'checkpoints', f'{cfg.ckpt}.ckpt')
25 | cfg = read_config(cfg.resume_dir)
26 | cfg.trainer.max_epochs = max_epochs
27 | logger.info("Resuming training")
28 | logger.info(f"The config is loaded from: \n{cfg.run_dir}")
29 | else:
30 | if "ckpt_path" in cfg and cfg.ckpt_path is not None:
31 | ckpt = cfg.ckpt_path
32 | config_path = save_config(cfg)
33 | logger.info("Training script")
34 | logger.info(f"The config can be found here: \n{config_path}")
35 |
36 | pl.seed_everything(cfg.seed)
37 |
38 | text_to_token_emb = instantiate(cfg.data.text_to_token_emb)
39 | text_to_sent_emb = instantiate(cfg.data.text_to_sent_emb)
40 |
41 | logger.info("Loading the dataloaders")
42 | train_dataset = instantiate(cfg.data, split="train",
43 | text_to_token_emb=text_to_token_emb,
44 | text_to_sent_emb=text_to_sent_emb)
45 |
46 | if "data_val" not in cfg:
47 | data_val = cfg.data
48 | else:
49 | data_val = cfg.data_val
50 | text_to_token_emb = instantiate(cfg.data_val.text_to_token_emb)
51 | text_to_sent_emb = instantiate(cfg.data_val.text_to_sent_emb)
52 |
53 | val_dataset = instantiate(data_val, split="val",
54 | text_to_token_emb=text_to_token_emb,
55 | text_to_sent_emb=text_to_sent_emb)
56 |
57 | train_dataloader = instantiate(
58 | cfg.dataloader,
59 | dataset=train_dataset,
60 | collate_fn=train_dataset.collate_fn,
61 | shuffle=True,
62 | )
63 |
64 | val_dataloader = instantiate(
65 | cfg.dataloader,
66 | dataset=val_dataset,
67 | collate_fn=val_dataset.collate_fn,
68 | shuffle=False,
69 | )
70 |
71 | logger.info("Loading the model")
72 | model = instantiate(cfg.model)
73 |
74 | logger.info(f"Using checkpoint: {ckpt}")
75 | logger.info("Training")
76 | trainer = instantiate(cfg.trainer)
77 | trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt)
78 |
79 |
80 | if __name__ == "__main__":
81 | train()
82 |
--------------------------------------------------------------------------------