├── .gitignore
├── LICENSE
├── README.md
├── configs
├── _base_
│ └── datasets
│ │ ├── human_ml3d_bs128.py
│ │ └── kit_ml_bs128.py
├── mdm
│ └── mdm_t2m_official.py
├── motiondiffuse
│ ├── motiondiffuse_kit.py
│ └── motiondiffuse_t2m.py
└── remodiffuse
│ ├── remodiffuse_kit.py
│ └── remodiffuse_t2m.py
├── imgs
├── pipeline.png
└── teaser.png
├── mogen
├── __init__.py
├── apis
│ ├── __init__.py
│ ├── test.py
│ └── train.py
├── core
│ ├── __init__.py
│ ├── distributed_wrapper.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── eval_hooks.py
│ │ ├── evaluators
│ │ │ ├── __init__.py
│ │ │ ├── base_evaluator.py
│ │ │ ├── diversity_evaluator.py
│ │ │ ├── fid_evaluator.py
│ │ │ ├── matching_score_evaluator.py
│ │ │ ├── multimodality_evaluator.py
│ │ │ └── precision_evaluator.py
│ │ ├── get_model.py
│ │ └── utils.py
│ └── optimizer
│ │ ├── __init__.py
│ │ └── builder.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── builder.py
│ ├── dataset_wrappers.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── compose.py
│ │ ├── formatting.py
│ │ └── transforms.py
│ ├── samplers
│ │ ├── __init__.py
│ │ └── distributed_sampler.py
│ └── text_motion_dataset.py
├── models
│ ├── __init__.py
│ ├── architectures
│ │ ├── __init__.py
│ │ ├── base_architecture.py
│ │ ├── diffusion_architecture.py
│ │ └── vae_architecture.py
│ ├── attentions
│ │ ├── __init__.py
│ │ ├── base_attention.py
│ │ ├── efficient_attention.py
│ │ └── semantics_modulated.py
│ ├── builder.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── mse_loss.py
│ │ └── utils.py
│ ├── rnns
│ │ ├── __init__.py
│ │ └── t2m_bigru.py
│ ├── transformers
│ │ ├── __init__.py
│ │ ├── actor.py
│ │ ├── diffusion_transformer.py
│ │ ├── mdm.py
│ │ ├── motiondiffuse.py
│ │ ├── position_encoding.py
│ │ └── remodiffuse.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gaussian_diffusion.py
│ │ ├── mlp.py
│ │ ├── stylization_block.py
│ │ └── word_vectorizer.py
├── utils
│ ├── __init__.py
│ ├── collect_env.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── misc.py
│ ├── path_utils.py
│ └── plot_utils.py
└── version.py
├── requirements.txt
└── tools
├── dist_train.sh
├── slurm_test.sh
├── slurm_train.sh
├── test.py
├── train.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | **/*.pyc
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # custom
108 | data
109 | # data for pytest moved to http server
110 | # !tests/data
111 | .vscode
112 | .idea
113 | *.pkl
114 | *.pkl.json
115 | *.log.json
116 | work_dirs/
117 | logs/
118 |
119 | # Pytorch
120 | *.pth
121 | *.pt
122 |
123 |
124 | # Visualization
125 | *.mp4
126 | *.png
127 | *.gif
128 | *.jpg
129 | *.obj
130 | *.ply
131 | !demo/resources/*
132 |
133 | # Resources as exception
134 | !resources/*
135 |
136 | # Loaded/Saved data files
137 | *.npz
138 | *.npy
139 | *.pickle
140 |
141 | # MacOS
142 | *DS_Store*
143 | # git
144 | *.orig
145 |
146 | env.sh
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2023 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model
4 |
5 |
15 |
16 | 1S-Lab, Nanyang Technological University
17 | 2SenseTime Research
18 |
19 |
20 | +corresponding author
21 |
22 |
23 |
24 | ---
25 |
26 |
36 |
37 |
38 |
39 |
40 | >**Abstract:** 3D human motion generation is crucial for creative industry. Recent advances rely on generative models with domain knowledge for text-driven motion generation, leading to substantial progress in capturing common motions. However, the performance on more diverse motions remains unsatisfactory. In this work, we propose **ReMoDiffuse**, a diffusion-model-based motion generation framework that integrates a retrieval mechanism to refine the denoising process.
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | >**Pipeline Overview:** ReMoDiffuse is a retrieval-augmented 3D human motion diffusion model. Benefiting from the extra knowledge from the retrieved samples, ReMoDiffuse is able to achieve high-fidelity on the given prompts. It contains three core components: a) **Hybrid Retrieval** database stores multi-modality features of each motion sequence. b) Semantics-modulated transformer incorporates several identical decoder layers, including a **Semantics-Modulated Attention (SMA)** layer and an FFN layer. The SMA layer will adaptively absorb knowledge from both retrived samples and the given prompts. c) **Condition Mxture** technique is proposed to better mix model's outputs under different combinations of conditions.
50 |
51 | ## Updates
52 |
53 | [09/2023] Add a [🤗Hugging Face Demo](https://huggingface.co/spaces/mingyuan/ReMoDiffuse)!
54 |
55 | [09/2023] Add a [Colab Demo](https://colab.research.google.com/drive/1jztE7c8js3P4YFbw5cGNPJAsCVrreTov?usp=sharing)! [](https://colab.research.google.com/drive/1jztE7c8js3P4YFbw5cGNPJAsCVrreTov?usp=sharing)
56 |
57 | [09/2023] Release code for [ReMoDiffuse](https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html) and [MotionDiffuse](https://mingyuan-zhang.github.io/projects/MotionDiffuse.html)
58 |
59 | ## Benchmark and Model Zoo
60 |
61 | #### Supported methods
62 |
63 | - [x] [MotionDiffuse](https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html)
64 | - [x] [MDM](https://guytevet.github.io/mdm-page/)
65 | - [x] [ReMoDiffuse](https://mingyuan-zhang.github.io/projects/MotionDiffuse.html)
66 |
67 |
68 | ## Citation
69 |
70 | If you find our work useful for your research, please consider citing the paper:
71 |
72 | ```
73 | @article{zhang2023remodiffuse,
74 | title={ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model},
75 | author={Zhang, Mingyuan and Guo, Xinying and Pan, Liang and Cai, Zhongang and Hong, Fangzhou and Li, Huirong and Yang, Lei and Liu, Ziwei},
76 | journal={arXiv preprint arXiv:2304.01116},
77 | year={2023}
78 | }
79 | @article{zhang2022motiondiffuse,
80 | title={MotionDiffuse: Text-Driven Human Motion Generation with Diffusion Model},
81 | author={Zhang, Mingyuan and Cai, Zhongang and Pan, Liang and Hong, Fangzhou and Guo, Xinying and Yang, Lei and Liu, Ziwei},
82 | journal={arXiv preprint arXiv:2208.15001},
83 | year={2022}
84 | }
85 | ```
86 |
87 | ## Installation
88 |
89 | ```shell
90 | # Create Conda Environment
91 | conda create -n mogen python=3.9 -y
92 | conda activate mogen
93 |
94 | # C++ Environment
95 | export PATH=/mnt/lustre/share/gcc/gcc-8.5.0/bin:$PATH
96 | export LD_LIBRARY_PATH=/mnt/lustre/share/gcc/gcc-8.5.0/lib:/mnt/lustre/share/gcc/gcc-8.5.0/lib64:/mnt/lustre/share/gcc/gmp-4.3.2/lib:/mnt/lustre/share/gcc/mpc-0.8.1/lib:/mnt/lustre/share/gcc/mpfr-2.4.2/lib:$LD_LIBRARY_PATH
97 |
98 | # Install Pytorch
99 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -y
100 |
101 | # Install MMCV
102 | pip install "mmcv-full>=1.4.2,<=1.9.0" -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.1/index.html
103 |
104 | # Install Pytorch3d
105 | conda install -c bottler nvidiacub -y
106 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y
107 | conda install pytorch3d -c pytorch3d -y
108 |
109 | # Install other requirements
110 | pip install -r requirements.txt
111 | ```
112 |
113 | ## Data Preparation
114 |
115 | Download data files from google drive [link](https://drive.google.com/drive/folders/13kwahiktQ2GMVKfVH3WT-VGAQ6JHbvUv?usp=sharing) or Baidu Netdisk [link](https://pan.baidu.com/s/1604jks-9PtBUtqCpQQmeEg)(access code: vprc). Unzipped all files and arrange them in the following file structure:
116 |
117 | ```text
118 | ReMoDiffuse
119 | ├── mogen
120 | ├── tools
121 | ├── configs
122 | ├── logs
123 | │ ├── motiondiffuse
124 | │ ├── remodiffuse
125 | │ └── mdm
126 | └── data
127 | ├── database
128 | ├── datasets
129 | ├── evaluators
130 | └── glove
131 | ```
132 |
133 | ## Training
134 |
135 | ### Training with a single / multiple GPUs
136 |
137 | ```shell
138 | PYTHONPATH=".":$PYTHONPATH python tools/train.py ${CONFIG_FILE} ${WORK_DIR} --no-validate
139 | ```
140 |
141 | **Note:** The provided config files are designed for training with 8 gpus. If you want to train on a single gpu, you can reduce the number of epochs to one-fourth of the original.
142 |
143 | ### Training with Slurm
144 |
145 | ```shell
146 | ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} --no-validate
147 | ```
148 |
149 | Common optional arguments include:
150 | - `--resume-from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file.
151 | - `--no-validate`: Whether not to evaluate the checkpoint during training.
152 |
153 | Example: using 8 GPUs to train ReMoDiffuse on a slurm cluster.
154 | ```shell
155 | ./tools/slurm_train.sh my_partition my_job configs/remodiffuse/remodiffuse_kit.py logs/remodiffuse_kit 8 --no-validate
156 | ```
157 |
158 | ## Evaluation
159 |
160 | ### Evaluate with a single GPU / multiple GPUs
161 |
162 | ```shell
163 | PYTHONPATH=".":$PYTHONPATH python tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT}
164 | ```
165 |
166 | ### Evaluate with slurm
167 |
168 | ```shell
169 | ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG} ${WORK_DIR} ${CHECKPOINT}
170 | ```
171 | Example:
172 | ```shell
173 | ./tools/slurm_test.sh my_partition test_remodiffuse configs/remodiffuse/remodiffuse_kit.py logs/remodiffuse_kit logs/remodiffuse_kit/latest.pth
174 | ```
175 |
176 | **Note:** Run full evaluation for HumanML3D dataset is very slow. You can change `replication_times` in [human_ml3d_bs128.py](configs/_base_/datasets/human_ml3d_bs128.py) to $1$ for a quick evaluation.
177 |
178 | ## Visualization
179 |
180 | ```shell
181 | PYTHONPATH=".":$PYTHONPATH python tools/visualize.py ${CONFIG} ${CHECKPOINT} \
182 | --text ${TEXT} \
183 | --motion_length ${MOTION_LENGTH} \
184 | --out ${OUTPUT_ANIMATION_PATH} \
185 | --device cpu
186 | ```
187 |
188 | Example:
189 | ```shell
190 | PYTHONPATH=".":$PYTHONPATH python tools/visualize.py \
191 | configs/remodiffuse/remodiffuse_t2m.py \
192 | logs/remodiffuse/remodiffuse_t2m/latest.pth \
193 | --text "a person is running quickly" \
194 | --motion_length 120 \
195 | --out "test.gif" \
196 | --device cpu
197 | ```
198 |
199 | ## Acknowledgement
200 |
201 | This study is supported by the Ministry of Education, Singapore, under its MOE AcRF Tier 2 (MOE-T2EP20221-0012), NTU NAP, and under the RIE2020 Industry Alignment Fund – Industry Collaboration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s).
202 |
203 | The visualization tool is developed on top of [Generating Diverse and Natural 3D Human Motions from Text](https://github.com/EricGuo5513/text-to-motion)
204 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/human_ml3d_bs128.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat']
3 | meta_keys = ['text', 'token']
4 | train_pipeline = [
5 | dict(
6 | type='Normalize',
7 | mean_path='data/datasets/human_ml3d/mean.npy',
8 | std_path='data/datasets/human_ml3d/std.npy'),
9 | dict(type='Crop', crop_size=196),
10 | dict(type='ToTensor', keys=data_keys),
11 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys)
12 | ]
13 |
14 | data = dict(
15 | samples_per_gpu=128,
16 | workers_per_gpu=1,
17 | train=dict(
18 | type='RepeatDataset',
19 | dataset=dict(
20 | type='TextMotionDataset',
21 | dataset_name='human_ml3d',
22 | data_prefix='data',
23 | pipeline=train_pipeline,
24 | ann_file='train.txt',
25 | motion_dir='motions',
26 | text_dir='texts',
27 | token_dir='tokens',
28 | clip_feat_dir='clip_feats',
29 | ),
30 | times=200
31 | ),
32 | test=dict(
33 | type='TextMotionDataset',
34 | dataset_name='human_ml3d',
35 | data_prefix='data',
36 | pipeline=train_pipeline,
37 | ann_file='test.txt',
38 | motion_dir='motions',
39 | text_dir='texts',
40 | token_dir='tokens',
41 | clip_feat_dir='clip_feats',
42 | eval_cfg=dict(
43 | shuffle_indexes=True,
44 | replication_times=20,
45 | replication_reduction='statistics',
46 | text_encoder_name='human_ml3d',
47 | text_encoder_path='data/evaluators/human_ml3d/finest.tar',
48 | motion_encoder_name='human_ml3d',
49 | motion_encoder_path='data/evaluators/human_ml3d/finest.tar',
50 | metrics=[
51 | dict(type='R Precision', batch_size=32, top_k=3),
52 | dict(type='Matching Score', batch_size=32),
53 | dict(type='FID'),
54 | dict(type='Diversity', num_samples=300),
55 | dict(type='MultiModality', num_samples=100, num_repeats=30, num_picks=10)
56 | ]
57 | ),
58 | test_mode=True
59 | )
60 | )
--------------------------------------------------------------------------------
/configs/_base_/datasets/kit_ml_bs128.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat']
3 | meta_keys = ['text', 'token']
4 | train_pipeline = [
5 | dict(type='Crop', crop_size=196),
6 | dict(
7 | type='Normalize',
8 | mean_path='data/datasets/kit_ml/mean.npy',
9 | std_path='data/datasets/kit_ml/std.npy'),
10 | dict(type='ToTensor', keys=data_keys),
11 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys)
12 | ]
13 |
14 | data = dict(
15 | samples_per_gpu=128,
16 | workers_per_gpu=1,
17 | train=dict(
18 | type='RepeatDataset',
19 | dataset=dict(
20 | type='TextMotionDataset',
21 | dataset_name='kit_ml',
22 | data_prefix='data',
23 | pipeline=train_pipeline,
24 | ann_file='train.txt',
25 | motion_dir='motions',
26 | text_dir='texts',
27 | token_dir='tokens',
28 | clip_feat_dir='clip_feats',
29 | ),
30 | times=100
31 | ),
32 | test=dict(
33 | type='TextMotionDataset',
34 | dataset_name='kit_ml',
35 | data_prefix='data',
36 | pipeline=train_pipeline,
37 | ann_file='test.txt',
38 | motion_dir='motions',
39 | text_dir='texts',
40 | token_dir='tokens',
41 | clip_feat_dir='clip_feats',
42 | eval_cfg=dict(
43 | shuffle_indexes=True,
44 | replication_times=20,
45 | replication_reduction='statistics',
46 | text_encoder_name='kit_ml',
47 | text_encoder_path='data/evaluators/kit_ml/finest.tar',
48 | motion_encoder_name='kit_ml',
49 | motion_encoder_path='data/evaluators/kit_ml/finest.tar',
50 | metrics=[
51 | dict(type='R Precision', batch_size=32, top_k=3),
52 | dict(type='Matching Score', batch_size=32),
53 | dict(type='FID'),
54 | dict(type='Diversity', num_samples=300),
55 | dict(type='MultiModality', num_samples=50, num_repeats=30, num_picks=10)
56 | ]
57 | ),
58 | test_mode=True
59 | )
60 | )
--------------------------------------------------------------------------------
/configs/mdm/mdm_t2m_official.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
2 |
3 | # checkpoint saving
4 | checkpoint_config = dict(interval=1)
5 |
6 | dist_params = dict(backend='nccl')
7 | log_level = 'INFO'
8 | load_from = None
9 | resume_from = None
10 | workflow = [('train', 1)]
11 |
12 | # optimizer
13 | optimizer = dict(type='Adam', lr=1e-4)
14 | optimizer_config = dict(grad_clip=None)
15 | # learning policy
16 | lr_config = dict(policy='step', step=[])
17 | runner = dict(type='EpochBasedRunner', max_epochs=50)
18 |
19 | log_config = dict(
20 | interval=50,
21 | hooks=[
22 | dict(type='TextLoggerHook'),
23 | # dict(type='TensorboardLoggerHook')
24 | ])
25 |
26 | input_feats = 263
27 | max_seq_len = 196
28 | latent_dim = 512
29 | time_embed_dim = 2048
30 | text_latent_dim = 256
31 | ff_size = 1024
32 | num_layers = 8
33 | num_heads = 4
34 | dropout = 0.1
35 | cond_mask_prob = 0.1
36 | # model settings
37 | model = dict(
38 | type='MotionDiffusion',
39 | model=dict(
40 | type='MDMTransformer',
41 | input_feats=input_feats,
42 | latent_dim=latent_dim,
43 | ff_size=ff_size,
44 | num_layers=num_layers,
45 | num_heads=num_heads,
46 | dropout=dropout,
47 | time_embed_dim=time_embed_dim,
48 | cond_mask_prob=cond_mask_prob,
49 | guide_scale=2.5,
50 | clip_version='ViT-B/32',
51 | use_official_ckpt=True
52 | ),
53 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
54 | diffusion_train=dict(
55 | beta_scheduler='cosine',
56 | diffusion_steps=1000,
57 | model_mean_type='start_x',
58 | model_var_type='fixed_small',
59 | ),
60 | diffusion_test=dict(
61 | beta_scheduler='cosine',
62 | diffusion_steps=1000,
63 | model_mean_type='start_x',
64 | model_var_type='fixed_small',
65 | ),
66 | inference_type='ddpm'
67 | )
--------------------------------------------------------------------------------
/configs/motiondiffuse/motiondiffuse_kit.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py']
2 |
3 | # checkpoint saving
4 | checkpoint_config = dict(interval=1)
5 |
6 | dist_params = dict(backend='nccl')
7 | log_level = 'INFO'
8 | load_from = None
9 | resume_from = None
10 | workflow = [('train', 1)]
11 |
12 | # optimizer
13 | optimizer = dict(type='Adam', lr=2e-4)
14 | optimizer_config = dict(grad_clip=None)
15 | # learning policy
16 | lr_config = dict(policy='step', step=[])
17 | runner = dict(type='EpochBasedRunner', max_epochs=50)
18 |
19 | log_config = dict(
20 | interval=50,
21 | hooks=[
22 | dict(type='TextLoggerHook'),
23 | # dict(type='TensorboardLoggerHook')
24 | ])
25 |
26 | input_feats = 251
27 | max_seq_len = 196
28 | latent_dim = 512
29 | time_embed_dim = 2048
30 | text_latent_dim = 256
31 | ff_size = 1024
32 | num_heads = 8
33 | dropout = 0
34 | # model settings
35 | model = dict(
36 | type='MotionDiffusion',
37 | model=dict(
38 | type='MotionDiffuseTransformer',
39 | input_feats=input_feats,
40 | max_seq_len=max_seq_len,
41 | latent_dim=latent_dim,
42 | time_embed_dim=time_embed_dim,
43 | num_layers=8,
44 | sa_block_cfg=dict(
45 | type='EfficientSelfAttention',
46 | latent_dim=latent_dim,
47 | num_heads=num_heads,
48 | dropout=dropout,
49 | time_embed_dim=time_embed_dim
50 | ),
51 | ca_block_cfg=dict(
52 | type='EfficientCrossAttention',
53 | latent_dim=latent_dim,
54 | text_latent_dim=text_latent_dim,
55 | num_heads=num_heads,
56 | dropout=dropout,
57 | time_embed_dim=time_embed_dim
58 | ),
59 | ffn_cfg=dict(
60 | latent_dim=latent_dim,
61 | ffn_dim=ff_size,
62 | dropout=dropout,
63 | time_embed_dim=time_embed_dim
64 | ),
65 | text_encoder=dict(
66 | pretrained_model='clip',
67 | latent_dim=text_latent_dim,
68 | num_layers=4,
69 | num_heads=4,
70 | ff_size=2048,
71 | dropout=dropout,
72 | use_text_proj=True
73 | )
74 | ),
75 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
76 | diffusion_train=dict(
77 | beta_scheduler='linear',
78 | diffusion_steps=1000,
79 | model_mean_type='epsilon',
80 | model_var_type='fixed_small',
81 | ),
82 | diffusion_test=dict(
83 | beta_scheduler='linear',
84 | diffusion_steps=1000,
85 | model_mean_type='epsilon',
86 | model_var_type='fixed_small',
87 | ),
88 | inference_type='ddpm'
89 | )
90 |
--------------------------------------------------------------------------------
/configs/motiondiffuse/motiondiffuse_t2m.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
2 |
3 | # checkpoint saving
4 | checkpoint_config = dict(interval=1)
5 |
6 | dist_params = dict(backend='nccl')
7 | log_level = 'INFO'
8 | load_from = None
9 | resume_from = None
10 | workflow = [('train', 1)]
11 |
12 | # optimizer
13 | optimizer = dict(type='Adam', lr=2e-4)
14 | optimizer_config = dict(grad_clip=None)
15 | # learning policy
16 | lr_config = dict(policy='step', step=[])
17 | runner = dict(type='EpochBasedRunner', max_epochs=50)
18 |
19 | log_config = dict(
20 | interval=50,
21 | hooks=[
22 | dict(type='TextLoggerHook'),
23 | # dict(type='TensorboardLoggerHook')
24 | ])
25 |
26 | input_feats = 263
27 | max_seq_len = 196
28 | latent_dim = 512
29 | time_embed_dim = 2048
30 | text_latent_dim = 256
31 | ff_size = 1024
32 | num_heads = 8
33 | dropout = 0
34 | # model settings
35 | model = dict(
36 | type='MotionDiffusion',
37 | model=dict(
38 | type='MotionDiffuseTransformer',
39 | input_feats=input_feats,
40 | max_seq_len=max_seq_len,
41 | latent_dim=latent_dim,
42 | time_embed_dim=time_embed_dim,
43 | num_layers=8,
44 | sa_block_cfg=dict(
45 | type='EfficientSelfAttention',
46 | latent_dim=latent_dim,
47 | num_heads=num_heads,
48 | dropout=dropout,
49 | time_embed_dim=time_embed_dim
50 | ),
51 | ca_block_cfg=dict(
52 | type='EfficientCrossAttention',
53 | latent_dim=latent_dim,
54 | text_latent_dim=text_latent_dim,
55 | num_heads=num_heads,
56 | dropout=dropout,
57 | time_embed_dim=time_embed_dim
58 | ),
59 | ffn_cfg=dict(
60 | latent_dim=latent_dim,
61 | ffn_dim=ff_size,
62 | dropout=dropout,
63 | time_embed_dim=time_embed_dim
64 | ),
65 | text_encoder=dict(
66 | pretrained_model='clip',
67 | latent_dim=text_latent_dim,
68 | num_layers=4,
69 | num_heads=4,
70 | ff_size=2048,
71 | dropout=dropout,
72 | use_text_proj=True
73 | )
74 | ),
75 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
76 | diffusion_train=dict(
77 | beta_scheduler='linear',
78 | diffusion_steps=1000,
79 | model_mean_type='epsilon',
80 | model_var_type='fixed_small',
81 | ),
82 | diffusion_test=dict(
83 | beta_scheduler='linear',
84 | diffusion_steps=1000,
85 | model_mean_type='epsilon',
86 | model_var_type='fixed_small',
87 | ),
88 | inference_type='ddpm'
89 | )
90 | data = dict(samples_per_gpu=128)
--------------------------------------------------------------------------------
/configs/remodiffuse/remodiffuse_kit.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py']
2 |
3 | # checkpoint saving
4 | checkpoint_config = dict(interval=1)
5 |
6 | dist_params = dict(backend='nccl')
7 | log_level = 'INFO'
8 | load_from = None
9 | resume_from = None
10 | workflow = [('train', 1)]
11 |
12 | # optimizer
13 | optimizer = dict(type='Adam', lr=2e-4)
14 | optimizer_config = dict(grad_clip=None)
15 | # learning policy
16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False)
17 | runner = dict(type='EpochBasedRunner', max_epochs=20)
18 |
19 | log_config = dict(
20 | interval=50,
21 | hooks=[
22 | dict(type='TextLoggerHook'),
23 | # dict(type='TensorboardLoggerHook')
24 | ])
25 |
26 | input_feats = 251
27 | max_seq_len = 196
28 | latent_dim = 512
29 | time_embed_dim = 2048
30 | text_latent_dim = 256
31 | ff_size = 1024
32 | num_heads = 8
33 | dropout = 0
34 |
35 | # model settings
36 | model = dict(
37 | type='MotionDiffusion',
38 | model=dict(
39 | type='ReMoDiffuseTransformer',
40 | input_feats=input_feats,
41 | max_seq_len=max_seq_len,
42 | latent_dim=latent_dim,
43 | time_embed_dim=time_embed_dim,
44 | num_layers=4,
45 | ca_block_cfg=dict(
46 | type='SemanticsModulatedAttention',
47 | latent_dim=latent_dim,
48 | text_latent_dim=text_latent_dim,
49 | num_heads=num_heads,
50 | dropout=dropout,
51 | time_embed_dim=time_embed_dim
52 | ),
53 | ffn_cfg=dict(
54 | latent_dim=latent_dim,
55 | ffn_dim=ff_size,
56 | dropout=dropout,
57 | time_embed_dim=time_embed_dim
58 | ),
59 | text_encoder=dict(
60 | pretrained_model='clip',
61 | latent_dim=text_latent_dim,
62 | num_layers=2,
63 | ff_size=2048,
64 | dropout=dropout,
65 | use_text_proj=False
66 | ),
67 | retrieval_cfg=dict(
68 | num_retrieval=2,
69 | stride=4,
70 | num_layers=2,
71 | num_motion_layers=2,
72 | kinematic_coef=0.1,
73 | topk=2,
74 | retrieval_file='data/database/kit_text_train.npz',
75 | latent_dim=latent_dim,
76 | output_dim=latent_dim,
77 | max_seq_len=max_seq_len,
78 | num_heads=num_heads,
79 | ff_size=ff_size,
80 | dropout=dropout,
81 | ffn_cfg=dict(
82 | latent_dim=latent_dim,
83 | ffn_dim=ff_size,
84 | dropout=dropout,
85 | ),
86 | sa_block_cfg=dict(
87 | type='EfficientSelfAttention',
88 | latent_dim=latent_dim,
89 | num_heads=num_heads,
90 | dropout=dropout
91 | ),
92 | ),
93 | scale_func_cfg=dict(
94 | coarse_scale=4.0,
95 | both_coef=0.78123,
96 | text_coef=0.39284,
97 | retr_coef=-0.12475
98 | )
99 | ),
100 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
101 | diffusion_train=dict(
102 | beta_scheduler='linear',
103 | diffusion_steps=1000,
104 | model_mean_type='start_x',
105 | model_var_type='fixed_large',
106 | ),
107 | diffusion_test=dict(
108 | beta_scheduler='linear',
109 | diffusion_steps=1000,
110 | model_mean_type='start_x',
111 | model_var_type='fixed_large',
112 | respace='15,15,8,6,6',
113 | ),
114 | inference_type='ddim'
115 | )
--------------------------------------------------------------------------------
/configs/remodiffuse/remodiffuse_t2m.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
2 |
3 | # checkpoint saving
4 | checkpoint_config = dict(interval=1)
5 |
6 | dist_params = dict(backend='nccl')
7 | log_level = 'INFO'
8 | load_from = None
9 | resume_from = None
10 | workflow = [('train', 1)]
11 |
12 | # optimizer
13 | optimizer = dict(type='Adam', lr=2e-4)
14 | optimizer_config = dict(grad_clip=None)
15 | # learning policy
16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False)
17 | runner = dict(type='EpochBasedRunner', max_epochs=40)
18 |
19 | log_config = dict(
20 | interval=50,
21 | hooks=[
22 | dict(type='TextLoggerHook'),
23 | # dict(type='TensorboardLoggerHook')
24 | ])
25 |
26 | input_feats = 263
27 | max_seq_len = 196
28 | latent_dim = 512
29 | time_embed_dim = 2048
30 | text_latent_dim = 256
31 | ff_size = 1024
32 | num_heads = 8
33 | dropout = 0
34 |
35 | # model settings
36 | model = dict(
37 | type='MotionDiffusion',
38 | model=dict(
39 | type='ReMoDiffuseTransformer',
40 | input_feats=input_feats,
41 | max_seq_len=max_seq_len,
42 | latent_dim=latent_dim,
43 | time_embed_dim=time_embed_dim,
44 | num_layers=4,
45 | ca_block_cfg=dict(
46 | type='SemanticsModulatedAttention',
47 | latent_dim=latent_dim,
48 | text_latent_dim=text_latent_dim,
49 | num_heads=num_heads,
50 | dropout=dropout,
51 | time_embed_dim=time_embed_dim
52 | ),
53 | ffn_cfg=dict(
54 | latent_dim=latent_dim,
55 | ffn_dim=ff_size,
56 | dropout=dropout,
57 | time_embed_dim=time_embed_dim
58 | ),
59 | text_encoder=dict(
60 | pretrained_model='clip',
61 | latent_dim=text_latent_dim,
62 | num_layers=2,
63 | ff_size=2048,
64 | dropout=dropout,
65 | use_text_proj=False
66 | ),
67 | retrieval_cfg=dict(
68 | num_retrieval=2,
69 | stride=4,
70 | num_layers=2,
71 | num_motion_layers=2,
72 | kinematic_coef=0.1,
73 | topk=2,
74 | retrieval_file='data/database/t2m_text_train.npz',
75 | latent_dim=latent_dim,
76 | output_dim=latent_dim,
77 | max_seq_len=max_seq_len,
78 | num_heads=num_heads,
79 | ff_size=ff_size,
80 | dropout=dropout,
81 | ffn_cfg=dict(
82 | latent_dim=latent_dim,
83 | ffn_dim=ff_size,
84 | dropout=dropout,
85 | ),
86 | sa_block_cfg=dict(
87 | type='EfficientSelfAttention',
88 | latent_dim=latent_dim,
89 | num_heads=num_heads,
90 | dropout=dropout
91 | ),
92 | ),
93 | scale_func_cfg=dict(
94 | coarse_scale=6.5,
95 | both_coef=0.52351,
96 | text_coef=-0.28419,
97 | retr_coef=2.39872
98 | )
99 | ),
100 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
101 | diffusion_train=dict(
102 | beta_scheduler='linear',
103 | diffusion_steps=1000,
104 | model_mean_type='start_x',
105 | model_var_type='fixed_large',
106 | ),
107 | diffusion_test=dict(
108 | beta_scheduler='linear',
109 | diffusion_steps=1000,
110 | model_mean_type='start_x',
111 | model_var_type='fixed_large',
112 | respace='15,15,8,6,6',
113 | ),
114 | inference_type='ddim'
115 | )
--------------------------------------------------------------------------------
/imgs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/imgs/pipeline.png
--------------------------------------------------------------------------------
/imgs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/imgs/teaser.png
--------------------------------------------------------------------------------
/mogen/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import mmcv
4 | from packaging.version import parse
5 |
6 | from .version import __version__
7 |
8 |
9 | def digit_version(version_str: str, length: int = 4):
10 | """Convert a version string into a tuple of integers.
11 | This method is usually used for comparing two versions. For pre-release
12 | versions: alpha < beta < rc.
13 | Args:
14 | version_str (str): The version string.
15 | length (int): The maximum number of version levels. Default: 4.
16 | Returns:
17 | tuple[int]: The version info in digits (integers).
18 | """
19 | version = parse(version_str)
20 | assert version.release, f'failed to parse version {version_str}'
21 | release = list(version.release)
22 | release = release[:length]
23 | if len(release) < length:
24 | release = release + [0] * (length - len(release))
25 | if version.is_prerelease:
26 | mapping = {'a': -3, 'b': -2, 'rc': -1}
27 | val = -4
28 | # version.pre can be None
29 | if version.pre:
30 | if version.pre[0] not in mapping:
31 | warnings.warn(f'unknown prerelease version {version.pre[0]}, '
32 | 'version checking may go wrong')
33 | else:
34 | val = mapping[version.pre[0]]
35 | release.extend([val, version.pre[-1]])
36 | else:
37 | release.extend([val, 0])
38 |
39 | elif version.is_postrelease:
40 | release.extend([1, version.post])
41 | else:
42 | release.extend([0, 0])
43 | return tuple(release)
44 |
45 |
46 | mmcv_minimum_version = '1.4.2'
47 | mmcv_maximum_version = '1.9.0'
48 | mmcv_version = digit_version(mmcv.__version__)
49 |
50 |
51 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
52 | and mmcv_version <= digit_version(mmcv_maximum_version)), \
53 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
54 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
55 |
56 | __all__ = ['__version__', 'digit_version']
--------------------------------------------------------------------------------
/mogen/apis/__init__.py:
--------------------------------------------------------------------------------
1 | from mogen.apis import test, train
2 | from mogen.apis.test import (
3 | collect_results_cpu,
4 | collect_results_gpu,
5 | multi_gpu_test,
6 | single_gpu_test,
7 | )
8 | from mogen.apis.train import set_random_seed, train_model
9 |
10 | __all__ = [
11 | 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
12 | 'single_gpu_test', 'set_random_seed', 'train_model'
13 | ]
--------------------------------------------------------------------------------
/mogen/apis/test.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | import shutil
4 | import tempfile
5 | import time
6 |
7 | import mmcv
8 | import torch
9 | import torch.distributed as dist
10 | from mmcv.runner import get_dist_info
11 |
12 |
13 | def single_gpu_test(model, data_loader):
14 | """Test with single gpu."""
15 | model.eval()
16 | results = []
17 | dataset = data_loader.dataset
18 | prog_bar = mmcv.ProgressBar(len(dataset))
19 | for i, data in enumerate(data_loader):
20 | with torch.no_grad():
21 | result = model(return_loss=False, **data)
22 |
23 | batch_size = len(result)
24 | if isinstance(result, list):
25 | results.extend(result)
26 | else:
27 | results.append(result)
28 |
29 | batch_size = data['motion'].size(0)
30 | for _ in range(batch_size):
31 | prog_bar.update()
32 | return results
33 |
34 |
35 | def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
36 | """Test model with multiple gpus.
37 | This method tests model with multiple gpus and collects the results
38 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
39 | it encodes results to gpu tensors and use gpu communication for results
40 | collection. On cpu mode it saves the results on different gpus to 'tmpdir'
41 | and collects them by the rank 0 worker.
42 | Args:
43 | model (nn.Module): Model to be tested.
44 | data_loader (nn.Dataloader): Pytorch data loader.
45 | tmpdir (str): Path of directory to save the temporary results from
46 | different gpus under cpu mode.
47 | gpu_collect (bool): Option to use either gpu or cpu to collect results.
48 | Returns:
49 | list: The prediction results.
50 | """
51 | model.eval()
52 | results = []
53 | dataset = data_loader.dataset
54 | rank, world_size = get_dist_info()
55 | if rank == 0:
56 | # Check if tmpdir is valid for cpu_collect
57 | if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)):
58 | raise OSError((f'The tmpdir {tmpdir} already exists.',
59 | ' Since tmpdir will be deleted after testing,',
60 | ' please make sure you specify an empty one.'))
61 | prog_bar = mmcv.ProgressBar(len(dataset))
62 | time.sleep(2) # This line can prevent deadlock problem in some cases.
63 | for i, data in enumerate(data_loader):
64 | with torch.no_grad():
65 | result = model(return_loss=False, **data)
66 | if isinstance(result, list):
67 | results.extend(result)
68 | else:
69 | results.append(result)
70 |
71 | if rank == 0:
72 | batch_size = data['motion'].size(0)
73 | for _ in range(batch_size * world_size):
74 | prog_bar.update()
75 |
76 | # collect results from all ranks
77 | if gpu_collect:
78 | results = collect_results_gpu(results, len(dataset))
79 | else:
80 | results = collect_results_cpu(results, len(dataset), tmpdir)
81 | return results
82 |
83 |
84 | def collect_results_cpu(result_part, size, tmpdir=None):
85 | """Collect results in cpu."""
86 | rank, world_size = get_dist_info()
87 | # create a tmp dir if it is not specified
88 | if tmpdir is None:
89 | MAX_LEN = 512
90 | # 32 is whitespace
91 | dir_tensor = torch.full((MAX_LEN, ),
92 | 32,
93 | dtype=torch.uint8,
94 | device='cuda')
95 | if rank == 0:
96 | mmcv.mkdir_or_exist('.dist_test')
97 | tmpdir = tempfile.mkdtemp(dir='.dist_test')
98 | tmpdir = torch.tensor(
99 | bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
100 | dir_tensor[:len(tmpdir)] = tmpdir
101 | dist.broadcast(dir_tensor, 0)
102 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
103 | else:
104 | mmcv.mkdir_or_exist(tmpdir)
105 | # dump the part result to the dir
106 | mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
107 | dist.barrier()
108 | # collect all parts
109 | if rank != 0:
110 | return None
111 | else:
112 | # load results of all parts from tmp dir
113 | part_list = []
114 | for i in range(world_size):
115 | part_file = osp.join(tmpdir, f'part_{i}.pkl')
116 | part_result = mmcv.load(part_file)
117 | part_list.append(part_result)
118 | # sort the results
119 | ordered_results = []
120 | for res in zip(*part_list):
121 | ordered_results.extend(list(res))
122 | # the dataloader may pad some samples
123 | ordered_results = ordered_results[:size]
124 | # remove tmp dir
125 | shutil.rmtree(tmpdir)
126 | return ordered_results
127 |
128 |
129 | def collect_results_gpu(result_part, size):
130 | """Collect results in gpu."""
131 | rank, world_size = get_dist_info()
132 | # dump result part to tensor with pickle
133 | part_tensor = torch.tensor(
134 | bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
135 | # gather all result part tensor shape
136 | shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
137 | shape_list = [shape_tensor.clone() for _ in range(world_size)]
138 | dist.all_gather(shape_list, shape_tensor)
139 | # padding result part tensor to max length
140 | shape_max = torch.tensor(shape_list).max()
141 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
142 | part_send[:shape_tensor[0]] = part_tensor
143 | part_recv_list = [
144 | part_tensor.new_zeros(shape_max) for _ in range(world_size)
145 | ]
146 | # gather all result part
147 | dist.all_gather(part_recv_list, part_send)
148 |
149 | if rank == 0:
150 | part_list = []
151 | for recv, shape in zip(part_recv_list, shape_list):
152 | part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
153 | part_list.append(part_result)
154 | # sort the results
155 | ordered_results = []
156 | for res in zip(*part_list):
157 | ordered_results.extend(list(res))
158 | # the dataloader may pad some samples
159 | ordered_results = ordered_results[:size]
160 | return ordered_results
--------------------------------------------------------------------------------
/mogen/apis/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import warnings
3 |
4 | import numpy as np
5 | import torch
6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7 | from mmcv.runner import (
8 | DistSamplerSeedHook,
9 | Fp16OptimizerHook,
10 | OptimizerHook,
11 | build_runner,
12 | )
13 |
14 | from mogen.core.distributed_wrapper import DistributedDataParallelWrapper
15 | from mogen.core.evaluation import DistEvalHook, EvalHook
16 | from mogen.core.optimizer import build_optimizers
17 | from mogen.datasets import build_dataloader, build_dataset
18 | from mogen.utils import get_root_logger
19 |
20 |
21 | def set_random_seed(seed, deterministic=False):
22 | """Set random seed.
23 | Args:
24 | seed (int): Seed to be used.
25 | deterministic (bool): Whether to set the deterministic option for
26 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
27 | to True and `torch.backends.cudnn.benchmark` to False.
28 | Default: False.
29 | """
30 | random.seed(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed_all(seed)
34 | if deterministic:
35 | torch.backends.cudnn.deterministic = True
36 | torch.backends.cudnn.benchmark = False
37 |
38 |
39 | def train_model(model,
40 | dataset,
41 | cfg,
42 | distributed=False,
43 | validate=False,
44 | timestamp=None,
45 | device='cuda',
46 | meta=None):
47 | """Main api for training model."""
48 | logger = get_root_logger(cfg.log_level)
49 |
50 | # prepare data loaders
51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
52 |
53 | data_loaders = [
54 | build_dataloader(
55 | ds,
56 | cfg.data.samples_per_gpu,
57 | cfg.data.workers_per_gpu,
58 | # cfg.gpus will be ignored if distributed
59 | num_gpus=len(cfg.gpu_ids),
60 | dist=distributed,
61 | round_up=True,
62 | seed=cfg.seed) for ds in dataset
63 | ]
64 |
65 | # determine whether use adversarial training precess or not
66 | use_adverserial_train = cfg.get('use_adversarial_train', False)
67 |
68 | # put model on gpus
69 | if distributed:
70 | find_unused_parameters = cfg.get('find_unused_parameters', True)
71 | # Sets the `find_unused_parameters` parameter in
72 | # torch.nn.parallel.DistributedDataParallel
73 | if use_adverserial_train:
74 | # Use DistributedDataParallelWrapper for adversarial training
75 | model = DistributedDataParallelWrapper(
76 | model,
77 | device_ids=[torch.cuda.current_device()],
78 | broadcast_buffers=False,
79 | find_unused_parameters=find_unused_parameters)
80 | else:
81 | model = MMDistributedDataParallel(
82 | model.cuda(),
83 | device_ids=[torch.cuda.current_device()],
84 | broadcast_buffers=False,
85 | find_unused_parameters=find_unused_parameters)
86 | else:
87 | if device == 'cuda':
88 | model = MMDataParallel(
89 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
90 | elif device == 'cpu':
91 | model = model.cpu()
92 | else:
93 | raise ValueError(F'unsupported device name {device}.')
94 |
95 | # build runner
96 | optimizer = build_optimizers(model, cfg.optimizer)
97 |
98 | if cfg.get('runner') is None:
99 | cfg.runner = {
100 | 'type': 'EpochBasedRunner',
101 | 'max_epochs': cfg.total_epochs
102 | }
103 | warnings.warn(
104 | 'config is now expected to have a `runner` section, '
105 | 'please set `runner` in your config.', UserWarning)
106 |
107 | runner = build_runner(
108 | cfg.runner,
109 | default_args=dict(
110 | model=model,
111 | batch_processor=None,
112 | optimizer=optimizer,
113 | work_dir=cfg.work_dir,
114 | logger=logger,
115 | meta=meta))
116 |
117 | # an ugly walkaround to make the .log and .log.json filenames the same
118 | runner.timestamp = timestamp
119 |
120 | if use_adverserial_train:
121 | # The optimizer step process is included in the train_step function
122 | # of the model, so the runner should NOT include optimizer hook.
123 | optimizer_config = None
124 | else:
125 | # fp16 setting
126 | fp16_cfg = cfg.get('fp16', None)
127 | if fp16_cfg is not None:
128 | optimizer_config = Fp16OptimizerHook(
129 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
130 | elif distributed and 'type' not in cfg.optimizer_config:
131 | optimizer_config = OptimizerHook(**cfg.optimizer_config)
132 | else:
133 | optimizer_config = cfg.optimizer_config
134 |
135 | # register hooks
136 | runner.register_training_hooks(
137 | cfg.lr_config,
138 | optimizer_config,
139 | cfg.checkpoint_config,
140 | cfg.log_config,
141 | cfg.get('momentum_config', None),
142 | custom_hooks_config=cfg.get('custom_hooks', None))
143 | if distributed:
144 | runner.register_hook(DistSamplerSeedHook())
145 |
146 | # register eval hooks
147 | if validate:
148 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
149 | val_dataloader = build_dataloader(
150 | val_dataset,
151 | samples_per_gpu=cfg.data.samples_per_gpu,
152 | workers_per_gpu=cfg.data.workers_per_gpu,
153 | dist=distributed,
154 | shuffle=False,
155 | round_up=True)
156 | eval_cfg = cfg.get('evaluation', {})
157 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
158 | eval_hook = DistEvalHook if distributed else EvalHook
159 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
160 |
161 | if cfg.resume_from:
162 | runner.resume(cfg.resume_from)
163 | elif cfg.load_from:
164 | runner.load_checkpoint(cfg.load_from)
165 | runner.run(data_loaders, cfg.workflow)
--------------------------------------------------------------------------------
/mogen/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/core/__init__.py
--------------------------------------------------------------------------------
/mogen/core/distributed_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
5 | from mmcv.parallel.scatter_gather import scatter_kwargs
6 | from torch.cuda._utils import _get_device_index
7 |
8 |
9 | @MODULE_WRAPPERS.register_module()
10 | class DistributedDataParallelWrapper(nn.Module):
11 | """A DistributedDataParallel wrapper for models in 3D mesh estimation task.
12 |
13 | In 3D mesh estimation task, there is a need to wrap different modules in
14 | the models with separate DistributedDataParallel. Otherwise, it will cause
15 | errors for GAN training.
16 | More specific, the GAN model, usually has two sub-modules:
17 | generator and discriminator. If we wrap both of them in one
18 | standard DistributedDataParallel, it will cause errors during training,
19 | because when we update the parameters of the generator (or discriminator),
20 | the parameters of the discriminator (or generator) is not updated, which is
21 | not allowed for DistributedDataParallel.
22 | So we design this wrapper to separately wrap DistributedDataParallel
23 | for generator and discriminator.
24 | In this wrapper, we perform two operations:
25 | 1. Wrap the modules in the models with separate MMDistributedDataParallel.
26 | Note that only modules with parameters will be wrapped.
27 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
28 | Note that the arguments of this wrapper is the same as those in
29 | `torch.nn.parallel.distributed.DistributedDataParallel`.
30 | Args:
31 | module (nn.Module): Module that needs to be wrapped.
32 | device_ids (list[int | `torch.device`]): Same as that in
33 | `torch.nn.parallel.distributed.DistributedDataParallel`.
34 | dim (int, optional): Same as that in the official scatter function in
35 | pytorch. Defaults to 0.
36 | broadcast_buffers (bool): Same as that in
37 | `torch.nn.parallel.distributed.DistributedDataParallel`.
38 | Defaults to False.
39 | find_unused_parameters (bool, optional): Same as that in
40 | `torch.nn.parallel.distributed.DistributedDataParallel`.
41 | Traverse the autograd graph of all tensors contained in returned
42 | value of the wrapped module’s forward function. Defaults to False.
43 | kwargs (dict): Other arguments used in
44 | `torch.nn.parallel.distributed.DistributedDataParallel`.
45 | """
46 |
47 | def __init__(self,
48 | module,
49 | device_ids,
50 | dim=0,
51 | broadcast_buffers=False,
52 | find_unused_parameters=False,
53 | **kwargs):
54 | super().__init__()
55 | assert len(device_ids) == 1, (
56 | 'Currently, DistributedDataParallelWrapper only supports one'
57 | 'single CUDA device for each process.'
58 | f'The length of device_ids must be 1, but got {len(device_ids)}.')
59 | self.module = module
60 | self.dim = dim
61 | self.to_ddp(
62 | device_ids=device_ids,
63 | dim=dim,
64 | broadcast_buffers=broadcast_buffers,
65 | find_unused_parameters=find_unused_parameters,
66 | **kwargs)
67 | self.output_device = _get_device_index(device_ids[0], True)
68 |
69 | def to_ddp(self, device_ids, dim, broadcast_buffers,
70 | find_unused_parameters, **kwargs):
71 | """Wrap models with separate MMDistributedDataParallel.
72 |
73 | It only wraps the modules with parameters.
74 | """
75 | for name, module in self.module._modules.items():
76 | if next(module.parameters(), None) is None:
77 | module = module.cuda()
78 | elif all(not p.requires_grad for p in module.parameters()):
79 | module = module.cuda()
80 | else:
81 | module = MMDistributedDataParallel(
82 | module.cuda(),
83 | device_ids=device_ids,
84 | dim=dim,
85 | broadcast_buffers=broadcast_buffers,
86 | find_unused_parameters=find_unused_parameters,
87 | **kwargs)
88 | self.module._modules[name] = module
89 |
90 | def scatter(self, inputs, kwargs, device_ids):
91 | """Scatter function.
92 |
93 | Args:
94 | inputs (Tensor): Input Tensor.
95 | kwargs (dict): Args for
96 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
97 | device_ids (int): Device id.
98 | """
99 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
100 |
101 | def forward(self, *inputs, **kwargs):
102 | """Forward function.
103 |
104 | Args:
105 | inputs (tuple): Input data.
106 | kwargs (dict): Args for
107 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
108 | """
109 | inputs, kwargs = self.scatter(inputs, kwargs,
110 | [torch.cuda.current_device()])
111 | return self.module(*inputs[0], **kwargs[0])
112 |
113 | def train_step(self, *inputs, **kwargs):
114 | """Train step function.
115 |
116 | Args:
117 | inputs (Tensor): Input Tensor.
118 | kwargs (dict): Args for
119 | ``mmcv.parallel.scatter_gather.scatter_kwargs``.
120 | """
121 | inputs, kwargs = self.scatter(inputs, kwargs,
122 | [torch.cuda.current_device()])
123 | output = self.module.train_step(*inputs[0], **kwargs[0])
124 | return output
125 |
126 | def val_step(self, *inputs, **kwargs):
127 | """Validation step function.
128 |
129 | Args:
130 | inputs (tuple): Input data.
131 | kwargs (dict): Args for ``scatter_kwargs``.
132 | """
133 | inputs, kwargs = self.scatter(inputs, kwargs,
134 | [torch.cuda.current_device()])
135 | output = self.module.val_step(*inputs[0], **kwargs[0])
136 | return output
--------------------------------------------------------------------------------
/mogen/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from mogen.core.evaluation.eval_hooks import DistEvalHook, EvalHook
2 | from mogen.core.evaluation.builder import build_evaluator
3 |
4 | __all__ = ["DistEvalHook", "EvalHook", "build_evaluator"]
--------------------------------------------------------------------------------
/mogen/core/evaluation/builder.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 | from mmcv.utils import Registry
4 | from .evaluators.precision_evaluator import PrecisionEvaluator
5 | from .evaluators.matching_score_evaluator import MatchingScoreEvaluator
6 | from .evaluators.fid_evaluator import FIDEvaluator
7 | from .evaluators.diversity_evaluator import DiversityEvaluator
8 | from .evaluators.multimodality_evaluator import MultiModalityEvaluator
9 |
10 | EVALUATORS = Registry('evaluators')
11 |
12 | EVALUATORS.register_module(name='R Precision', module=PrecisionEvaluator)
13 | EVALUATORS.register_module(name='Matching Score', module=MatchingScoreEvaluator)
14 | EVALUATORS.register_module(name='FID', module=FIDEvaluator)
15 | EVALUATORS.register_module(name='Diversity', module=DiversityEvaluator)
16 | EVALUATORS.register_module(name='MultiModality', module=MultiModalityEvaluator)
17 |
18 |
19 | def build_evaluator(metric, eval_cfg, data_len, eval_indexes):
20 | cfg = copy.deepcopy(eval_cfg)
21 | cfg.update(metric)
22 | cfg.pop('metrics')
23 | cfg['data_len'] = data_len
24 | cfg['eval_indexes'] = eval_indexes
25 | evaluator = EVALUATORS.build(cfg)
26 | if evaluator.append_indexes is not None:
27 | for i in range(eval_cfg['replication_times']):
28 | eval_indexes[i] = np.concatenate((eval_indexes[i], evaluator.append_indexes[i]), axis=0)
29 | return evaluator, eval_indexes
30 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/eval_hooks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import tempfile
3 | import warnings
4 |
5 | from mmcv.runner import DistEvalHook as BaseDistEvalHook
6 | from mmcv.runner import EvalHook as BaseEvalHook
7 |
8 | mogen_GREATER_KEYS = []
9 | mogen_LESS_KEYS = []
10 |
11 |
12 | class EvalHook(BaseEvalHook):
13 |
14 | def __init__(self,
15 | dataloader,
16 | start=None,
17 | interval=1,
18 | by_epoch=True,
19 | save_best=None,
20 | rule=None,
21 | test_fn=None,
22 | greater_keys=mogen_GREATER_KEYS,
23 | less_keys=mogen_LESS_KEYS,
24 | **eval_kwargs):
25 | if test_fn is None:
26 | from mogen.apis import single_gpu_test
27 | test_fn = single_gpu_test
28 |
29 | # remove "gpu_collect" from eval_kwargs
30 | if 'gpu_collect' in eval_kwargs:
31 | warnings.warn(
32 | '"gpu_collect" will be deprecated in EvalHook.'
33 | 'Please remove it from the config.', DeprecationWarning)
34 | _ = eval_kwargs.pop('gpu_collect')
35 |
36 | # update "save_best" according to "key_indicator" and remove the
37 | # latter from eval_kwargs
38 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
39 | warnings.warn(
40 | '"key_indicator" will be deprecated in EvalHook.'
41 | 'Please use "save_best" to specify the metric key,'
42 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning)
43 |
44 | key_indicator = eval_kwargs.pop('key_indicator', None)
45 | if save_best is True and key_indicator is None:
46 | raise ValueError('key_indicator should not be None, when '
47 | 'save_best is set to True.')
48 | save_best = key_indicator
49 |
50 | super().__init__(dataloader, start, interval, by_epoch, save_best,
51 | rule, test_fn, greater_keys, less_keys, **eval_kwargs)
52 |
53 | def evaluate(self, runner, results):
54 |
55 | with tempfile.TemporaryDirectory() as tmp_dir:
56 | eval_res = self.dataloader.dataset.evaluate(
57 | results,
58 | work_dir=tmp_dir,
59 | logger=runner.logger,
60 | **self.eval_kwargs)
61 |
62 | for name, val in eval_res.items():
63 | runner.log_buffer.output[name] = val
64 | runner.log_buffer.ready = True
65 |
66 | if self.save_best is not None:
67 | if self.key_indicator == 'auto':
68 | self._init_rule(self.rule, list(eval_res.keys())[0])
69 |
70 | return eval_res[self.key_indicator]
71 |
72 | return None
73 |
74 |
75 | class DistEvalHook(BaseDistEvalHook):
76 |
77 | def __init__(self,
78 | dataloader,
79 | start=None,
80 | interval=1,
81 | by_epoch=True,
82 | save_best=None,
83 | rule=None,
84 | test_fn=None,
85 | greater_keys=mogen_GREATER_KEYS,
86 | less_keys=mogen_LESS_KEYS,
87 | broadcast_bn_buffer=True,
88 | tmpdir=None,
89 | gpu_collect=False,
90 | **eval_kwargs):
91 |
92 | if test_fn is None:
93 | from mogen.apis import multi_gpu_test
94 | test_fn = multi_gpu_test
95 |
96 | # update "save_best" according to "key_indicator" and remove the
97 | # latter from eval_kwargs
98 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
99 | warnings.warn(
100 | '"key_indicator" will be deprecated in EvalHook.'
101 | 'Please use "save_best" to specify the metric key,'
102 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning)
103 |
104 | key_indicator = eval_kwargs.pop('key_indicator', None)
105 | if save_best is True and key_indicator is None:
106 | raise ValueError('key_indicator should not be None, when '
107 | 'save_best is set to True.')
108 | save_best = key_indicator
109 |
110 | super().__init__(dataloader, start, interval, by_epoch, save_best,
111 | rule, test_fn, greater_keys, less_keys,
112 | broadcast_bn_buffer, tmpdir, gpu_collect,
113 | **eval_kwargs)
114 |
115 | def evaluate(self, runner, results):
116 | """Evaluate the results.
117 | Args:
118 | runner (:obj:`mmcv.Runner`): The underlined training runner.
119 | results (list): Output results.
120 | """
121 | with tempfile.TemporaryDirectory() as tmp_dir:
122 | eval_res = self.dataloader.dataset.evaluate(
123 | results,
124 | work_dir=tmp_dir,
125 | logger=runner.logger,
126 | **self.eval_kwargs)
127 |
128 | for name, val in eval_res.items():
129 | runner.log_buffer.output[name] = val
130 | runner.log_buffer.ready = True
131 |
132 | if self.save_best is not None:
133 | if self.key_indicator == 'auto':
134 | # infer from eval_results
135 | self._init_rule(self.rule, list(eval_res.keys())[0])
136 | return eval_res[self.key_indicator]
137 |
138 | return None
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/core/evaluation/evaluators/__init__.py
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/base_evaluator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from ..utils import get_metric_statistics
4 |
5 |
6 | class BaseEvaluator(object):
7 |
8 | def __init__(self,
9 | batch_size=None,
10 | drop_last=False,
11 | replication_times=1,
12 | replication_reduction='statistics',
13 | eval_begin_idx=None,
14 | eval_end_idx=None):
15 | self.batch_size = batch_size
16 | self.drop_last = drop_last
17 | self.replication_times = replication_times
18 | self.replication_reduction = replication_reduction
19 | assert replication_reduction in ['statistics', 'mean', 'concat']
20 | self.eval_begin_idx = eval_begin_idx
21 | self.eval_end_idx = eval_end_idx
22 |
23 | def evaluate(self, results):
24 | total_len = len(results)
25 | partial_len = total_len // self.replication_times
26 | all_metrics = []
27 | for replication_idx in range(self.replication_times):
28 | partial_results = results[
29 | replication_idx * partial_len: (replication_idx + 1) * partial_len]
30 | if self.batch_size is not None:
31 | batch_metrics = []
32 | for batch_start in range(self.eval_begin_idx, self.eval_end_idx, self.batch_size):
33 | batch_results = partial_results[batch_start: batch_start + self.batch_size]
34 | if len(batch_results) < self.batch_size and self.drop_last:
35 | continue
36 | batch_metrics.append(self.single_evaluate(batch_results))
37 | all_metrics.append(self.concat_batch_metrics(batch_metrics))
38 | else:
39 | batch_results = partial_results[self.eval_begin_idx: self.eval_end_idx]
40 | all_metrics.append(self.single_evaluate(batch_results))
41 | all_metrics = np.stack(all_metrics, axis=0)
42 | if self.replication_reduction == 'statistics':
43 | values = get_metric_statistics(all_metrics, self.replication_times)
44 | elif self.replication_reduction == 'mean':
45 | values = np.mean(all_metrics, axis=0)
46 | elif self.replication_reduction == 'concat':
47 | values = all_metrics
48 | return self.parse_values(values)
49 |
50 | def prepare_results(self, results):
51 | text = []
52 | pred_motion = []
53 | pred_motion_length = []
54 | pred_motion_mask = []
55 | motion = []
56 | motion_length = []
57 | motion_mask = []
58 | token = []
59 | # count the maximum motion length
60 | T = max([result['motion'].shape[0] for result in results])
61 | for result in results:
62 | cur_motion = result['motion']
63 | if cur_motion.shape[0] < T:
64 | padding_values = torch.zeros((T - cur_motion.shape[0], cur_motion.shape[1]))
65 | padding_values = padding_values.type_as(pred_motion)
66 | cur_motion = torch.cat([cur_motion, padding_values], dim=0)
67 | motion.append(cur_motion)
68 | cur_pred_motion = result['pred_motion']
69 | if cur_pred_motion.shape[0] < T:
70 | padding_values = torch.zeros((T - cur_pred_motion.shape[0], cur_pred_motion.shape[1]))
71 | padding_values = padding_values.type_as(cur_pred_motion)
72 | cur_pred_motion = torch.cat([cur_pred_motion, padding_values], dim=0)
73 | pred_motion.append(cur_pred_motion)
74 | cur_motion_mask = result['motion_mask']
75 | if cur_motion_mask.shape[0] < T:
76 | padding_values = torch.zeros((T - cur_motion_mask.shape[0]))
77 | padding_values = padding_values.type_as(cur_motion_mask)
78 | cur_motion_mask= torch.cat([cur_motion_mask, padding_values], dim=0)
79 | motion_mask.append(cur_motion_mask)
80 | cur_pred_motion_mask = result['pred_motion_mask']
81 | if cur_pred_motion_mask.shape[0] < T:
82 | padding_values = torch.zeros((T - cur_pred_motion_mask.shape[0]))
83 | padding_values = padding_values.type_as(cur_pred_motion_mask)
84 | cur_pred_motion_mask= torch.cat([cur_pred_motion_mask, padding_values], dim=0)
85 | pred_motion_mask.append(cur_pred_motion_mask)
86 | motion_length.append(result['motion_length'].item())
87 | pred_motion_length.append(result['pred_motion_length'].item())
88 | if 'text' in result.keys():
89 | text.append(result['text'])
90 | if 'token' in result.keys():
91 | token.append(result['token'])
92 |
93 | motion = torch.stack(motion, dim=0)
94 | pred_motion = torch.stack(pred_motion, dim=0)
95 | motion_mask = torch.stack(motion_mask, dim=0)
96 | pred_motion_mask = torch.stack(pred_motion_mask, dim=0)
97 | motion_length = torch.Tensor(motion_length).to(motion.device).long()
98 | pred_motion_length = torch.Tensor(pred_motion_length).to(motion.device).long()
99 | output = {
100 | 'pred_motion': pred_motion,
101 | 'pred_motion_mask': pred_motion_mask,
102 | 'pred_motion_length': pred_motion_length,
103 | 'motion': motion,
104 | 'motion_mask': motion_mask,
105 | 'motion_length': motion_length,
106 | 'text': text,
107 | 'token': token
108 | }
109 | return output
110 |
111 | def to_device(self, device):
112 | for model in self.model_list:
113 | model.to(device)
114 |
115 | def motion_encode(self, motion, motion_length, motion_mask, device):
116 | N = motion.shape[0]
117 | motion_emb = []
118 | batch_size = 32
119 | cur_idx = 0
120 | with torch.no_grad():
121 | while cur_idx < N:
122 | cur_motion = motion[cur_idx: cur_idx + batch_size].to(device)
123 | cur_motion_length = motion_length[cur_idx: cur_idx + batch_size].to(device)
124 | cur_motion_mask = motion_mask[cur_idx: cur_idx + batch_size].to(device)
125 | cur_motion_emb = self.motion_encoder(cur_motion, cur_motion_length, cur_motion_mask)
126 | motion_emb.append(cur_motion_emb)
127 | cur_idx += batch_size
128 | motion_emb = torch.cat(motion_emb, dim=0)
129 | return motion_emb
130 |
131 | def text_encode(self, text, token, device):
132 | N = len(text)
133 | text_emb = []
134 | batch_size = 32
135 | cur_idx = 0
136 | with torch.no_grad():
137 | while cur_idx < N:
138 | cur_text = text[cur_idx: cur_idx + batch_size]
139 | cur_token = token[cur_idx: cur_idx + batch_size]
140 | cur_text_emb = self.text_encoder(cur_text, cur_token, device)
141 | text_emb.append(cur_text_emb)
142 | cur_idx += batch_size
143 | text_emb = torch.cat(text_emb, dim=0)
144 | return text_emb
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/diversity_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..get_model import get_motion_model
5 | from .base_evaluator import BaseEvaluator
6 | from ..utils import calculate_diversity
7 |
8 |
9 | class DiversityEvaluator(BaseEvaluator):
10 |
11 | def __init__(self,
12 | data_len=0,
13 | motion_encoder_name=None,
14 | motion_encoder_path=None,
15 | num_samples=300,
16 | batch_size=None,
17 | drop_last=False,
18 | replication_times=1,
19 | replication_reduction='statistics',
20 | **kwargs):
21 | super().__init__(
22 | replication_times=replication_times,
23 | replication_reduction=replication_reduction,
24 | batch_size=batch_size,
25 | drop_last=drop_last,
26 | eval_begin_idx=0,
27 | eval_end_idx=data_len
28 | )
29 | self.num_samples = num_samples
30 | self.append_indexes = None
31 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
32 | self.model_list = [self.motion_encoder]
33 |
34 | def single_evaluate(self, results):
35 | results = self.prepare_results(results)
36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37 | motion = results['motion']
38 | pred_motion = results['pred_motion']
39 | pred_motion_length = results['pred_motion_length']
40 | pred_motion_mask = results['pred_motion_mask']
41 | self.motion_encoder.to(device)
42 | self.motion_encoder.eval()
43 | with torch.no_grad():
44 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
45 | diversity = calculate_diversity(pred_motion_emb, self.num_samples)
46 | return diversity
47 |
48 | def parse_values(self, values):
49 | metrics = {}
50 | metrics['Diversity (mean)'] = values[0]
51 | metrics['Diversity (conf)'] = values[1]
52 | return metrics
53 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/fid_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..get_model import get_motion_model
5 | from .base_evaluator import BaseEvaluator
6 | from ..utils import (
7 | calculate_activation_statistics,
8 | calculate_frechet_distance)
9 |
10 |
11 | class FIDEvaluator(BaseEvaluator):
12 |
13 | def __init__(self,
14 | data_len=0,
15 | motion_encoder_name=None,
16 | motion_encoder_path=None,
17 | batch_size=None,
18 | drop_last=False,
19 | replication_times=1,
20 | replication_reduction='statistics',
21 | **kwargs):
22 | super().__init__(
23 | replication_times=replication_times,
24 | replication_reduction=replication_reduction,
25 | batch_size=batch_size,
26 | drop_last=drop_last,
27 | eval_begin_idx=0,
28 | eval_end_idx=data_len
29 | )
30 | self.append_indexes = None
31 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
32 | self.model_list = [self.motion_encoder]
33 |
34 | def single_evaluate(self, results):
35 | results = self.prepare_results(results)
36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37 | pred_motion = results['pred_motion']
38 |
39 | pred_motion_length = results['pred_motion_length']
40 | pred_motion_mask = results['pred_motion_mask']
41 | motion = results['motion']
42 | motion_length = results['motion_length']
43 | motion_mask = results['motion_mask']
44 | self.motion_encoder.to(device)
45 | self.motion_encoder.eval()
46 | with torch.no_grad():
47 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
48 | gt_motion_emb = self.motion_encode(motion, motion_length, motion_mask, device).cpu().detach().numpy()
49 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_emb)
50 | pred_mu, pred_cov = calculate_activation_statistics(pred_motion_emb)
51 | fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov)
52 | return fid
53 |
54 | def parse_values(self, values):
55 | metrics = {}
56 | metrics['FID (mean)'] = values[0]
57 | metrics['FID (conf)'] = values[1]
58 | return metrics
59 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/matching_score_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..get_model import get_motion_model, get_text_model
5 | from .base_evaluator import BaseEvaluator
6 | from ..utils import calculate_top_k, euclidean_distance_matrix
7 |
8 |
9 | class MatchingScoreEvaluator(BaseEvaluator):
10 |
11 | def __init__(self,
12 | data_len=0,
13 | text_encoder_name=None,
14 | text_encoder_path=None,
15 | motion_encoder_name=None,
16 | motion_encoder_path=None,
17 | top_k=3,
18 | batch_size=32,
19 | drop_last=False,
20 | replication_times=1,
21 | replication_reduction='statistics',
22 | **kwargs):
23 | super().__init__(
24 | replication_times=replication_times,
25 | replication_reduction=replication_reduction,
26 | batch_size=batch_size,
27 | drop_last=drop_last,
28 | eval_begin_idx=0,
29 | eval_end_idx=data_len
30 | )
31 | self.append_indexes = None
32 | self.text_encoder = get_text_model(text_encoder_name, text_encoder_path)
33 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
34 | self.top_k = top_k
35 | self.model_list = [self.text_encoder, self.motion_encoder]
36 |
37 | def single_evaluate(self, results):
38 | results = self.prepare_results(results)
39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40 | motion = results['motion']
41 | pred_motion = results['pred_motion']
42 | pred_motion_length = results['pred_motion_length']
43 | pred_motion_mask = results['pred_motion_mask']
44 | text = results['text']
45 | token = results['token']
46 | self.text_encoder.to(device)
47 | self.motion_encoder.to(device)
48 | self.text_encoder.eval()
49 | self.motion_encoder.eval()
50 | with torch.no_grad():
51 | word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy()
52 | motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
53 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb)
54 | matching_score = dist_mat.trace()
55 | all_size = word_emb.shape[0]
56 | return matching_score, all_size
57 |
58 | def concat_batch_metrics(self, batch_metrics):
59 | matching_score_sum = 0
60 | all_size = 0
61 | for batch_matching_score, batch_all_size in batch_metrics:
62 | matching_score_sum += batch_matching_score
63 | all_size += batch_all_size
64 | matching_score = matching_score_sum / all_size
65 | return matching_score
66 |
67 | def parse_values(self, values):
68 | metrics = {}
69 | metrics['Matching Score (mean)'] = values[0]
70 | metrics['Matching Score (conf)'] = values[1]
71 | return metrics
72 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/multimodality_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..get_model import get_motion_model
5 | from .base_evaluator import BaseEvaluator
6 | from ..utils import calculate_multimodality
7 |
8 |
9 | class MultiModalityEvaluator(BaseEvaluator):
10 |
11 | def __init__(self,
12 | data_len=0,
13 | motion_encoder_name=None,
14 | motion_encoder_path=None,
15 | num_samples=100,
16 | num_repeats=30,
17 | num_picks=10,
18 | batch_size=None,
19 | drop_last=False,
20 | replication_times=1,
21 | replication_reduction='statistics',
22 | **kwargs):
23 | super().__init__(
24 | replication_times=replication_times,
25 | replication_reduction=replication_reduction,
26 | batch_size=batch_size,
27 | drop_last=drop_last,
28 | eval_begin_idx=data_len,
29 | eval_end_idx=data_len + num_samples * num_repeats
30 | )
31 | self.num_samples = num_samples
32 | self.num_repeats = num_repeats
33 | self.num_picks = num_picks
34 | self.append_indexes = []
35 | for i in range(replication_times):
36 | append_indexes = []
37 | selected_indexs = np.random.choice(data_len, self.num_samples)
38 | for index in selected_indexs:
39 | append_indexes = append_indexes + [index] * self.num_repeats
40 | self.append_indexes.append(np.array(append_indexes))
41 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
42 | self.model_list = [self.motion_encoder]
43 |
44 | def single_evaluate(self, results):
45 | results = self.prepare_results(results)
46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47 | motion = results['motion']
48 | pred_motion = results['pred_motion']
49 | pred_motion_length = results['pred_motion_length']
50 | pred_motion_mask = results['pred_motion_mask']
51 | self.motion_encoder.to(device)
52 | self.motion_encoder.eval()
53 | with torch.no_grad():
54 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
55 | pred_motion_emb = pred_motion_emb.reshape((self.num_samples, self.num_repeats, -1))
56 | multimodality = calculate_multimodality(pred_motion_emb, self.num_picks)
57 | return multimodality
58 |
59 | def parse_values(self, values):
60 | metrics = {}
61 | metrics['MultiModality (mean)'] = values[0]
62 | metrics['MultiModality (conf)'] = values[1]
63 | return metrics
64 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/evaluators/precision_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from ..get_model import get_motion_model, get_text_model
5 | from .base_evaluator import BaseEvaluator
6 | from ..utils import calculate_top_k, euclidean_distance_matrix
7 |
8 |
9 | class PrecisionEvaluator(BaseEvaluator):
10 |
11 | def __init__(self,
12 | data_len=0,
13 | text_encoder_name=None,
14 | text_encoder_path=None,
15 | motion_encoder_name=None,
16 | motion_encoder_path=None,
17 | top_k=3,
18 | batch_size=32,
19 | drop_last=False,
20 | replication_times=1,
21 | replication_reduction='statistics',
22 | **kwargs):
23 | super().__init__(
24 | replication_times=replication_times,
25 | replication_reduction=replication_reduction,
26 | batch_size=batch_size,
27 | drop_last=drop_last,
28 | eval_begin_idx=0,
29 | eval_end_idx=data_len
30 | )
31 | self.append_indexes = None
32 | self.text_encoder = get_text_model(text_encoder_name, text_encoder_path)
33 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
34 | self.top_k = top_k
35 | self.model_list = [self.text_encoder, self.motion_encoder]
36 |
37 | def single_evaluate(self, results):
38 | results = self.prepare_results(results)
39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40 | motion = results['motion']
41 | pred_motion = results['pred_motion']
42 | pred_motion_length = results['pred_motion_length']
43 | pred_motion_mask = results['pred_motion_mask']
44 | text = results['text']
45 | token = results['token']
46 | self.text_encoder.to(device)
47 | self.motion_encoder.to(device)
48 | self.text_encoder.eval()
49 | self.motion_encoder.eval()
50 | with torch.no_grad():
51 | word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy()
52 | motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
53 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb)
54 | argsmax = np.argsort(dist_mat, axis=1)
55 | top_k_mat = calculate_top_k(argsmax, top_k=self.top_k)
56 | top_k_count = top_k_mat.sum(axis=0)
57 | all_size = word_emb.shape[0]
58 | return top_k_count, all_size
59 |
60 | def concat_batch_metrics(self, batch_metrics):
61 | top_k_count = 0
62 | all_size = 0
63 | for batch_top_k_count, batch_all_size in batch_metrics:
64 | top_k_count += batch_top_k_count
65 | all_size += batch_all_size
66 | R_precision = top_k_count / all_size
67 | return R_precision
68 |
69 | def parse_values(self, values):
70 | metrics = {}
71 | for top_k in range(self.top_k):
72 | metrics['R_precision Top %d (mean)' % (top_k + 1)] = values[0][top_k]
73 | metrics['R_precision Top %d (conf)' % (top_k + 1)] = values[1][top_k]
74 | return metrics
75 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/get_model.py:
--------------------------------------------------------------------------------
1 | from mogen.models import build_submodule
2 |
3 |
4 | def get_motion_model(name, ckpt_path):
5 | if name == 'kit_ml':
6 | model = build_submodule(dict(
7 | type='T2MMotionEncoder',
8 | input_size=251,
9 | movement_hidden_size=512,
10 | movement_latent_size=512,
11 | motion_hidden_size=1024,
12 | motion_latent_size=512,
13 | ))
14 | else:
15 | model = build_submodule(dict(
16 | type='T2MMotionEncoder',
17 | input_size=263,
18 | movement_hidden_size=512,
19 | movement_latent_size=512,
20 | motion_hidden_size=1024,
21 | motion_latent_size=512,
22 | ))
23 | model.load_pretrained(ckpt_path)
24 | return model
25 |
26 | def get_text_model(name, ckpt_path):
27 | if name == 'kit_ml':
28 | model = build_submodule(dict(
29 | type='T2MTextEncoder',
30 | word_size=300,
31 | pos_size=15,
32 | hidden_size=512,
33 | output_size=512,
34 | max_text_len=20
35 | ))
36 | else:
37 | model = build_submodule(dict(
38 | type='T2MTextEncoder',
39 | word_size=300,
40 | pos_size=15,
41 | hidden_size=512,
42 | output_size=512,
43 | max_text_len=20
44 | ))
45 | model.load_pretrained(ckpt_path)
46 | return model
47 |
--------------------------------------------------------------------------------
/mogen/core/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 |
5 | def get_metric_statistics(values, replication_times):
6 | mean = np.mean(values, axis=0)
7 | std = np.std(values, axis=0)
8 | conf_interval = 1.96 * std / np.sqrt(replication_times)
9 | return mean, conf_interval
10 |
11 |
12 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
13 | def euclidean_distance_matrix(matrix1, matrix2):
14 | """
15 | Params:
16 | -- matrix1: N1 x D
17 | -- matrix2: N2 x D
18 | Returns:
19 | -- dist: N1 x N2
20 | dist[i, j] == distance(matrix1[i], matrix2[j])
21 | """
22 | assert matrix1.shape[1] == matrix2.shape[1]
23 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
24 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
25 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
26 | dists = np.sqrt(d1 + d2 + d3) # broadcasting
27 | return dists
28 |
29 |
30 | def calculate_top_k(mat, top_k):
31 | size = mat.shape[0]
32 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
33 | bool_mat = (mat == gt_mat)
34 | correct_vec = False
35 | top_k_list = []
36 | for i in range(top_k):
37 | # print(correct_vec, bool_mat[:, i])
38 | correct_vec = (correct_vec | bool_mat[:, i])
39 | # print(correct_vec)
40 | top_k_list.append(correct_vec[:, None])
41 | top_k_mat = np.concatenate(top_k_list, axis=1)
42 | return top_k_mat
43 |
44 |
45 | def calculate_activation_statistics(activations):
46 | """
47 | Params:
48 | -- activation: num_samples x dim_feat
49 | Returns:
50 | -- mu: dim_feat
51 | -- sigma: dim_feat x dim_feat
52 | """
53 | mu = np.mean(activations, axis=0)
54 | cov = np.cov(activations, rowvar=False)
55 | return mu, cov
56 |
57 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
58 | """Numpy implementation of the Frechet Distance.
59 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
60 | and X_2 ~ N(mu_2, C_2) is
61 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
62 | Stable version by Dougal J. Sutherland.
63 | Params:
64 | -- mu1 : Numpy array containing the activations of a layer of the
65 | inception net (like returned by the function 'get_predictions')
66 | for generated samples.
67 | -- mu2 : The sample mean over activations, precalculated on an
68 | representative data set.
69 | -- sigma1: The covariance matrix over activations for generated samples.
70 | -- sigma2: The covariance matrix over activations, precalculated on an
71 | representative data set.
72 | Returns:
73 | -- : The Frechet Distance.
74 | """
75 |
76 | mu1 = np.atleast_1d(mu1)
77 | mu2 = np.atleast_1d(mu2)
78 |
79 | sigma1 = np.atleast_2d(sigma1)
80 | sigma2 = np.atleast_2d(sigma2)
81 |
82 | assert mu1.shape == mu2.shape, \
83 | 'Training and test mean vectors have different lengths'
84 | assert sigma1.shape == sigma2.shape, \
85 | 'Training and test covariances have different dimensions'
86 |
87 | diff = mu1 - mu2
88 |
89 | # Product might be almost singular
90 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
91 | if not np.isfinite(covmean).all():
92 | msg = ('fid calculation produces singular product; '
93 | 'adding %s to diagonal of cov estimates') % eps
94 | print(msg)
95 | offset = np.eye(sigma1.shape[0]) * eps
96 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
97 |
98 | # Numerical error might give slight imaginary component
99 | if np.iscomplexobj(covmean):
100 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
101 | m = np.max(np.abs(covmean.imag))
102 | raise ValueError('Imaginary component {}'.format(m))
103 | covmean = covmean.real
104 |
105 | tr_covmean = np.trace(covmean)
106 |
107 | return (diff.dot(diff) + np.trace(sigma1) +
108 | np.trace(sigma2) - 2 * tr_covmean)
109 |
110 |
111 | def calculate_diversity(activation, diversity_times):
112 | assert len(activation.shape) == 2
113 | assert activation.shape[0] > diversity_times
114 | num_samples = activation.shape[0]
115 |
116 | first_indices = np.random.choice(num_samples, diversity_times, replace=False)
117 | second_indices = np.random.choice(num_samples, diversity_times, replace=False)
118 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
119 | return dist.mean()
120 |
121 |
122 | def calculate_multimodality(activation, multimodality_times):
123 | assert len(activation.shape) == 3
124 | assert activation.shape[1] > multimodality_times
125 | num_per_sent = activation.shape[1]
126 |
127 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
128 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
129 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
130 | return dist.mean()
131 |
--------------------------------------------------------------------------------
/mogen/core/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import OPTIMIZERS, build_optimizers
2 |
3 | __all__ = ['build_optimizers', 'OPTIMIZERS']
--------------------------------------------------------------------------------
/mogen/core/optimizer/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.runner import build_optimizer
3 | from mmcv.utils import Registry
4 |
5 | OPTIMIZERS = Registry('optimizers')
6 |
7 |
8 | def build_optimizers(model, cfgs):
9 | """Build multiple optimizers from configs. If `cfgs` contains several dicts
10 | for optimizers, then a dict for each constructed optimizers will be
11 | returned. If `cfgs` only contains one optimizer config, the constructed
12 | optimizer itself will be returned. For example,
13 |
14 | 1) Multiple optimizer configs:
15 |
16 | .. code-block:: python
17 |
18 | optimizer_cfg = dict(
19 | model1=dict(type='SGD', lr=lr),
20 | model2=dict(type='SGD', lr=lr))
21 |
22 | The return dict is
23 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
24 |
25 | 2) Single optimizer config:
26 |
27 | .. code-block:: python
28 |
29 | optimizer_cfg = dict(type='SGD', lr=lr)
30 |
31 | The return is ``torch.optim.Optimizer``.
32 |
33 | Args:
34 | model (:obj:`nn.Module`): The model with parameters to be optimized.
35 | cfgs (dict): The config dict of the optimizer.
36 |
37 | Returns:
38 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
39 | The initialized optimizers.
40 | """
41 | optimizers = {}
42 | if hasattr(model, 'module'):
43 | model = model.module
44 | # determine whether 'cfgs' has several dicts for optimizers
45 | if all(isinstance(v, dict) for v in cfgs.values()):
46 | for key, cfg in cfgs.items():
47 | cfg_ = cfg.copy()
48 | module = getattr(model, key)
49 | optimizers[key] = build_optimizer(module, cfg_)
50 | return optimizers
51 |
52 | return build_optimizer(model, cfgs)
--------------------------------------------------------------------------------
/mogen/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import BaseMotionDataset
2 | from .text_motion_dataset import TextMotionDataset
3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
4 | from .pipelines import Compose
5 | from .samplers import DistributedSampler
6 |
7 |
8 | __all__ = [
9 | 'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', 'build_dataloader',
10 | 'build_dataset', 'Compose', 'DistributedSampler'
11 | ]
--------------------------------------------------------------------------------
/mogen/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | from typing import Optional, Union
4 |
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 |
8 | from .pipelines import Compose
9 | from .builder import DATASETS
10 | from mogen.core.evaluation import build_evaluator
11 |
12 |
13 | @DATASETS.register_module()
14 | class BaseMotionDataset(Dataset):
15 | """Base motion dataset.
16 | Args:
17 | data_prefix (str): the prefix of data path.
18 | pipeline (list): a list of dict, where each element represents
19 | a operation defined in `mogen.datasets.pipelines`.
20 | ann_file (str | None, optional): the annotation file. When ann_file is
21 | str, the subclass is expected to read from the ann_file. When
22 | ann_file is None, the subclass is expected to read according
23 | to data_prefix.
24 | test_mode (bool): in train mode or test mode. Default: None.
25 | dataset_name (str | None, optional): the name of dataset. It is used
26 | to identify the type of evaluation metric. Default: None.
27 | """
28 |
29 | def __init__(self,
30 | data_prefix: str,
31 | pipeline: list,
32 | dataset_name: Optional[Union[str, None]] = None,
33 | fixed_length: Optional[Union[int, None]] = None,
34 | ann_file: Optional[Union[str, None]] = None,
35 | motion_dir: Optional[Union[str, None]] = None,
36 | eval_cfg: Optional[Union[dict, None]] = None,
37 | test_mode: Optional[bool] = False):
38 | super(BaseMotionDataset, self).__init__()
39 |
40 | self.data_prefix = data_prefix
41 | self.pipeline = Compose(pipeline)
42 | self.dataset_name = dataset_name
43 | self.fixed_length = fixed_length
44 | self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file)
45 | self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir)
46 | self.eval_cfg = copy.deepcopy(eval_cfg)
47 | self.test_mode = test_mode
48 |
49 | self.load_annotations()
50 | if self.test_mode:
51 | self.prepare_evaluation()
52 |
53 | def load_anno(self, name):
54 | motion_path = os.path.join(self.motion_dir, name + '.npy')
55 | motion_data = np.load(motion_path)
56 | return {'motion': motion_data}
57 |
58 |
59 | def load_annotations(self):
60 | """Load annotations from ``ann_file`` to ``data_infos``"""
61 | self.data_infos = []
62 | for line in open(self.ann_file, 'r').readlines():
63 | line = line.strip()
64 | self.data_infos.append(self.load_anno(line))
65 |
66 |
67 | def prepare_data(self, idx: int):
68 | """"Prepare raw data for the f'{idx'}-th data."""
69 | results = copy.deepcopy(self.data_infos[idx])
70 | results['dataset_name'] = self.dataset_name
71 | results['sample_idx'] = idx
72 | return self.pipeline(results)
73 |
74 | def __len__(self):
75 | """Return the length of current dataset."""
76 | if self.test_mode:
77 | return len(self.eval_indexes)
78 | elif self.fixed_length is not None:
79 | return self.fixed_length
80 | return len(self.data_infos)
81 |
82 | def __getitem__(self, idx: int):
83 | """Prepare data for the ``idx``-th data.
84 | As for video dataset, we can first parse raw data for each frame. Then
85 | we combine annotations from all frames. This interface is used to
86 | simplify the logic of video dataset and other special datasets.
87 | """
88 | if self.test_mode:
89 | idx = self.eval_indexes[idx]
90 | elif self.fixed_length is not None:
91 | idx = idx % len(self.data_infos)
92 | return self.prepare_data(idx)
93 |
94 | def prepare_evaluation(self):
95 | self.evaluators = []
96 | self.eval_indexes = []
97 | for _ in range(self.eval_cfg['replication_times']):
98 | eval_indexes = np.arange(len(self.data_infos))
99 | if self.eval_cfg.get('shuffle_indexes', False):
100 | np.random.shuffle(eval_indexes)
101 | self.eval_indexes.append(eval_indexes)
102 | for metric in self.eval_cfg['metrics']:
103 | evaluator, self.eval_indexes = build_evaluator(
104 | metric, self.eval_cfg, len(self.data_infos), self.eval_indexes)
105 | self.evaluators.append(evaluator)
106 |
107 | self.eval_indexes = np.concatenate(self.eval_indexes)
108 |
109 | def evaluate(self, results, work_dir, logger=None):
110 | metrics = {}
111 | device = results[0]['motion'].device
112 | for evaluator in self.evaluators:
113 | evaluator.to_device(device)
114 | metrics.update(evaluator.evaluate(results))
115 | if logger is not None:
116 | logger.info(metrics)
117 | return metrics
118 |
--------------------------------------------------------------------------------
/mogen/datasets/builder.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import random
3 | from functools import partial
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | from mmcv.parallel import collate
8 | from mmcv.runner import get_dist_info
9 | from mmcv.utils import Registry, build_from_cfg
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.dataset import Dataset
12 |
13 | from .samplers import DistributedSampler
14 |
15 | if platform.system() != 'Windows':
16 | # https://github.com/pytorch/pytorch/issues/973
17 | import resource
18 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
19 | base_soft_limit = rlimit[0]
20 | hard_limit = rlimit[1]
21 | soft_limit = min(max(4096, base_soft_limit), hard_limit)
22 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
23 |
24 | DATASETS = Registry('dataset')
25 | PIPELINES = Registry('pipeline')
26 |
27 |
28 | def build_dataset(cfg: Union[dict, list, tuple],
29 | default_args: Optional[Union[dict, None]] = None):
30 | """"Build dataset by the given config."""
31 | from .dataset_wrappers import (
32 | ConcatDataset,
33 | RepeatDataset,
34 | )
35 | if isinstance(cfg, (list, tuple)):
36 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
37 | elif cfg['type'] == 'RepeatDataset':
38 | dataset = RepeatDataset(
39 | build_dataset(cfg['dataset'], default_args), cfg['times'])
40 | else:
41 | dataset = build_from_cfg(cfg, DATASETS, default_args)
42 |
43 | return dataset
44 |
45 |
46 | def build_dataloader(dataset: Dataset,
47 | samples_per_gpu: int,
48 | workers_per_gpu: int,
49 | num_gpus: Optional[int] = 1,
50 | dist: Optional[bool] = True,
51 | shuffle: Optional[bool] = True,
52 | round_up: Optional[bool] = True,
53 | seed: Optional[Union[int, None]] = None,
54 | persistent_workers: Optional[bool] = True,
55 | **kwargs):
56 | """Build PyTorch DataLoader.
57 | In distributed training, each GPU/process has a dataloader.
58 | In non-distributed training, there is only one dataloader for all GPUs.
59 | Args:
60 | dataset (:obj:`Dataset`): A PyTorch dataset.
61 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
62 | batch size of each GPU.
63 | workers_per_gpu (int): How many subprocesses to use for data loading
64 | for each GPU.
65 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed
66 | training.
67 | dist (bool, optional): Distributed training/test or not. Default: True.
68 | shuffle (bool, optional): Whether to shuffle the data at every epoch.
69 | Default: True.
70 | round_up (bool, optional): Whether to round up the length of dataset by
71 | adding extra samples to make it evenly divisible. Default: True.
72 | kwargs: any keyword argument to be used to initialize DataLoader
73 | Returns:
74 | DataLoader: A PyTorch dataloader.
75 | """
76 | rank, world_size = get_dist_info()
77 | if dist:
78 | sampler = DistributedSampler(
79 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
80 | shuffle = False
81 | batch_size = samples_per_gpu
82 | num_workers = workers_per_gpu
83 | else:
84 | sampler = None
85 | batch_size = num_gpus * samples_per_gpu
86 | num_workers = num_gpus * workers_per_gpu
87 |
88 | init_fn = partial(
89 | worker_init_fn, num_workers=num_workers, rank=rank,
90 | seed=seed) if seed is not None else None
91 |
92 | data_loader = DataLoader(
93 | dataset,
94 | batch_size=batch_size,
95 | sampler=sampler,
96 | num_workers=num_workers,
97 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
98 | pin_memory=False,
99 | shuffle=shuffle,
100 | worker_init_fn=init_fn,
101 | persistent_workers=persistent_workers,
102 | **kwargs)
103 |
104 | return data_loader
105 |
106 |
107 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
108 | """Init random seed for each worker."""
109 | # The seed of each worker equals to
110 | # num_worker * rank + worker_id + user_seed
111 | worker_seed = num_workers * rank + worker_id + seed
112 | np.random.seed(worker_seed)
113 | random.seed(worker_seed)
--------------------------------------------------------------------------------
/mogen/datasets/dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
2 | from torch.utils.data.dataset import Dataset
3 |
4 | from .builder import DATASETS
5 |
6 |
7 | @DATASETS.register_module()
8 | class ConcatDataset(_ConcatDataset):
9 | """A wrapper of concatenated dataset.
10 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
11 | add `get_cat_ids` function.
12 | Args:
13 | datasets (list[:obj:`Dataset`]): A list of datasets.
14 | """
15 |
16 | def __init__(self, datasets: list):
17 | super(ConcatDataset, self).__init__(datasets)
18 |
19 |
20 | @DATASETS.register_module()
21 | class RepeatDataset(object):
22 | """A wrapper of repeated dataset.
23 | The length of repeated dataset will be `times` larger than the original
24 | dataset. This is useful when the data loading time is long but the dataset
25 | is small. Using RepeatDataset can reduce the data loading time between
26 | epochs.
27 | Args:
28 | dataset (:obj:`Dataset`): The dataset to be repeated.
29 | times (int): Repeat times.
30 | """
31 |
32 | def __init__(self, dataset: Dataset, times: int):
33 | self.dataset = dataset
34 | self.times = times
35 |
36 | self._ori_len = len(self.dataset)
37 |
38 | def __getitem__(self, idx: int):
39 | return self.dataset[idx % self._ori_len]
40 |
41 | def __len__(self):
42 | return self.times * self._ori_len
--------------------------------------------------------------------------------
/mogen/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .compose import Compose
2 | from .formatting import (
3 | to_tensor,
4 | ToTensor,
5 | Transpose,
6 | Collect,
7 | WrapFieldsToLists
8 | )
9 | from .transforms import (
10 | Crop,
11 | RandomCrop,
12 | Normalize
13 | )
14 |
15 | __all__ = [
16 | 'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', 'ToTensor',
17 | 'Crop', 'RandomCrop', 'Normalize'
18 | ]
--------------------------------------------------------------------------------
/mogen/datasets/pipelines/compose.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | from mmcv.utils import build_from_cfg
4 |
5 | from ..builder import PIPELINES
6 |
7 |
8 | @PIPELINES.register_module()
9 | class Compose(object):
10 | """Compose a data pipeline with a sequence of transforms.
11 |
12 | Args:
13 | transforms (list[dict | callable]):
14 | Either config dicts of transforms or transform objects.
15 | """
16 |
17 | def __init__(self, transforms):
18 | assert isinstance(transforms, Sequence)
19 | self.transforms = []
20 | for transform in transforms:
21 | if isinstance(transform, dict):
22 | transform = build_from_cfg(transform, PIPELINES)
23 | self.transforms.append(transform)
24 | elif callable(transform):
25 | self.transforms.append(transform)
26 | else:
27 | raise TypeError('transform must be callable or a dict, but got'
28 | f' {type(transform)}')
29 |
30 | def __call__(self, data):
31 | for t in self.transforms:
32 | data = t(data)
33 | if data is None:
34 | return None
35 | return data
36 |
37 | def __repr__(self):
38 | format_string = self.__class__.__name__ + '('
39 | for t in self.transforms:
40 | format_string += f'\n {t}'
41 | format_string += '\n)'
42 | return format_string
--------------------------------------------------------------------------------
/mogen/datasets/pipelines/formatting.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 |
3 | import mmcv
4 | import numpy as np
5 | import torch
6 | from mmcv.parallel import DataContainer as DC
7 | from PIL import Image
8 |
9 | from ..builder import PIPELINES
10 |
11 |
12 | def to_tensor(data):
13 | """Convert objects of various python types to :obj:`torch.Tensor`.
14 |
15 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
16 | :class:`Sequence`, :class:`int` and :class:`float`.
17 | """
18 | if isinstance(data, torch.Tensor):
19 | return data
20 | elif isinstance(data, np.ndarray):
21 | return torch.from_numpy(data)
22 | elif isinstance(data, Sequence) and not mmcv.is_str(data):
23 | return torch.tensor(data)
24 | elif isinstance(data, int):
25 | return torch.LongTensor([data])
26 | elif isinstance(data, float):
27 | return torch.FloatTensor([data])
28 | else:
29 | raise TypeError(
30 | f'Type {type(data)} cannot be converted to tensor.'
31 | 'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
32 | '`Sequence`, `int` and `float`')
33 |
34 |
35 | @PIPELINES.register_module()
36 | class ToTensor(object):
37 |
38 | def __init__(self, keys):
39 | self.keys = keys
40 |
41 | def __call__(self, results):
42 | for key in self.keys:
43 | results[key] = to_tensor(results[key])
44 | return results
45 |
46 | def __repr__(self):
47 | return self.__class__.__name__ + f'(keys={self.keys})'
48 |
49 |
50 | @PIPELINES.register_module()
51 | class Transpose(object):
52 |
53 | def __init__(self, keys, order):
54 | self.keys = keys
55 | self.order = order
56 |
57 | def __call__(self, results):
58 | for key in self.keys:
59 | results[key] = results[key].transpose(self.order)
60 | return results
61 |
62 | def __repr__(self):
63 | return self.__class__.__name__ + \
64 | f'(keys={self.keys}, order={self.order})'
65 |
66 |
67 | @PIPELINES.register_module()
68 | class Collect(object):
69 | """Collect data from the loader relevant to the specific task.
70 |
71 | This is usually the last stage of the data loader pipeline.
72 |
73 | Args:
74 | keys (Sequence[str]): Keys of results to be collected in ``data``.
75 | meta_keys (Sequence[str], optional): Meta keys to be converted to
76 | ``mmcv.DataContainer`` and collected in ``data[motion_metas]``.
77 | Default: ``('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')``
78 |
79 | Returns:
80 | dict: The result dict contains the following keys
81 | - keys in``self.keys``
82 | - ``motion_metas`` if available
83 | """
84 |
85 | def __init__(self,
86 | keys,
87 | meta_keys=('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')):
88 | self.keys = keys
89 | self.meta_keys = meta_keys
90 |
91 | def __call__(self, results):
92 | data = {}
93 | motion_meta = {}
94 | for key in self.meta_keys:
95 | if key in results:
96 | motion_meta[key] = results[key]
97 | data['motion_metas'] = DC(motion_meta, cpu_only=True)
98 | for key in self.keys:
99 | data[key] = results[key]
100 | return data
101 |
102 | def __repr__(self):
103 | return self.__class__.__name__ + \
104 | f'(keys={self.keys}, meta_keys={self.meta_keys})'
105 |
106 |
107 | @PIPELINES.register_module()
108 | class WrapFieldsToLists(object):
109 | """Wrap fields of the data dictionary into lists for evaluation.
110 |
111 | This class can be used as a last step of a test or validation
112 | pipeline for single image evaluation or inference.
113 |
114 | Example:
115 | >>> test_pipeline = [
116 | >>> dict(type='LoadImageFromFile'),
117 | >>> dict(type='Normalize',
118 | mean=[123.675, 116.28, 103.53],
119 | std=[58.395, 57.12, 57.375],
120 | to_rgb=True),
121 | >>> dict(type='ImageToTensor', keys=['img']),
122 | >>> dict(type='Collect', keys=['img']),
123 | >>> dict(type='WrapIntoLists')
124 | >>> ]
125 | """
126 |
127 | def __call__(self, results):
128 | # Wrap dict fields into lists
129 | for key, val in results.items():
130 | results[key] = [val]
131 | return results
132 |
133 | def __repr__(self):
134 | return f'{self.__class__.__name__}()'
--------------------------------------------------------------------------------
/mogen/datasets/pipelines/transforms.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import mmcv
5 | import numpy as np
6 |
7 | from ..builder import PIPELINES
8 | import torch
9 | from typing import Optional, Tuple, Union
10 |
11 |
12 | @PIPELINES.register_module()
13 | class Crop(object):
14 | r"""Crop motion sequences.
15 |
16 | Args:
17 | crop_size (int): The size of the cropped motion sequence.
18 | """
19 | def __init__(self,
20 | crop_size: Optional[Union[int, None]] = None):
21 | self.crop_size = crop_size
22 | assert self.crop_size is not None
23 |
24 | def __call__(self, results):
25 | motion = results['motion']
26 | length = len(motion)
27 | if length >= self.crop_size:
28 | idx = random.randint(0, length - self.crop_size)
29 | motion = motion[idx: idx + self.crop_size]
30 | results['motion_length'] = self.crop_size
31 | else:
32 | padding_length = self.crop_size - length
33 | D = motion.shape[1:]
34 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
35 | motion = np.concatenate([motion, padding_zeros], axis=0)
36 | results['motion_length'] = length
37 | assert len(motion) == self.crop_size
38 | results['motion'] = motion
39 | results['motion_shape'] = motion.shape
40 | if length >= self.crop_size:
41 | results['motion_mask'] = torch.ones(self.crop_size).numpy()
42 | else:
43 | results['motion_mask'] = torch.cat(
44 | (torch.ones(length), torch.zeros(self.crop_size - length))).numpy()
45 | return results
46 |
47 |
48 | def __repr__(self):
49 | repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})'
50 | return repr_str
51 |
52 | @PIPELINES.register_module()
53 | class RandomCrop(object):
54 | r"""Random crop motion sequences. Each sequence will be padded with zeros to the maximum length.
55 |
56 | Args:
57 | min_size (int or None): The minimum size of the cropped motion sequence (inclusive).
58 | max_size (int or None): The maximum size of the cropped motion sequence (inclusive).
59 | """
60 | def __init__(self,
61 | min_size: Optional[Union[int, None]] = None,
62 | max_size: Optional[Union[int, None]] = None):
63 | self.min_size = min_size
64 | self.max_size = max_size
65 | assert self.min_size is not None
66 | assert self.max_size is not None
67 |
68 | def __call__(self, results):
69 | motion = results['motion']
70 | length = len(motion)
71 | crop_size = random.randint(self.min_size, self.max_size)
72 | if length > crop_size:
73 | idx = random.randint(0, length - crop_size)
74 | motion = motion[idx: idx + crop_size]
75 | results['motion_length'] = crop_size
76 | else:
77 | results['motion_length'] = length
78 | padding_length = self.max_size - min(crop_size, length)
79 | if padding_length > 0:
80 | D = motion.shape[1:]
81 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
82 | motion = np.concatenate([motion, padding_zeros], axis=0)
83 | results['motion'] = motion
84 | results['motion_shape'] = motion.shape
85 | if length >= self.max_size and crop_size == self.max_size:
86 | results['motion_mask'] = torch.ones(self.max_size).numpy()
87 | else:
88 | results['motion_mask'] = torch.cat((
89 | torch.ones(min(length, crop_size)),
90 | torch.zeros(self.max_size - min(length, crop_size))), dim=0).numpy()
91 | assert len(motion) == self.max_size
92 | return results
93 |
94 |
95 | def __repr__(self):
96 | repr_str = self.__class__.__name__ + f'(min_size={self.min_size}'
97 | repr_str += f', max_size={self.max_size})'
98 | return repr_str
99 |
100 | @PIPELINES.register_module()
101 | class Normalize(object):
102 | """Normalize motion sequences.
103 |
104 | Args:
105 | mean_path (str): Path of mean file.
106 | std_path (str): Path of std file.
107 | """
108 |
109 | def __init__(self, mean_path, std_path, eps=1e-9):
110 | self.mean = np.load(mean_path)
111 | self.std = np.load(std_path)
112 | self.eps = eps
113 |
114 | def __call__(self, results):
115 | motion = results['motion']
116 | motion = (motion - self.mean) / (self.std + self.eps)
117 | results['motion'] = motion
118 | results['motion_norm_mean'] = self.mean
119 | results['motion_norm_std'] = self.std
120 | return results
121 |
--------------------------------------------------------------------------------
/mogen/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed_sampler import DistributedSampler
2 |
3 | __all__ = ['DistributedSampler']
--------------------------------------------------------------------------------
/mogen/datasets/samplers/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DistributedSampler as _DistributedSampler
3 |
4 |
5 | class DistributedSampler(_DistributedSampler):
6 |
7 | def __init__(self,
8 | dataset,
9 | num_replicas=None,
10 | rank=None,
11 | shuffle=True,
12 | round_up=True):
13 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
14 | self.shuffle = shuffle
15 | self.round_up = round_up
16 | if self.round_up:
17 | self.total_size = self.num_samples * self.num_replicas
18 | else:
19 | self.total_size = len(self.dataset)
20 |
21 | def __iter__(self):
22 | # deterministically shuffle based on epoch
23 | if self.shuffle:
24 | g = torch.Generator()
25 | g.manual_seed(self.epoch)
26 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
27 | else:
28 | indices = torch.arange(len(self.dataset)).tolist()
29 |
30 | # add extra samples to make it evenly divisible
31 | if self.round_up:
32 | indices = (
33 | indices *
34 | int(self.total_size / len(indices) + 1))[:self.total_size]
35 | assert len(indices) == self.total_size
36 |
37 | # subsample
38 | indices = indices[self.rank:self.total_size:self.num_replicas]
39 | if self.round_up:
40 | assert len(indices) == self.num_samples
41 |
42 | return iter(indices)
43 |
--------------------------------------------------------------------------------
/mogen/datasets/text_motion_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path
4 | from abc import ABCMeta
5 | from collections import OrderedDict
6 | from typing import Any, List, Optional, Union
7 |
8 | import mmcv
9 | import copy
10 | import numpy as np
11 | import torch
12 | import torch.distributed as dist
13 | from mmcv.runner import get_dist_info
14 |
15 | from .base_dataset import BaseMotionDataset
16 | from .builder import DATASETS
17 |
18 |
19 | @DATASETS.register_module()
20 | class TextMotionDataset(BaseMotionDataset):
21 | """TextMotion dataset.
22 |
23 | Args:
24 | text_dir (str): Path to the directory containing the text files.
25 | """
26 | def __init__(self,
27 | data_prefix: str,
28 | pipeline: list,
29 | dataset_name: Optional[Union[str, None]] = None,
30 | fixed_length: Optional[Union[int, None]] = None,
31 | ann_file: Optional[Union[str, None]] = None,
32 | motion_dir: Optional[Union[str, None]] = None,
33 | text_dir: Optional[Union[str, None]] = None,
34 | token_dir: Optional[Union[str, None]] = None,
35 | clip_feat_dir: Optional[Union[str, None]] = None,
36 | eval_cfg: Optional[Union[dict, None]] = None,
37 | fine_mode: Optional[bool] = False,
38 | test_mode: Optional[bool] = False):
39 | self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir)
40 | if token_dir is not None:
41 | self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir)
42 | else:
43 | self.token_dir = None
44 | if clip_feat_dir is not None:
45 | self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir)
46 | else:
47 | self.clip_feat_dir = None
48 | self.fine_mode = fine_mode
49 | super(TextMotionDataset, self).__init__(
50 | data_prefix=data_prefix,
51 | pipeline=pipeline,
52 | dataset_name=dataset_name,
53 | fixed_length=fixed_length,
54 | ann_file=ann_file,
55 | motion_dir=motion_dir,
56 | eval_cfg=eval_cfg,
57 | test_mode=test_mode)
58 |
59 | def load_anno(self, name):
60 | results = super().load_anno(name)
61 | text_path = os.path.join(self.text_dir, name + '.txt')
62 | text_data = []
63 | for line in open(text_path, 'r'):
64 | text_data.append(line.strip())
65 | results['text'] = text_data
66 | if self.token_dir is not None:
67 | token_path = os.path.join(self.token_dir, name + '.txt')
68 | token_data = []
69 | for line in open(token_path, 'r'):
70 | token_data.append(line.strip())
71 | results['token'] = token_data
72 | if self.clip_feat_dir is not None:
73 | clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy')
74 | clip_feat = torch.from_numpy(np.load(clip_feat_path))
75 | results['clip_feat'] = clip_feat
76 | return results
77 |
78 | def prepare_data(self, idx: int):
79 | """"Prepare raw data for the f'{idx'}-th data."""
80 | results = copy.deepcopy(self.data_infos[idx])
81 | text_list = results['text']
82 | idx = np.random.randint(0, len(text_list))
83 | if self.fine_mode:
84 | results['text'] = json.loads(text_list[idx])
85 | else:
86 | results['text'] = text_list[idx]
87 | if 'clip_feat' in results.keys():
88 | results['clip_feat'] = results['clip_feat'][idx]
89 | if 'token' in results.keys():
90 | results['token'] = results['token'][idx]
91 | results['dataset_name'] = self.dataset_name
92 | results['sample_idx'] = idx
93 | return self.pipeline(results)
94 |
--------------------------------------------------------------------------------
/mogen/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .architectures import *
2 | from .losses import *
3 | from .rnns import *
4 | from .transformers import *
5 | from .attentions import *
6 | from .builder import *
7 | from .utils import *
--------------------------------------------------------------------------------
/mogen/models/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | from .vae_architecture import MotionVAE
2 | from .diffusion_architecture import MotionDiffusion
3 |
4 | __all__ = [
5 | 'MotionVAE', 'MotionDiffusion'
6 | ]
--------------------------------------------------------------------------------
/mogen/models/architectures/base_architecture.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.distributed as dist
6 | from mmcv.runner import BaseModule
7 |
8 |
9 | def to_cpu(x):
10 | if isinstance(x, torch.Tensor):
11 | return x.detach().cpu()
12 | return x
13 |
14 |
15 | class BaseArchitecture(BaseModule):
16 | """Base class for mogen architecture."""
17 |
18 | def __init__(self, init_cfg=None):
19 | super(BaseArchitecture, self).__init__(init_cfg)
20 |
21 | def forward_train(self, **kwargs):
22 | pass
23 |
24 | def forward_test(self, **kwargs):
25 | pass
26 |
27 | def _parse_losses(self, losses):
28 | """Parse the raw outputs (losses) of the network.
29 | Args:
30 | losses (dict): Raw output of the network, which usually contain
31 | losses and other necessary information.
32 | Returns:
33 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
34 | which may be a weighted sum of all losses, log_vars contains \
35 | all the variables to be sent to the logger.
36 | """
37 | log_vars = OrderedDict()
38 | for loss_name, loss_value in losses.items():
39 | if isinstance(loss_value, torch.Tensor):
40 | log_vars[loss_name] = loss_value.mean()
41 | elif isinstance(loss_value, list):
42 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
43 | else:
44 | raise TypeError(
45 | f'{loss_name} is not a tensor or list of tensors')
46 |
47 | loss = sum(_value for _key, _value in log_vars.items()
48 | if 'loss' in _key)
49 |
50 | log_vars['loss'] = loss
51 | for loss_name, loss_value in log_vars.items():
52 | # reduce loss when distributed training
53 | if dist.is_available() and dist.is_initialized():
54 | loss_value = loss_value.data.clone()
55 | dist.all_reduce(loss_value.div_(dist.get_world_size()))
56 | log_vars[loss_name] = loss_value.item()
57 |
58 | return loss, log_vars
59 |
60 | def train_step(self, data, optimizer):
61 | """The iteration step during training.
62 | This method defines an iteration step during training, except for the
63 | back propagation and optimizer updating, which are done in an optimizer
64 | hook. Note that in some complicated cases or models, the whole process
65 | including back propagation and optimizer updating is also defined in
66 | this method, such as GAN.
67 | Args:
68 | data (dict): The output of dataloader.
69 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
70 | runner is passed to ``train_step()``. This argument is unused
71 | and reserved.
72 | Returns:
73 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
74 | ``num_samples``.
75 | - ``loss`` is a tensor for back propagation, which can be a
76 | weighted sum of multiple losses.
77 | - ``log_vars`` contains all the variables to be sent to the
78 | logger.
79 | - ``num_samples`` indicates the batch size (when the model is
80 | DDP, it means the batch size on each GPU), which is used for
81 | averaging the logs.
82 | """
83 | losses = self(**data)
84 | loss, log_vars = self._parse_losses(losses)
85 |
86 | outputs = dict(
87 | loss=loss, log_vars=log_vars, num_samples=len(data['motion']))
88 |
89 | return outputs
90 |
91 | def val_step(self, data, optimizer=None):
92 | """The iteration step during validation.
93 | This method shares the same signature as :func:`train_step`, but used
94 | during val epochs. Note that the evaluation after training epochs is
95 | not implemented with this method, but an evaluation hook.
96 | """
97 | losses = self(**data)
98 | loss, log_vars = self._parse_losses(losses)
99 |
100 | outputs = dict(
101 | loss=loss, log_vars=log_vars, num_samples=len(data['motion']))
102 |
103 | return outputs
104 |
105 | def forward(self, **kwargs):
106 | if self.training:
107 | return self.forward_train(**kwargs)
108 | else:
109 | return self.forward_test(**kwargs)
110 |
111 | def split_results(self, results):
112 | B = results['motion'].shape[0]
113 | output = []
114 | for i in range(B):
115 | batch_output = dict()
116 | batch_output['motion'] = to_cpu(results['motion'][i])
117 | batch_output['pred_motion'] = to_cpu(results['pred_motion'][i])
118 | batch_output['motion_length'] = to_cpu(results['motion_length'][i])
119 | batch_output['motion_mask'] = to_cpu(results['motion_mask'][i])
120 | if 'pred_motion_length' in results.keys():
121 | batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i])
122 | else:
123 | batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i])
124 | if 'pred_motion_mask' in results:
125 | batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i])
126 | else:
127 | batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i])
128 | if 'motion_metas' in results.keys():
129 | motion_metas = results['motion_metas'][i]
130 | if 'text' in motion_metas.keys():
131 | batch_output['text'] = motion_metas['text']
132 | if 'token' in motion_metas.keys():
133 | batch_output['token'] = motion_metas['token']
134 | output.append(batch_output)
135 | return output
136 |
--------------------------------------------------------------------------------
/mogen/models/architectures/diffusion_architecture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_architecture import BaseArchitecture
6 | from ..builder import (
7 | ARCHITECTURES,
8 | build_architecture,
9 | build_submodule,
10 | build_loss
11 | )
12 | from ..utils.gaussian_diffusion import (
13 | GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler,
14 | ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion
15 | )
16 |
17 | def build_diffusion(cfg):
18 | beta_scheduler = cfg['beta_scheduler']
19 | diffusion_steps = cfg['diffusion_steps']
20 |
21 | betas = get_named_beta_schedule(beta_scheduler, diffusion_steps)
22 | model_mean_type = {
23 | 'start_x': ModelMeanType.START_X,
24 | 'previous_x': ModelMeanType.PREVIOUS_X,
25 | 'epsilon': ModelMeanType.EPSILON
26 | }[cfg['model_mean_type']]
27 | model_var_type = {
28 | 'learned': ModelVarType.LEARNED,
29 | 'fixed_small': ModelVarType.FIXED_SMALL,
30 | 'fixed_large': ModelVarType.FIXED_LARGE,
31 | 'learned_range': ModelVarType.LEARNED_RANGE
32 | }[cfg['model_var_type']]
33 | if cfg.get('respace', None) is not None:
34 | diffusion = SpacedDiffusion(
35 | use_timesteps=space_timesteps(diffusion_steps, cfg['respace']),
36 | betas=betas,
37 | model_mean_type=model_mean_type,
38 | model_var_type=model_var_type,
39 | loss_type=LossType.MSE
40 | )
41 | else:
42 | diffusion = GaussianDiffusion(
43 | betas=betas,
44 | model_mean_type=model_mean_type,
45 | model_var_type=model_var_type,
46 | loss_type=LossType.MSE)
47 | return diffusion
48 |
49 |
50 | @ARCHITECTURES.register_module()
51 | class MotionDiffusion(BaseArchitecture):
52 |
53 | def __init__(self,
54 | model=None,
55 | loss_recon=None,
56 | diffusion_train=None,
57 | diffusion_test=None,
58 | init_cfg=None,
59 | inference_type='ddpm',
60 | **kwargs):
61 | super().__init__(init_cfg=init_cfg, **kwargs)
62 | self.model = build_submodule(model)
63 | self.loss_recon = build_loss(loss_recon)
64 | self.diffusion_train = build_diffusion(diffusion_train)
65 | self.diffusion_test = build_diffusion(diffusion_test)
66 | self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train)
67 | self.inference_type = inference_type
68 |
69 | def forward(self, **kwargs):
70 | motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float()
71 | sample_idx = kwargs.get('sample_idx', None)
72 | clip_feat = kwargs.get('clip_feat', None)
73 | B, T = motion.shape[:2]
74 | text = []
75 | for i in range(B):
76 | text.append(kwargs['motion_metas'][i]['text'])
77 |
78 | if self.training:
79 | t, _ = self.sampler.sample(B, motion.device)
80 | output = self.diffusion_train.training_losses(
81 | model=self.model,
82 | x_start=motion,
83 | t=t,
84 | model_kwargs={
85 | 'motion_mask': motion_mask,
86 | 'motion_length': kwargs['motion_length'],
87 | 'text': text,
88 | 'clip_feat': clip_feat,
89 | 'sample_idx': sample_idx}
90 | )
91 | pred, target = output['pred'], output['target']
92 | recon_loss = self.loss_recon(pred, target, reduction_override='none')
93 | recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum()
94 | loss = {'recon_loss': recon_loss}
95 | return loss
96 | else:
97 | dim_pose = kwargs['motion'].shape[-1]
98 | model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs)
99 | model_kwargs['motion_mask'] = motion_mask
100 | model_kwargs['sample_idx'] = sample_idx
101 | inference_kwargs = kwargs.get('inference_kwargs', {})
102 | if self.inference_type == 'ddpm':
103 | output = self.diffusion_test.p_sample_loop(
104 | self.model,
105 | (B, T, dim_pose),
106 | clip_denoised=False,
107 | progress=False,
108 | model_kwargs=model_kwargs,
109 | **inference_kwargs
110 | )
111 | else:
112 | output = self.diffusion_test.ddim_sample_loop(
113 | self.model,
114 | (B, T, dim_pose),
115 | clip_denoised=False,
116 | progress=False,
117 | model_kwargs=model_kwargs,
118 | eta=0,
119 | **inference_kwargs
120 | )
121 | if getattr(self.model, "post_process") is not None:
122 | output = self.model.post_process(output)
123 | results = kwargs
124 | results['pred_motion'] = output
125 | results = self.split_results(results)
126 | return results
127 |
128 |
--------------------------------------------------------------------------------
/mogen/models/architectures/vae_architecture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_architecture import BaseArchitecture
6 | from ..builder import (
7 | ARCHITECTURES,
8 | build_architecture,
9 | build_submodule,
10 | build_loss
11 | )
12 |
13 |
14 | @ARCHITECTURES.register_module()
15 | class PoseVAE(BaseArchitecture):
16 |
17 | def __init__(self,
18 | encoder=None,
19 | decoder=None,
20 | loss_recon=None,
21 | kl_div_loss_weight=None,
22 | init_cfg=None,
23 | **kwargs):
24 | super().__init__(init_cfg=init_cfg, **kwargs)
25 | self.encoder = build_submodule(encoder)
26 | self.decoder = build_submodule(decoder)
27 | self.loss_recon = build_loss(loss_recon)
28 | self.kl_div_loss_weight = kl_div_loss_weight
29 |
30 | def reparameterize(self, mu, logvar):
31 | std = torch.exp(logvar / 2)
32 |
33 | eps = std.data.new(std.size()).normal_()
34 | latent_code = eps.mul(std).add_(mu)
35 | return latent_code
36 |
37 | def encode(self, pose):
38 | mu, logvar = self.encoder(pose)
39 | return mu
40 |
41 | def forward(self, **kwargs):
42 | motion = kwargs['motion'].float()
43 | B, T = motion.shape[:2]
44 | pose = motion.reshape(B * T, -1)
45 | pose = pose[:, :-4]
46 |
47 | mu, logvar = self.encoder(pose)
48 | z = self.reparameterize(mu, logvar)
49 | pred = self.decoder(z)
50 |
51 | loss = dict()
52 | recon_loss = self.loss_recon(pred, pose, reduction_override='none')
53 | loss['recon_loss'] = recon_loss
54 | if self.kl_div_loss_weight is not None:
55 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
56 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)
57 |
58 | return loss
59 |
60 |
61 | @ARCHITECTURES.register_module()
62 | class MotionVAE(BaseArchitecture):
63 |
64 | def __init__(self,
65 | encoder=None,
66 | decoder=None,
67 | loss_recon=None,
68 | kl_div_loss_weight=None,
69 | init_cfg=None,
70 | **kwargs):
71 | super().__init__(init_cfg=init_cfg, **kwargs)
72 | self.encoder = build_submodule(encoder)
73 | self.decoder = build_submodule(decoder)
74 | self.loss_recon = build_loss(loss_recon)
75 | self.kl_div_loss_weight = kl_div_loss_weight
76 |
77 | def sample(self, std=1, latent_code=None):
78 | if latent_code is not None:
79 | z = latent_code
80 | else:
81 | z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std
82 | output = self.decoder(z)
83 | if self.use_normalization:
84 | output = output * self.motion_std
85 | output = output + self.motion_mean
86 | return output
87 |
88 | def reparameterize(self, mu, logvar):
89 | std = torch.exp(logvar / 2)
90 |
91 | eps = std.data.new(std.size()).normal_()
92 | latent_code = eps.mul(std).add_(mu)
93 | return latent_code
94 |
95 | def encode(self, motion, motion_mask):
96 | mu, logvar = self.encoder(motion, motion_mask)
97 | return self.reparameterize(mu, logvar)
98 |
99 | def decode(self, z, motion_mask):
100 | return self.decoder(z, motion_mask)
101 |
102 | def forward(self, **kwargs):
103 | motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask']
104 | B, T = motion.shape[:2]
105 |
106 | mu, logvar = self.encoder(motion, motion_mask)
107 | z = self.reparameterize(mu, logvar)
108 | pred = self.decoder(z, motion_mask)
109 |
110 | loss = dict()
111 | recon_loss = self.loss_recon(pred, motion, reduction_override='none')
112 | recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum()
113 | loss['recon_loss'] = recon_loss
114 | if self.kl_div_loss_weight is not None:
115 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
116 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight)
117 |
118 | return loss
--------------------------------------------------------------------------------
/mogen/models/attentions/__init__.py:
--------------------------------------------------------------------------------
1 | from .efficient_attention import (
2 | EfficientSelfAttention,
3 | EfficientCrossAttention
4 | )
5 | from .semantics_modulated import SemanticsModulatedAttention
6 | from .base_attention import BaseMixedAttention
--------------------------------------------------------------------------------
/mogen/models/attentions/base_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ..utils.stylization_block import StylizationBlock
5 | from ..builder import ATTENTIONS
6 |
7 |
8 | @ATTENTIONS.register_module()
9 | class BaseMixedAttention(nn.Module):
10 |
11 | def __init__(self, latent_dim,
12 | text_latent_dim,
13 | num_heads,
14 | dropout,
15 | time_embed_dim):
16 | super().__init__()
17 | self.num_heads = num_heads
18 |
19 | self.norm = nn.LayerNorm(latent_dim)
20 | self.text_norm = nn.LayerNorm(text_latent_dim)
21 |
22 | self.query = nn.Linear(latent_dim, latent_dim)
23 | self.key_text = nn.Linear(text_latent_dim, latent_dim)
24 | self.value_text = nn.Linear(text_latent_dim, latent_dim)
25 | self.key_motion = nn.Linear(latent_dim, latent_dim)
26 | self.value_motion = nn.Linear(latent_dim, latent_dim)
27 |
28 | self.dropout = nn.Dropout(dropout)
29 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
30 |
31 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs):
32 | """
33 | x: B, T, D
34 | xf: B, N, L
35 | """
36 | B, T, D = x.shape
37 | N = xf.shape[1] + x.shape[1]
38 | H = self.num_heads
39 | # B, T, D
40 | query = self.query(self.norm(x)).view(B, T, H, -1)
41 | # B, N, D
42 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1)
43 | key = torch.cat((
44 | self.key_text(self.text_norm(xf)),
45 | self.key_motion(self.norm(x))
46 | ), dim=1).view(B, N, H, -1)
47 |
48 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key)
49 | motion_mask = src_mask.view(B, 1, T, 1)
50 | text_mask = text_cond_type.view(B, 1, -1, 1)
51 | mask = torch.cat((text_mask, motion_mask), dim=2)
52 | attention = attention + (1 - mask) * -1000000
53 | attention = F.softmax(attention, dim=2)
54 |
55 | value = torch.cat((
56 | self.value_text(self.text_norm(xf)) * text_cond_type,
57 | self.value_motion(self.norm(x)) * src_mask,
58 | ), dim=1).view(B, N, H, -1)
59 |
60 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D)
61 | y = x + self.proj_out(y, emb)
62 | return y
63 |
64 |
65 | @ATTENTIONS.register_module()
66 | class BaseSelfAttention(nn.Module):
67 |
68 | def __init__(self, latent_dim,
69 | num_heads,
70 | dropout,
71 | time_embed_dim):
72 | super().__init__()
73 | self.num_heads = num_heads
74 |
75 | self.norm = nn.LayerNorm(latent_dim)
76 | self.query = nn.Linear(latent_dim, latent_dim)
77 | self.key = nn.Linear(latent_dim, latent_dim)
78 | self.value = nn.Linear(latent_dim, latent_dim)
79 |
80 | self.dropout = nn.Dropout(dropout)
81 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
82 |
83 | def forward(self, x, emb, src_mask, **kwargs):
84 | """
85 | x: B, T, D
86 | """
87 | B, T, D = x.shape
88 | H = self.num_heads
89 | # B, T, D
90 | query = self.query(self.norm(x)).view(B, T, H, -1)
91 | # B, N, D
92 | key = self.key(self.norm(x)).view(B, T, H, -1)
93 |
94 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key)
95 | mask = src_mask.view(B, 1, T, 1)
96 | attention = attention + (1 - mask) * -1000000
97 | attention = F.softmax(attention, dim=2)
98 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1)
99 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D)
100 | y = x + self.proj_out(y, emb)
101 | return y
102 |
103 |
104 | @ATTENTIONS.register_module()
105 | class BaseCrossAttention(nn.Module):
106 |
107 | def __init__(self, latent_dim,
108 | text_latent_dim,
109 | num_heads,
110 | dropout,
111 | time_embed_dim):
112 | super().__init__()
113 | self.num_heads = num_heads
114 |
115 | self.norm = nn.LayerNorm(latent_dim)
116 | self.text_norm = nn.LayerNorm(text_latent_dim)
117 |
118 | self.query = nn.Linear(latent_dim, latent_dim)
119 | self.key = nn.Linear(text_latent_dim, latent_dim)
120 | self.value = nn.Linear(text_latent_dim, latent_dim)
121 |
122 | self.dropout = nn.Dropout(dropout)
123 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
124 |
125 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs):
126 | """
127 | x: B, T, D
128 | xf: B, N, L
129 | """
130 | B, T, D = x.shape
131 | N = xf.shape[1]
132 | H = self.num_heads
133 | # B, T, D
134 | query = self.query(self.norm(x)).view(B, T, H, -1)
135 | # B, N, D
136 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1)
137 | key = self.key(self.text_norm(xf)).view(B, N, H, -1)
138 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key)
139 | mask = text_cond_type.view(B, 1, -1, 1)
140 | attention = attention + (1 - mask) * -1000000
141 | attention = F.softmax(attention, dim=2)
142 |
143 | value = (self.value(self.text_norm(xf)) * text_cond_type).view(B, N, H, -1)
144 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D)
145 | y = x + self.proj_out(y, emb)
146 | return y
147 |
--------------------------------------------------------------------------------
/mogen/models/attentions/efficient_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ..utils.stylization_block import StylizationBlock
5 | from ..builder import ATTENTIONS
6 |
7 |
8 | @ATTENTIONS.register_module()
9 | class EfficientSelfAttention(nn.Module):
10 |
11 | def __init__(self, latent_dim, num_heads, dropout, time_embed_dim=None):
12 | super().__init__()
13 | self.num_heads = num_heads
14 | self.norm = nn.LayerNorm(latent_dim)
15 | self.query = nn.Linear(latent_dim, latent_dim)
16 | self.key = nn.Linear(latent_dim, latent_dim)
17 | self.value = nn.Linear(latent_dim, latent_dim)
18 | self.dropout = nn.Dropout(dropout)
19 | self.time_embed_dim = time_embed_dim
20 | if time_embed_dim is not None:
21 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
22 |
23 | def forward(self, x, src_mask, emb=None, **kwargs):
24 | """
25 | x: B, T, D
26 | """
27 | B, T, D = x.shape
28 | H = self.num_heads
29 | # B, T, D
30 | query = self.query(self.norm(x))
31 | # B, T, D
32 | key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000)
33 | query = F.softmax(query.view(B, T, H, -1), dim=-1)
34 | key = F.softmax(key.view(B, T, H, -1), dim=1)
35 | # B, T, H, HD
36 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1)
37 | # B, H, HD, HD
38 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
39 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
40 | if self.time_embed_dim is None:
41 | y = x + y
42 | else:
43 | y = x + self.proj_out(y, emb)
44 | return y
45 |
46 |
47 | @ATTENTIONS.register_module()
48 | class EfficientCrossAttention(nn.Module):
49 |
50 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, time_embed_dim):
51 | super().__init__()
52 | self.num_heads = num_heads
53 | self.norm = nn.LayerNorm(latent_dim)
54 | self.text_norm = nn.LayerNorm(text_latent_dim)
55 | self.query = nn.Linear(latent_dim, latent_dim)
56 | self.key = nn.Linear(text_latent_dim, latent_dim)
57 | self.value = nn.Linear(text_latent_dim, latent_dim)
58 | self.dropout = nn.Dropout(dropout)
59 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
60 |
61 | def forward(self, x, xf, emb, cond_type=None, **kwargs):
62 | """
63 | x: B, T, D
64 | xf: B, N, L
65 | """
66 | B, T, D = x.shape
67 | N = xf.shape[1]
68 | H = self.num_heads
69 | # B, T, D
70 | query = self.query(self.norm(x))
71 | # B, N, D
72 | key = self.key(self.text_norm(xf))
73 | query = F.softmax(query.view(B, T, H, -1), dim=-1)
74 | if cond_type is None:
75 | key = F.softmax(key.view(B, N, H, -1), dim=1)
76 | # B, N, H, HD
77 | value = self.value(self.text_norm(xf)).view(B, N, H, -1)
78 | else:
79 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1)
80 | key = key + (1 - text_cond_type) * -1000000
81 | key = F.softmax(key.view(B, N, H, -1), dim=1)
82 | value = self.value(self.text_norm(xf) * text_cond_type).view(B, N, H, -1)
83 | # B, H, HD, HD
84 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
85 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
86 | y = x + self.proj_out(y, emb)
87 | return y
88 |
--------------------------------------------------------------------------------
/mogen/models/attentions/semantics_modulated.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ..utils.stylization_block import StylizationBlock
5 | from ..builder import ATTENTIONS
6 |
7 |
8 | def zero_module(module):
9 | """
10 | Zero out the parameters of a module and return it.
11 | """
12 | for p in module.parameters():
13 | p.detach().zero_()
14 | return module
15 |
16 |
17 | @ATTENTIONS.register_module()
18 | class SemanticsModulatedAttention(nn.Module):
19 |
20 | def __init__(self, latent_dim,
21 | text_latent_dim,
22 | num_heads,
23 | dropout,
24 | time_embed_dim):
25 | super().__init__()
26 | self.num_heads = num_heads
27 |
28 | self.norm = nn.LayerNorm(latent_dim)
29 | self.text_norm = nn.LayerNorm(text_latent_dim)
30 |
31 | self.query = nn.Linear(latent_dim, latent_dim)
32 | self.key_text = nn.Linear(text_latent_dim, latent_dim)
33 | self.value_text = nn.Linear(text_latent_dim, latent_dim)
34 | self.key_motion = nn.Linear(latent_dim, latent_dim)
35 | self.value_motion = nn.Linear(latent_dim, latent_dim)
36 |
37 | self.retr_norm1 = nn.LayerNorm(2 * latent_dim)
38 | self.retr_norm2 = nn.LayerNorm(latent_dim)
39 | self.key_retr = nn.Linear(2 * latent_dim, latent_dim)
40 | self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim))
41 |
42 | self.dropout = nn.Dropout(dropout)
43 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
44 |
45 | def forward(self, x, xf, emb, src_mask, cond_type, re_dict=None):
46 | """
47 | x: B, T, D
48 | xf: B, N, L
49 | """
50 | B, T, D = x.shape
51 | re_motion = re_dict['re_motion']
52 | re_text = re_dict['re_text']
53 | re_mask = re_dict['re_mask']
54 | re_mask = re_mask.reshape(B, -1, 1)
55 | N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2]
56 | H = self.num_heads
57 | # B, T, D
58 | query = self.query(self.norm(x))
59 | # B, N, D
60 | text_cond_type = (cond_type % 10 > 0).float()
61 | retr_cond_type = (cond_type // 10 > 0).float()
62 | re_text = re_text.repeat(1, 1, re_motion.shape[2], 1)
63 | re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D)
64 | key = torch.cat((
65 | self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000,
66 | self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000,
67 | self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000
68 | ), dim=1)
69 | query = F.softmax(query.view(B, T, H, -1), dim=-1)
70 | key = F.softmax(key.view(B, N, H, -1), dim=1)
71 | # B, N, H, HD
72 | re_feat_value = re_motion.reshape(B, -1, D)
73 | value = torch.cat((
74 | self.value_text(self.text_norm(xf)) * text_cond_type,
75 | self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask,
76 | self.value_motion(self.norm(x)) * src_mask,
77 | ), dim=1).view(B, N, H, -1)
78 | # B, H, HD, HD
79 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
80 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
81 | y = x + self.proj_out(y, emb)
82 | return y
--------------------------------------------------------------------------------
/mogen/models/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.cnn import MODELS as MMCV_MODELS
2 | from mmcv.utils import Registry
3 |
4 |
5 | def build_from_cfg(cfg, registry, default_args=None):
6 | if cfg is None:
7 | return None
8 | return MMCV_MODELS.build_func(cfg, registry, default_args)
9 |
10 |
11 | MODELS = Registry('models', parent=MMCV_MODELS, build_func=build_from_cfg)
12 |
13 | LOSSES = MODELS
14 | ARCHITECTURES = MODELS
15 | SUBMODULES = MODELS
16 | ATTENTIONS = MODELS
17 |
18 | def build_loss(cfg):
19 | """Build loss."""
20 | return LOSSES.build(cfg)
21 |
22 | def build_architecture(cfg):
23 | """Build framework."""
24 | return ARCHITECTURES.build(cfg)
25 |
26 | def build_submodule(cfg):
27 | """Build submodule."""
28 | return SUBMODULES.build(cfg)
29 |
30 | def build_attention(cfg):
31 | """Build attention."""
32 | return ATTENTIONS.build(cfg)
33 |
--------------------------------------------------------------------------------
/mogen/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .mse_loss import MSELoss
2 | from .utils import (
3 | convert_to_one_hot,
4 | reduce_loss,
5 | weight_reduce_loss,
6 | weighted_loss,
7 | )
8 |
9 |
10 | __all__ = [
11 | 'convert_to_one_hot', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss',
12 | 'MSELoss'
13 | ]
--------------------------------------------------------------------------------
/mogen/models/losses/mse_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from ..builder import LOSSES
5 | from .utils import weighted_loss
6 |
7 |
8 | def gmof(x, sigma):
9 | """Geman-McClure error function."""
10 | x_squared = x**2
11 | sigma_squared = sigma**2
12 | return (sigma_squared * x_squared) / (sigma_squared + x_squared)
13 |
14 |
15 | @weighted_loss
16 | def mse_loss(pred, target):
17 | """Warpper of mse loss."""
18 | return F.mse_loss(pred, target, reduction='none')
19 |
20 |
21 | @weighted_loss
22 | def mse_loss_with_gmof(pred, target, sigma):
23 | """Extended MSE Loss with GMOF."""
24 | loss = F.mse_loss(pred, target, reduction='none')
25 | loss = gmof(loss, sigma)
26 | return loss
27 |
28 |
29 | @LOSSES.register_module()
30 | class MSELoss(nn.Module):
31 | """MSELoss.
32 | Args:
33 | reduction (str, optional): The method that reduces the loss to a
34 | scalar. Options are "none", "mean" and "sum".
35 | loss_weight (float, optional): The weight of the loss. Defaults to 1.0
36 | """
37 |
38 | def __init__(self, reduction='mean', loss_weight=1.0):
39 | super().__init__()
40 | assert reduction in (None, 'none', 'mean', 'sum')
41 | reduction = 'none' if reduction is None else reduction
42 | self.reduction = reduction
43 | self.loss_weight = loss_weight
44 |
45 | def forward(self,
46 | pred,
47 | target,
48 | weight=None,
49 | avg_factor=None,
50 | reduction_override=None):
51 | """Forward function of loss.
52 | Args:
53 | pred (torch.Tensor): The prediction.
54 | target (torch.Tensor): The learning target of the prediction.
55 | weight (torch.Tensor, optional): Weight of the loss for each
56 | prediction. Defaults to None.
57 | avg_factor (int, optional): Average factor that is used to average
58 | the loss. Defaults to None.
59 | reduction_override (str, optional): The reduction method used to
60 | override the original reduction method of the loss.
61 | Defaults to None.
62 | Returns:
63 | torch.Tensor: The calculated loss
64 | """
65 | assert reduction_override in (None, 'none', 'mean', 'sum')
66 | reduction = (
67 | reduction_override if reduction_override else self.reduction)
68 | loss = self.loss_weight * mse_loss(
69 | pred, target, weight, reduction=reduction, avg_factor=avg_factor)
70 | return loss
--------------------------------------------------------------------------------
/mogen/models/losses/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def reduce_loss(loss, reduction):
8 | """Reduce loss as specified.
9 | Args:
10 | loss (Tensor): Elementwise loss tensor.
11 | reduction (str): Options are "none", "mean" and "sum".
12 | Return:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | elif reduction_enum == 2:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
26 | """Apply element-wise weight and reduce loss.
27 | Args:
28 | loss (Tensor): Element-wise loss.
29 | weight (Tensor): Element-wise weights.
30 | reduction (str): Same as built-in losses of PyTorch.
31 | avg_factor (float): Average factor when computing the mean of losses.
32 | Returns:
33 | Tensor: Processed loss values.
34 | """
35 | # if weight is specified, apply element-wise weight
36 | if weight is not None:
37 | loss = loss * weight
38 |
39 | # if avg_factor is not specified, just reduce the loss
40 | if avg_factor is None:
41 | loss = reduce_loss(loss, reduction)
42 | else:
43 | # if reduction is mean, then average the loss by avg_factor
44 | if reduction == 'mean':
45 | loss = loss.sum() / avg_factor
46 | # if reduction is 'none', then do nothing, otherwise raise an error
47 | elif reduction != 'none':
48 | raise ValueError('avg_factor can not be used with reduction="sum"')
49 | return loss
50 |
51 |
52 | def weighted_loss(loss_func):
53 | """Create a weighted version of a given loss function.
54 | To use this decorator, the loss function must have the signature like
55 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
56 | element-wise loss without any reduction. This decorator will add weight
57 | and reduction arguments to the function. The decorated function will have
58 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
59 | avg_factor=None, **kwargs)`.
60 | :Example:
61 | >>> import torch
62 | >>> @weighted_loss
63 | >>> def l1_loss(pred, target):
64 | >>> return (pred - target).abs()
65 | >>> pred = torch.Tensor([0, 2, 3])
66 | >>> target = torch.Tensor([1, 1, 1])
67 | >>> weight = torch.Tensor([1, 0, 1])
68 | >>> l1_loss(pred, target)
69 | tensor(1.3333)
70 | >>> l1_loss(pred, target, weight)
71 | tensor(1.)
72 | >>> l1_loss(pred, target, reduction='none')
73 | tensor([1., 1., 2.])
74 | >>> l1_loss(pred, target, weight, avg_factor=2)
75 | tensor(1.5000)
76 | """
77 |
78 | @functools.wraps(loss_func)
79 | def wrapper(pred,
80 | target,
81 | weight=None,
82 | reduction='mean',
83 | avg_factor=None,
84 | **kwargs):
85 | # get element-wise loss
86 | loss = loss_func(pred, target, **kwargs)
87 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
88 | return loss
89 |
90 | return wrapper
91 |
92 |
93 | def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
94 | """This function converts target class indices to one-hot vectors, given
95 | the number of classes.
96 | Args:
97 | targets (Tensor): The ground truth label of the prediction
98 | with shape (N, 1)
99 | classes (int): the number of classes.
100 | Returns:
101 | Tensor: Processed loss values.
102 | """
103 | assert (torch.max(targets).item() <
104 | classes), 'Class Index must be less than number of classes'
105 | one_hot_targets = torch.zeros((targets.shape[0], classes),
106 | dtype=torch.long,
107 | device=targets.device)
108 | one_hot_targets.scatter_(1, targets.long(), 1)
109 | return one_hot_targets
110 |
--------------------------------------------------------------------------------
/mogen/models/rnns/__init__.py:
--------------------------------------------------------------------------------
1 | from .t2m_bigru import T2MMotionEncoder, T2MTextEncoder
--------------------------------------------------------------------------------
/mogen/models/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import ACTOREncoder, ACTORDecoder
2 | from .motiondiffuse import MotionDiffuseTransformer
3 | from .remodiffuse import ReMoDiffuseTransformer
4 | from .mdm import MDMTransformer
5 | from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding
--------------------------------------------------------------------------------
/mogen/models/transformers/actor.py:
--------------------------------------------------------------------------------
1 | from cv2 import norm
2 | import torch
3 | from torch import layer_norm, nn
4 | from mmcv.runner import BaseModule
5 | import numpy as np
6 |
7 | from ..builder import SUBMODULES
8 | from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding
9 | import math
10 |
11 |
12 | @SUBMODULES.register_module()
13 | class ACTOREncoder(BaseModule):
14 | def __init__(self,
15 | max_seq_len=16,
16 | njoints=None,
17 | nfeats=None,
18 | input_feats=None,
19 | latent_dim=256,
20 | output_dim=256,
21 | condition_dim=None,
22 | num_heads=4,
23 | ff_size=1024,
24 | num_layers=8,
25 | activation='gelu',
26 | dropout=0.1,
27 | use_condition=False,
28 | num_class=None,
29 | use_final_proj=False,
30 | output_var=False,
31 | pos_embedding='sinusoidal',
32 | init_cfg=None):
33 | super().__init__(init_cfg=init_cfg)
34 | self.njoints = njoints
35 | self.nfeats = nfeats
36 | if input_feats is None:
37 | assert self.njoints is not None and self.nfeats is not None
38 | self.input_feats = njoints * nfeats
39 | else:
40 | self.input_feats = input_feats
41 | self.max_seq_len = max_seq_len
42 | self.latent_dim = latent_dim
43 | self.condition_dim = condition_dim
44 | self.use_condition = use_condition
45 | self.num_class = num_class
46 | self.use_final_proj = use_final_proj
47 | self.output_var = output_var
48 | self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim)
49 | if self.use_condition:
50 | if num_class is None:
51 | self.mu_layer = build_MLP(self.condition_dim, self.latent_dim)
52 | if self.output_var:
53 | self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim)
54 | else:
55 | self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim))
56 | if self.output_var:
57 | self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim))
58 | else:
59 | if self.output_var:
60 | self.query = nn.Parameter(torch.randn(2, self.latent_dim))
61 | else:
62 | self.query = nn.Parameter(torch.randn(1, self.latent_dim))
63 | if pos_embedding == 'sinusoidal':
64 | self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout)
65 | else:
66 | self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2)
67 | seqTransEncoderLayer = nn.TransformerEncoderLayer(
68 | d_model=self.latent_dim,
69 | nhead=num_heads,
70 | dim_feedforward=ff_size,
71 | dropout=dropout,
72 | activation=activation)
73 | self.seqTransEncoder = nn.TransformerEncoder(
74 | seqTransEncoderLayer,
75 | num_layers=num_layers)
76 |
77 | def forward(self, motion, motion_mask=None, condition=None):
78 | B, T = motion.shape[:2]
79 | motion = motion.view(B, T, -1)
80 | feature = self.skelEmbedding(motion)
81 | if self.use_condition:
82 | if self.output_var:
83 | if self.num_class is None:
84 | sigma_query = self.sigma_layer(condition).view(B, 1, -1)
85 | else:
86 | sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1)
87 | feature = torch.cat((sigma_query, feature), dim=1)
88 | if self.num_class is None:
89 | mu_query = self.mu_layer(condition).view(B, 1, -1)
90 | else:
91 | mu_query = self.mu_layer[condition.long()].view(B, 1, -1)
92 | feature = torch.cat((mu_query, feature), dim=1)
93 | else:
94 | query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1)
95 | feature = torch.cat((query, feature), dim=1)
96 | if self.output_var:
97 | motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool()
98 | else:
99 | motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool()
100 | feature = feature.permute(1, 0, 2).contiguous()
101 | feature = self.pos_encoder(feature)
102 | feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask)
103 | if self.use_final_proj:
104 | mu = self.final_mu(feature[0])
105 | if self.output_var:
106 | sigma = self.final_sigma(feature[1])
107 | return mu, sigma
108 | return mu
109 | else:
110 | if self.output_var:
111 | return feature[0], feature[1]
112 | else:
113 | return feature[0]
114 |
115 |
116 | @SUBMODULES.register_module()
117 | class ACTORDecoder(BaseModule):
118 |
119 | def __init__(self,
120 | max_seq_len=16,
121 | njoints=None,
122 | nfeats=None,
123 | input_feats=None,
124 | input_dim=256,
125 | latent_dim=256,
126 | condition_dim=None,
127 | num_heads=4,
128 | ff_size=1024,
129 | num_layers=8,
130 | activation='gelu',
131 | dropout=0.1,
132 | use_condition=False,
133 | num_class=None,
134 | pos_embedding='sinusoidal',
135 | init_cfg=None):
136 | super().__init__(init_cfg=init_cfg)
137 | if input_dim != latent_dim:
138 | self.linear = nn.Linear(input_dim, latent_dim)
139 | else:
140 | self.linear = nn.Identity()
141 | self.njoints = njoints
142 | self.nfeats = nfeats
143 | if input_feats is None:
144 | assert self.njoints is not None and self.nfeats is not None
145 | self.input_feats = njoints * nfeats
146 | else:
147 | self.input_feats = input_feats
148 | self.max_seq_len = max_seq_len
149 | self.input_dim = input_dim
150 | self.latent_dim = latent_dim
151 | self.condition_dim = condition_dim
152 | self.use_condition = use_condition
153 | self.num_class = num_class
154 | if self.use_condition:
155 | if num_class is None:
156 | self.condition_bias = build_MLP(condition_dim, latent_dim)
157 | else:
158 | self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim))
159 | if pos_embedding == 'sinusoidal':
160 | self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout)
161 | else:
162 | self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len)
163 | seqTransDecoderLayer = nn.TransformerDecoderLayer(
164 | d_model=self.latent_dim,
165 | nhead=num_heads,
166 | dim_feedforward=ff_size,
167 | dropout=dropout,
168 | activation=activation)
169 | self.seqTransDecoder = nn.TransformerDecoder(
170 | seqTransDecoderLayer,
171 | num_layers=num_layers)
172 |
173 | self.final = nn.Linear(self.latent_dim, self.input_feats)
174 |
175 | def forward(self, input, motion_mask=None, condition=None):
176 | B = input.shape[0]
177 | T = self.max_seq_len
178 | input = self.linear(input)
179 | if self.use_condition:
180 | if self.num_class is None:
181 | condition = self.condition_bias(condition)
182 | else:
183 | condition = self.condition_bias[condition.long()].squeeze(1)
184 | input = input + condition
185 | query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1)
186 | input = input.view(1, B, -1)
187 | feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool())
188 | pose = self.final(feature).permute(1, 0, 2).contiguous()
189 | return pose
190 |
--------------------------------------------------------------------------------
/mogen/models/transformers/mdm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import clip
6 |
7 | from ..builder import SUBMODULES
8 |
9 |
10 | def convert_weights(model: nn.Module):
11 | """Convert applicable model parameters to fp32"""
12 |
13 | def _convert_weights_to_fp32(l):
14 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
15 | l.weight.data = l.weight.data.float()
16 | if l.bias is not None:
17 | l.bias.data = l.bias.data.float()
18 |
19 | if isinstance(l, nn.MultiheadAttention):
20 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
21 | tensor = getattr(l, attr)
22 | if tensor is not None:
23 | tensor.data = tensor.data.float()
24 |
25 | for name in ["text_projection", "proj"]:
26 | if hasattr(l, name):
27 | attr = getattr(l, name)
28 | if attr is not None:
29 | attr.data = attr.data.float()
30 |
31 | model.apply(_convert_weights_to_fp32)
32 |
33 |
34 | @SUBMODULES.register_module()
35 | class MDMTransformer(nn.Module):
36 | def __init__(self,
37 | input_feats=263,
38 | latent_dim=256,
39 | ff_size=1024,
40 | num_layers=8,
41 | num_heads=4,
42 | dropout=0.1,
43 | activation="gelu",
44 | clip_dim=512,
45 | clip_version=None,
46 | guide_scale=1.0,
47 | cond_mask_prob=0.1,
48 | use_official_ckpt=False,
49 | **kwargs):
50 | super().__init__()
51 |
52 | self.latent_dim = latent_dim
53 | self.ff_size = ff_size
54 | self.num_layers = num_layers
55 | self.num_heads = num_heads
56 | self.dropout = dropout
57 | self.activation = activation
58 | self.clip_dim = clip_dim
59 | self.input_feats = input_feats
60 | self.guide_scale = guide_scale
61 | self.use_official_ckpt = use_official_ckpt
62 |
63 | self.cond_mask_prob = cond_mask_prob
64 | self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
65 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
66 |
67 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
68 | nhead=self.num_heads,
69 | dim_feedforward=self.ff_size,
70 | dropout=self.dropout,
71 | activation=self.activation)
72 |
73 | self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
74 | num_layers=self.num_layers)
75 |
76 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
77 |
78 |
79 | self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
80 | self.clip_version = clip_version
81 | self.clip_model = self.load_and_freeze_clip(clip_version)
82 |
83 | self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
84 |
85 |
86 | def load_and_freeze_clip(self, clip_version):
87 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
88 | jit=False) # Must set jit=False for training
89 | clip.model.convert_weights(
90 | clip_model) # Actually this line is unnecessary since clip by default already on float16
91 |
92 | clip_model.eval()
93 | for p in clip_model.parameters():
94 | p.requires_grad = False
95 |
96 | return clip_model
97 |
98 |
99 | def mask_cond(self, cond, force_mask=False):
100 | bs, d = cond.shape
101 | if force_mask:
102 | return torch.zeros_like(cond)
103 | elif self.training and self.cond_mask_prob > 0.:
104 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond
105 | return cond * (1. - mask)
106 | else:
107 | return cond
108 |
109 | def encode_text(self, raw_text):
110 | # raw_text - list (batch_size length) of strings with input text prompts
111 | device = next(self.parameters()).device
112 | max_text_len = 20
113 | if max_text_len is not None:
114 | default_context_length = 77
115 | context_length = max_text_len + 2 # start_token + 20 + end_token
116 | assert context_length < default_context_length
117 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device)
118 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
119 | texts = torch.cat([texts, zero_pad], dim=1)
120 | return self.clip_model.encode_text(texts).float()
121 |
122 | def get_precompute_condition(self, text, device=None, **kwargs):
123 | if not self.training and device == torch.device('cpu'):
124 | convert_weights(self.clip_model)
125 | text_feat = self.encode_text(text)
126 | return {'text_feat': text_feat}
127 |
128 | def post_process(self, motion):
129 | assert len(motion.shape) == 3
130 | if self.use_official_ckpt:
131 | motion[:, :, :4] = motion[:, :, :4] * 25
132 | return motion
133 |
134 | def forward(self, motion, timesteps, text_feat=None, **kwargs):
135 | """
136 | motion: B, T, D
137 | timesteps: [batch_size] (int)
138 | """
139 | B, T, D = motion.shape
140 | device = motion.device
141 | if text_feat is None:
142 | enc_text = get_precompute_condition(**kwargs)['text_feat']
143 | else:
144 | enc_text = text_feat
145 | if self.training:
146 | # T, B, D
147 | motion = self.poseEmbedding(motion).permute(1, 0, 2)
148 |
149 | emb = self.embed_timestep(timesteps) # [1, bs, d]
150 | emb += self.embed_text(self.mask_cond(enc_text, force_mask=False))
151 |
152 | xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0))
153 | output = self.seqTransEncoder(xseq)[1:]
154 |
155 | # B, T, D
156 | output = self.poseFinal(output).permute(1, 0, 2)
157 | return output
158 | else:
159 | # T, B, D
160 | motion = self.poseEmbedding(motion).permute(1, 0, 2)
161 |
162 | emb = self.embed_timestep(timesteps) # [1, bs, d]
163 | emb_uncond = emb + self.embed_text(self.mask_cond(enc_text, force_mask=True))
164 | emb_text = emb + self.embed_text(self.mask_cond(enc_text, force_mask=False))
165 |
166 | xseq = self.sequence_pos_encoder(torch.cat((emb_uncond, motion), axis=0))
167 | xseq_text = self.sequence_pos_encoder(torch.cat((emb_text, motion), axis=0))
168 | output = self.seqTransEncoder(xseq)[1:]
169 | output_text = self.seqTransEncoder(xseq_text)[1:]
170 | # B, T, D
171 | output = self.poseFinal(output).permute(1, 0, 2)
172 | output_text = self.poseFinal(output_text).permute(1, 0, 2)
173 | scale = self.guide_scale
174 | output = output + scale * (output_text - output)
175 | return output
176 |
177 |
178 | class PositionalEncoding(nn.Module):
179 | def __init__(self, d_model, dropout=0.1, max_len=5000):
180 | super(PositionalEncoding, self).__init__()
181 | self.dropout = nn.Dropout(p=dropout)
182 |
183 | pe = torch.zeros(max_len, d_model)
184 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
185 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
186 | pe[:, 0::2] = torch.sin(position * div_term)
187 | pe[:, 1::2] = torch.cos(position * div_term)
188 | pe = pe.unsqueeze(0).transpose(0, 1)
189 |
190 | self.register_buffer('pe', pe)
191 |
192 | def forward(self, x):
193 | # not used in the final model
194 | x = x + self.pe[:x.shape[0], :]
195 | return self.dropout(x)
196 |
197 |
198 | class TimestepEmbedder(nn.Module):
199 | def __init__(self, latent_dim, sequence_pos_encoder):
200 | super().__init__()
201 | self.latent_dim = latent_dim
202 | self.sequence_pos_encoder = sequence_pos_encoder
203 |
204 | time_embed_dim = self.latent_dim
205 | self.time_embed = nn.Sequential(
206 | nn.Linear(self.latent_dim, time_embed_dim),
207 | nn.SiLU(),
208 | nn.Linear(time_embed_dim, time_embed_dim),
209 | )
210 |
211 | def forward(self, timesteps):
212 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
213 |
--------------------------------------------------------------------------------
/mogen/models/transformers/motiondiffuse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from ..builder import SUBMODULES
7 | from .diffusion_transformer import DiffusionTransformer
8 |
9 |
10 | @SUBMODULES.register_module()
11 | class MotionDiffuseTransformer(DiffusionTransformer):
12 | def __init__(self, **kwargs):
13 | super().__init__(**kwargs)
14 |
15 | def get_precompute_condition(self,
16 | text=None,
17 | xf_proj=None,
18 | xf_out=None,
19 | device=None,
20 | clip_feat=None,
21 | **kwargs):
22 | if xf_proj is None or xf_out is None:
23 | xf_proj, xf_out = self.encode_text(text, clip_feat, device)
24 | return {'xf_proj': xf_proj, 'xf_out': xf_out}
25 |
26 | def post_process(self, motion):
27 | return motion
28 |
29 | def forward_train(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs):
30 | B, T = h.shape[0], h.shape[1]
31 | for module in self.temporal_decoder_blocks:
32 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
33 | output = self.out(h).view(B, T, -1).contiguous()
34 | return output
35 |
36 | def forward_test(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs):
37 | B, T = h.shape[0], h.shape[1]
38 | for module in self.temporal_decoder_blocks:
39 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask)
40 | output = self.out(h).view(B, T, -1).contiguous()
41 | return output
42 |
--------------------------------------------------------------------------------
/mogen/models/transformers/position_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class SinusoidalPositionalEncoding(nn.Module):
7 | def __init__(self, d_model, dropout=0.1, max_len=5000):
8 | super(SinusoidalPositionalEncoding, self).__init__()
9 | self.dropout = nn.Dropout(p=dropout)
10 |
11 | pe = torch.zeros(max_len, d_model)
12 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
13 | div_term = torch.arange(0, d_model, 2).float()
14 | div_term = div_term * (-np.log(10000.0) / d_model)
15 | div_term = torch.exp(div_term)
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0).transpose(0, 1)
19 | # T, 1, D
20 | self.register_buffer('pe', pe)
21 |
22 | def forward(self, x):
23 | x = x + self.pe[:x.shape[0]]
24 | return self.dropout(x)
25 |
26 |
27 | class LearnedPositionalEncoding(nn.Module):
28 | def __init__(self, d_model, dropout=0.1, max_len=5000):
29 | super(LearnedPositionalEncoding, self).__init__()
30 | self.dropout = nn.Dropout(p=dropout)
31 | self.pe = nn.Parameter(torch.randn(max_len, 1, d_model))
32 |
33 | def forward(self, x):
34 | x = x + self.pe[:x.shape[0]]
35 | return self.dropout(x)
36 |
--------------------------------------------------------------------------------
/mogen/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/models/utils/__init__.py
--------------------------------------------------------------------------------
/mogen/models/utils/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def build_MLP(dim_list, latent_dim):
5 | model_list = []
6 | prev = dim_list[0]
7 | for cur in dim_list[1:]:
8 | model_list.append(nn.Linear(prev, cur))
9 | model_list.append(nn.GELU())
10 | prev = cur
11 | model_list.append(nn.Linear(prev, latent_dim))
12 | model = nn.Sequential(*model_list)
13 | return model
14 |
--------------------------------------------------------------------------------
/mogen/models/utils/stylization_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def zero_module(module):
6 | """
7 | Zero out the parameters of a module and return it.
8 | """
9 | for p in module.parameters():
10 | p.detach().zero_()
11 | return module
12 |
13 |
14 | class StylizationBlock(nn.Module):
15 |
16 | def __init__(self, latent_dim, time_embed_dim, dropout):
17 | super().__init__()
18 | self.emb_layers = nn.Sequential(
19 | nn.SiLU(),
20 | nn.Linear(time_embed_dim, 2 * latent_dim),
21 | )
22 | self.norm = nn.LayerNorm(latent_dim)
23 | self.out_layers = nn.Sequential(
24 | nn.SiLU(),
25 | nn.Dropout(p=dropout),
26 | zero_module(nn.Linear(latent_dim, latent_dim)),
27 | )
28 |
29 | def forward(self, h, emb):
30 | """
31 | h: B, T, D
32 | emb: B, D
33 | """
34 | # B, 1, 2D
35 | emb_out = self.emb_layers(emb).unsqueeze(1)
36 | # scale: B, 1, D / shift: B, 1, D
37 | scale, shift = torch.chunk(emb_out, 2, dim=2)
38 | h = self.norm(h) * (1 + scale) + shift
39 | h = self.out_layers(h)
40 | return h
41 |
--------------------------------------------------------------------------------
/mogen/models/utils/word_vectorizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | from os.path import join as pjoin
4 |
5 | POS_enumerator = {
6 | 'VERB': 0,
7 | 'NOUN': 1,
8 | 'DET': 2,
9 | 'ADP': 3,
10 | 'NUM': 4,
11 | 'AUX': 5,
12 | 'PRON': 6,
13 | 'ADJ': 7,
14 | 'ADV': 8,
15 | 'Loc_VIP': 9,
16 | 'Body_VIP': 10,
17 | 'Obj_VIP': 11,
18 | 'Act_VIP': 12,
19 | 'Desc_VIP': 13,
20 | 'OTHER': 14,
21 | }
22 |
23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24 | 'up', 'down', 'straight', 'curve')
25 |
26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27 |
28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29 |
30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33 |
34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35 | 'angrily', 'sadly')
36 |
37 | VIP_dict = {
38 | 'Loc_VIP': Loc_list,
39 | 'Body_VIP': Body_list,
40 | 'Obj_VIP': Obj_List,
41 | 'Act_VIP': Act_list,
42 | 'Desc_VIP': Desc_list,
43 | }
44 |
45 |
46 | class WordVectorizer(object):
47 | def __init__(self, meta_root, prefix):
48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51 | self.word2vec = {w: vectors[word2idx[w]] for w in words}
52 |
53 | def _get_pos_ohot(self, pos):
54 | pos_vec = np.zeros(len(POS_enumerator))
55 | if pos in POS_enumerator:
56 | pos_vec[POS_enumerator[pos]] = 1
57 | else:
58 | pos_vec[POS_enumerator['OTHER']] = 1
59 | return pos_vec
60 |
61 | def __len__(self):
62 | return len(self.word2vec)
63 |
64 | def __getitem__(self, item):
65 | word, pos = item.split('/')
66 | if word in self.word2vec:
67 | word_vec = self.word2vec[word]
68 | vip_pos = None
69 | for key, values in VIP_dict.items():
70 | if word in values:
71 | vip_pos = key
72 | break
73 | if vip_pos is not None:
74 | pos_vec = self._get_pos_ohot(vip_pos)
75 | else:
76 | pos_vec = self._get_pos_ohot(pos)
77 | else:
78 | word_vec = self.word2vec['unk']
79 | pos_vec = self._get_pos_ohot('OTHER')
80 | return word_vec, pos_vec
--------------------------------------------------------------------------------
/mogen/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from mogen.utils.collect_env import collect_env
2 | from mogen.utils.dist_utils import DistOptimizerHook, allreduce_grads
3 | from mogen.utils.logger import get_root_logger
4 | from mogen.utils.misc import multi_apply, torch_to_numpy
5 | from mogen.utils.path_utils import (
6 | Existence,
7 | check_input_path,
8 | check_path_existence,
9 | check_path_suffix,
10 | prepare_output_path,
11 | )
12 |
13 |
14 | __all__ = [
15 | 'collect_env', 'DistOptimizerHook', 'allreduce_grads', 'get_root_logger',
16 | 'multi_apply', 'torch_to_numpy', 'Existence', 'check_input_path',
17 | 'check_path_existence', 'check_path_suffix', 'prepare_output_path'
18 | ]
--------------------------------------------------------------------------------
/mogen/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import collect_env as collect_base_env
2 | from mmcv.utils import get_git_hash
3 |
4 | import mogen
5 |
6 |
7 | def collect_env():
8 | """Collect the information of the running environments."""
9 | env_info = collect_base_env()
10 | env_info['mogen'] = mogen.__version__ + '+' + get_git_hash()[:7]
11 | return env_info
12 |
13 |
14 | if __name__ == '__main__':
15 | for name, val in collect_env().items():
16 | print(f'{name}: {val}')
--------------------------------------------------------------------------------
/mogen/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.distributed as dist
4 | from mmcv.runner import OptimizerHook
5 | from torch._utils import (
6 | _flatten_dense_tensors,
7 | _take_tensors,
8 | _unflatten_dense_tensors,
9 | )
10 |
11 |
12 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
13 | if bucket_size_mb > 0:
14 | bucket_size_bytes = bucket_size_mb * 1024 * 1024
15 | buckets = _take_tensors(tensors, bucket_size_bytes)
16 | else:
17 | buckets = OrderedDict()
18 | for tensor in tensors:
19 | tp = tensor.type()
20 | if tp not in buckets:
21 | buckets[tp] = []
22 | buckets[tp].append(tensor)
23 | buckets = buckets.values()
24 |
25 | for bucket in buckets:
26 | flat_tensors = _flatten_dense_tensors(bucket)
27 | dist.all_reduce(flat_tensors)
28 | flat_tensors.div_(world_size)
29 | for tensor, synced in zip(
30 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
31 | tensor.copy_(synced)
32 |
33 |
34 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
35 | grads = [
36 | param.grad.data for param in params
37 | if param.requires_grad and param.grad is not None
38 | ]
39 | world_size = dist.get_world_size()
40 | if coalesce:
41 | _allreduce_coalesced(grads, world_size, bucket_size_mb)
42 | else:
43 | for tensor in grads:
44 | dist.all_reduce(tensor.div_(world_size))
45 |
46 |
47 | class DistOptimizerHook(OptimizerHook):
48 |
49 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
50 | self.grad_clip = grad_clip
51 | self.coalesce = coalesce
52 | self.bucket_size_mb = bucket_size_mb
53 |
54 | def after_train_iter(self, runner):
55 | runner.optimizer.zero_grad()
56 | runner.outputs['loss'].backward()
57 | if self.grad_clip is not None:
58 | self.clip_grads(runner.model.parameters())
59 | runner.optimizer.step()
--------------------------------------------------------------------------------
/mogen/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from mmcv.utils import get_logger
4 |
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO):
7 | return get_logger('mogen', log_file, log_level)
--------------------------------------------------------------------------------
/mogen/utils/misc.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 |
5 |
6 | def multi_apply(func, *args, **kwargs):
7 | pfunc = partial(func, **kwargs) if kwargs else func
8 | map_results = map(pfunc, *args)
9 | return tuple(map(list, zip(*map_results)))
10 |
11 |
12 | def torch_to_numpy(x):
13 | assert isinstance(x, torch.Tensor)
14 | return x.detach().cpu().numpy()
--------------------------------------------------------------------------------
/mogen/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from enum import Enum
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | try:
8 | from typing import Literal
9 | except ImportError:
10 | from typing_extensions import Literal
11 |
12 |
13 | def check_path_suffix(path_str: str,
14 | allowed_suffix: Union[str, List[str]] = '') -> bool:
15 | """Check whether the suffix of the path is allowed.
16 |
17 | Args:
18 | path_str (str):
19 | Path to check.
20 | allowed_suffix (List[str], optional):
21 | What extension names are allowed.
22 | Offer a list like ['.jpg', ',jpeg'].
23 | When it's [], all will be received.
24 | Use [''] then directory is allowed.
25 | Defaults to [].
26 |
27 | Returns:
28 | bool:
29 | True: suffix test passed
30 | False: suffix test failed
31 | """
32 | if isinstance(allowed_suffix, str):
33 | allowed_suffix = [allowed_suffix]
34 | pathinfo = Path(path_str)
35 | suffix = pathinfo.suffix.lower()
36 | if len(allowed_suffix) == 0:
37 | return True
38 | if pathinfo.is_dir():
39 | if '' in allowed_suffix:
40 | return True
41 | else:
42 | return False
43 | else:
44 | for index, tmp_suffix in enumerate(allowed_suffix):
45 | if not tmp_suffix.startswith('.'):
46 | tmp_suffix = '.' + tmp_suffix
47 | allowed_suffix[index] = tmp_suffix.lower()
48 | if suffix in allowed_suffix:
49 | return True
50 | else:
51 | return False
52 |
53 |
54 | class Existence(Enum):
55 | """State of file existence."""
56 | FileExist = 0
57 | DirectoryExistEmpty = 1
58 | DirectoryExistNotEmpty = 2
59 | MissingParent = 3
60 | DirectoryNotExist = 4
61 | FileNotExist = 5
62 |
63 |
64 | def check_path_existence(
65 | path_str: str,
66 | path_type: Literal['file', 'dir', 'auto'] = 'auto',
67 | ) -> Existence:
68 | """Check whether a file or a directory exists at the expected path.
69 |
70 | Args:
71 | path_str (str):
72 | Path to check.
73 | path_type (Literal[, optional):
74 | What kind of file do we expect at the path.
75 | Choose among `file`, `dir`, `auto`.
76 | Defaults to 'auto'. path_type = path_type.lower()
77 |
78 | Raises:
79 | KeyError: if `path_type` conflicts with `path_str`
80 |
81 | Returns:
82 | Existence:
83 | 0. FileExist: file at path_str exists.
84 | 1. DirectoryExistEmpty: folder at path exists and.
85 | 2. DirectoryExistNotEmpty: folder at path_str exists and not empty.
86 | 3. MissingParent: its parent doesn't exist.
87 | 4. DirectoryNotExist: expect a folder at path_str, but not found.
88 | 5. FileNotExist: expect a file at path_str, but not found.
89 | """
90 | path_type = path_type.lower()
91 | assert path_type in {'file', 'dir', 'auto'}
92 | pathinfo = Path(path_str)
93 | if not pathinfo.parent.is_dir():
94 | return Existence.MissingParent
95 | suffix = pathinfo.suffix.lower()
96 | if path_type == 'dir' or\
97 | path_type == 'auto' and suffix == '':
98 | if pathinfo.is_dir():
99 | if len(os.listdir(path_str)) == 0:
100 | return Existence.DirectoryExistEmpty
101 | else:
102 | return Existence.DirectoryExistNotEmpty
103 | else:
104 | return Existence.DirectoryNotExist
105 | elif path_type == 'file' or\
106 | path_type == 'auto' and suffix != '':
107 | if pathinfo.is_file():
108 | return Existence.FileExist
109 | elif pathinfo.is_dir():
110 | if len(os.listdir(path_str)) == 0:
111 | return Existence.DirectoryExistEmpty
112 | else:
113 | return Existence.DirectoryExistNotEmpty
114 | if path_str.endswith('/'):
115 | return Existence.DirectoryNotExist
116 | else:
117 | return Existence.FileNotExist
118 |
119 |
120 | def prepare_output_path(output_path: str,
121 | allowed_suffix: List[str] = [],
122 | tag: str = 'output file',
123 | path_type: Literal['file', 'dir', 'auto'] = 'auto',
124 | overwrite: bool = True) -> None:
125 | """Check output folder or file.
126 |
127 | Args:
128 | output_path (str): could be folder or file.
129 | allowed_suffix (List[str], optional):
130 | Check the suffix of `output_path`. If folder, should be [] or [''].
131 | If could both be folder or file, should be [suffixs..., ''].
132 | Defaults to [].
133 | tag (str, optional): The `string` tag to specify the output type.
134 | Defaults to 'output file'.
135 | path_type (Literal[, optional):
136 | Choose `file` for file and `dir` for folder.
137 | Choose `auto` if allowed to be both.
138 | Defaults to 'auto'.
139 | overwrite (bool, optional):
140 | Whether overwrite the existing file or folder.
141 | Defaults to True.
142 |
143 | Raises:
144 | FileNotFoundError: suffix does not match.
145 | FileExistsError: file or folder already exists and `overwrite` is
146 | False.
147 |
148 | Returns:
149 | None
150 | """
151 | if path_type.lower() == 'dir':
152 | allowed_suffix = []
153 | exist_result = check_path_existence(output_path, path_type=path_type)
154 | if exist_result == Existence.MissingParent:
155 | warnings.warn(
156 | f'The parent folder of {tag} does not exist: {output_path},' +
157 | f' will make dir {Path(output_path).parent.absolute().__str__()}')
158 | os.makedirs(
159 | Path(output_path).parent.absolute().__str__(), exist_ok=True)
160 |
161 | elif exist_result == Existence.DirectoryNotExist:
162 | os.mkdir(output_path)
163 | print(f'Making directory {output_path} for saving results.')
164 | elif exist_result == Existence.FileNotExist:
165 | suffix_matched = \
166 | check_path_suffix(output_path, allowed_suffix=allowed_suffix)
167 | if not suffix_matched:
168 | raise FileNotFoundError(
169 | f'The {tag} should be {", ".join(allowed_suffix)}: '
170 | f'{output_path}.')
171 | elif exist_result == Existence.FileExist:
172 | if not overwrite:
173 | raise FileExistsError(
174 | f'{output_path} exists (set overwrite = True to overwrite).')
175 | else:
176 | print(f'Overwriting {output_path}.')
177 | elif exist_result == Existence.DirectoryExistEmpty:
178 | pass
179 | elif exist_result == Existence.DirectoryExistNotEmpty:
180 | if not overwrite:
181 | raise FileExistsError(
182 | f'{output_path} is not empty (set overwrite = '
183 | 'True to overwrite the files).')
184 | else:
185 | print(f'Overwriting {output_path} and its files.')
186 | else:
187 | raise FileNotFoundError(f'No Existence type for {output_path}.')
188 |
189 |
190 | def check_input_path(
191 | input_path: str,
192 | allowed_suffix: List[str] = [],
193 | tag: str = 'input file',
194 | path_type: Literal['file', 'dir', 'auto'] = 'auto',
195 | ):
196 | """Check input folder or file.
197 |
198 | Args:
199 | input_path (str): input folder or file path.
200 | allowed_suffix (List[str], optional):
201 | Check the suffix of `input_path`. If folder, should be [] or [''].
202 | If could both be folder or file, should be [suffixs..., ''].
203 | Defaults to [].
204 | tag (str, optional): The `string` tag to specify the output type.
205 | Defaults to 'output file'.
206 | path_type (Literal[, optional):
207 | Choose `file` for file and `directory` for folder.
208 | Choose `auto` if allowed to be both.
209 | Defaults to 'auto'.
210 |
211 | Raises:
212 | FileNotFoundError: file does not exists or suffix does not match.
213 |
214 | Returns:
215 | None
216 | """
217 | if path_type.lower() == 'dir':
218 | allowed_suffix = []
219 | exist_result = check_path_existence(input_path, path_type=path_type)
220 |
221 | if exist_result in [
222 | Existence.FileExist, Existence.DirectoryExistEmpty,
223 | Existence.DirectoryExistNotEmpty
224 | ]:
225 | suffix_matched = \
226 | check_path_suffix(input_path, allowed_suffix=allowed_suffix)
227 | if not suffix_matched:
228 | raise FileNotFoundError(
229 | f'The {tag} should be {", ".join(allowed_suffix)}:' +
230 | f'{input_path}.')
231 | else:
232 | raise FileNotFoundError(f'The {tag} does not exist: {input_path}.')
--------------------------------------------------------------------------------
/mogen/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is borrowed from https://github.com/EricGuo5513/text-to-motion
3 | """
4 |
5 | import torch
6 | import numpy as np
7 |
8 | import math
9 | import matplotlib
10 | import matplotlib.pyplot as plt
11 | from mpl_toolkits.mplot3d import Axes3D
12 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter
13 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
14 | import mpl_toolkits.mplot3d.axes3d as p3
15 |
16 | # Define a kinematic tree for the skeletal struture
17 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
18 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
19 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
20 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
21 |
22 |
23 | def qinv(q):
24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25 | mask = torch.ones_like(q)
26 | mask[..., 1:] = -mask[..., 1:]
27 | return q * mask
28 |
29 |
30 | def qrot(q, v):
31 | """
32 | Rotate vector(s) v about the rotation described by quaternion(s) q.
33 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
34 | where * denotes any number of dimensions.
35 | Returns a tensor of shape (*, 3).
36 | """
37 | assert q.shape[-1] == 4
38 | assert v.shape[-1] == 3
39 | assert q.shape[:-1] == v.shape[:-1]
40 |
41 | original_shape = list(v.shape)
42 | # print(q.shape)
43 | q = q.contiguous().view(-1, 4)
44 | v = v.contiguous().view(-1, 3)
45 |
46 | qvec = q[:, 1:]
47 | uv = torch.cross(qvec, v, dim=1)
48 | uuv = torch.cross(qvec, uv, dim=1)
49 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
50 |
51 |
52 | def recover_root_rot_pos(data):
53 | rot_vel = data[..., 0]
54 | r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
55 | '''Get Y-axis rotation from rotation velocity'''
56 | r_rot_ang[..., 1:] = rot_vel[..., :-1]
57 | r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
58 |
59 | r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
60 | r_rot_quat[..., 0] = torch.cos(r_rot_ang)
61 | r_rot_quat[..., 2] = torch.sin(r_rot_ang)
62 |
63 | r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
64 | r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
65 | '''Add Y-axis rotation to root position'''
66 | r_pos = qrot(qinv(r_rot_quat), r_pos)
67 |
68 | r_pos = torch.cumsum(r_pos, dim=-2)
69 |
70 | r_pos[..., 1] = data[..., 3]
71 | return r_rot_quat, r_pos
72 |
73 |
74 | def recover_from_ric(data, joints_num):
75 | r_rot_quat, r_pos = recover_root_rot_pos(data)
76 | positions = data[..., 4:(joints_num - 1) * 3 + 4]
77 | positions = positions.view(positions.shape[:-1] + (-1, 3))
78 |
79 | '''Add Y-axis rotation to local joints'''
80 | positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
81 |
82 | '''Add root XZ to joints'''
83 | positions[..., 0] += r_pos[..., 0:1]
84 | positions[..., 2] += r_pos[..., 2:3]
85 |
86 | '''Concate root and joints'''
87 | positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
88 |
89 | return positions
90 |
91 |
92 | def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
93 | matplotlib.use('Agg')
94 |
95 | title_sp = title.split(' ')
96 | if len(title_sp) > 20:
97 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
98 | elif len(title_sp) > 10:
99 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
100 |
101 | def init():
102 | ax.set_xlim3d([-radius / 4, radius / 4])
103 | ax.set_ylim3d([0, radius / 2])
104 | ax.set_zlim3d([0, radius / 2])
105 | fig.suptitle(title, fontsize=20)
106 | ax.grid(b=False)
107 |
108 | def plot_xzPlane(minx, maxx, miny, minz, maxz):
109 | verts = [
110 | [minx, miny, minz],
111 | [minx, miny, maxz],
112 | [maxx, miny, maxz],
113 | [maxx, miny, minz]
114 | ]
115 | xz_plane = Poly3DCollection([verts])
116 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
117 | ax.add_collection3d(xz_plane)
118 |
119 | # (seq_len, joints_num, 3)
120 | data = joints.copy().reshape(len(joints), -1, 3)
121 | fig = plt.figure(figsize=figsize)
122 | ax = p3.Axes3D(fig)
123 | init()
124 | MINS = data.min(axis=0).min(axis=0)
125 | MAXS = data.max(axis=0).max(axis=0)
126 | colors = ['red', 'blue', 'black', 'red', 'blue',
127 | 'yellow', 'yellow', 'darkblue', 'darkblue', 'darkblue',
128 | 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
129 |
130 | frame_number = data.shape[0]
131 |
132 | height_offset = MINS[1]
133 | data[:, :, 1] -= height_offset
134 | trajec = data[:, 0, [0, 2]]
135 |
136 | data[..., 0] -= data[:, 0:1, 0]
137 | data[..., 2] -= data[:, 0:1, 2]
138 |
139 | def update(index):
140 | ax.lines = []
141 | ax.collections = []
142 | ax.view_init(elev=120, azim=-90)
143 | ax.dist = 7.5
144 | plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
145 | MAXS[2] - trajec[index, 1])
146 |
147 | if index > 1:
148 | ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
149 | trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
150 | color='blue')
151 |
152 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
153 | if i < 5:
154 | linewidth = 4.0
155 | else:
156 | linewidth = 2.0
157 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
158 | color=color)
159 |
160 | plt.axis('off')
161 | ax.set_xticklabels([])
162 | ax.set_yticklabels([])
163 | ax.set_zticklabels([])
164 |
165 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
166 | ani.save(save_path, fps=fps)
167 | plt.close()
168 |
169 |
170 | def plot_3d_motion_kit(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
171 | matplotlib.use('Agg')
172 |
173 | title_sp = title.split(' ')
174 | if len(title_sp) > 20:
175 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
176 | elif len(title_sp) > 10:
177 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
178 |
179 | def init():
180 | ax.set_xlim3d([-radius / 0.125, radius / 0.125])
181 | ax.set_ylim3d([0, radius / 0.0625])
182 | ax.set_zlim3d([0, radius / 0.0625])
183 | fig.suptitle(title, fontsize=20)
184 | ax.grid(b=False)
185 |
186 | def plot_xzPlane(minx, maxx, miny, minz, maxz):
187 | verts = [
188 | [minx, miny, minz],
189 | [minx, miny, maxz],
190 | [maxx, miny, maxz],
191 | [maxx, miny, minz]
192 | ]
193 | xz_plane = Poly3DCollection([verts])
194 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
195 | ax.add_collection3d(xz_plane)
196 |
197 | # (seq_len, joints_num, 3)
198 | data = joints.copy().reshape(len(joints), -1, 3)
199 | fig = plt.figure(figsize=figsize)
200 | ax = p3.Axes3D(fig)
201 | init()
202 | MINS = data.min(axis=0).min(axis=0)
203 | MAXS = data.max(axis=0).max(axis=0)
204 | colors = ['red', 'blue', 'black', 'red', 'blue',]
205 |
206 | frame_number = data.shape[0]
207 |
208 | height_offset = MINS[1]
209 | data[:, :, 1] -= height_offset
210 | trajec = data[:, 0, [0, 2]]
211 |
212 | data[..., 0] -= data[:, 0:1, 0]
213 | data[..., 2] -= data[:, 0:1, 2]
214 |
215 | def update(index):
216 | ax.lines = []
217 | ax.collections = []
218 | ax.view_init(elev=100, azim=-90)
219 | ax.dist = 225
220 | plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
221 | MAXS[2] - trajec[index, 1])
222 |
223 | if index > 1:
224 | ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
225 | trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
226 | color='blue')
227 |
228 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
229 | if i < 5:
230 | linewidth = 4.0
231 | else:
232 | linewidth = 2.0
233 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
234 | color=color)
235 |
236 | plt.axis('off')
237 | ax.set_xticklabels([])
238 | ax.set_yticklabels([])
239 | ax.set_zticklabels([])
240 |
241 | ani = FuncAnimation(fig, update, frames=frame_number, interval=500 / fps, repeat=False)
242 | ani.save(save_path, fps=fps)
243 | plt.close()
244 |
--------------------------------------------------------------------------------
/mogen/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.1'
2 |
3 |
4 | def parse_version_info(version_str):
5 | """Parse a version string into a tuple.
6 | Args:
7 | version_str (str): The version string.
8 | Returns:
9 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
10 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
11 | """
12 | version_info = []
13 | for x in version_str.split('.'):
14 | if x.isdigit():
15 | version_info.append(int(x))
16 | elif x.find('rc') != -1:
17 | patch_version = x.split('rc')
18 | version_info.append(int(patch_version[0]))
19 | version_info.append(f'rc{patch_version[1]}')
20 | return tuple(version_info)
21 |
22 |
23 | version_info = parse_version_info(__version__)
24 |
25 | __all__ = ['__version__', 'version_info', 'parse_version_info']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | scipy
5 | matplotlib==3.3.1
6 | git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | CONFIG=$1
2 | WORK_DIR=$2
3 | GPUS=$3
4 | PORT=${PORT:-29500}
5 |
6 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
7 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
8 | $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:4}
--------------------------------------------------------------------------------
/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 |
4 | set -x
5 |
6 | PARTITION=$1
7 | JOB_NAME=$2
8 | CONFIG=$3
9 | WORK_DIR=$4
10 | CHECKPOINT=$5
11 | GPUS=1
12 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8))
13 | CPUS_PER_TASK=${CPUS_PER_TASK:-2}
14 | SRUN_ARGS=${SRUN_ARGS:-""}
15 | PY_ARGS=${@:6}
16 |
17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
18 | srun -p ${PARTITION} \
19 | --job-name=${JOB_NAME} \
20 | --gres=gpu:${GPUS_PER_NODE} \
21 | --ntasks=${GPUS} \
22 | --ntasks-per-node=${GPUS_PER_NODE} \
23 | --cpus-per-task=${CPUS_PER_TASK} \
24 | --kill-on-bad-exit=1 \
25 | ${SRUN_ARGS} \
26 | python -u tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
--------------------------------------------------------------------------------
/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 |
4 | set -x
5 |
6 | PARTITION=$1
7 | JOB_NAME=$2
8 | CONFIG=$3
9 | WORK_DIR=$4
10 | GPUS=$5
11 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8))
12 | CPUS_PER_TASK=${CPUS_PER_TASK:-2}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 | PY_ARGS=${@:6}
15 |
16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
17 | srun -p ${PARTITION} \
18 | --job-name=${JOB_NAME} \
19 | --gres=gpu:${GPUS_PER_NODE} \
20 | --ntasks=${GPUS} \
21 | --ntasks-per-node=${GPUS_PER_NODE} \
22 | --cpus-per-task=${CPUS_PER_TASK} \
23 | --kill-on-bad-exit=1 \
24 | ${SRUN_ARGS} \
25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 |
5 | import mmcv
6 | import torch
7 | from mmcv import DictAction
8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
9 | from mmcv.runner import (
10 | get_dist_info,
11 | init_dist,
12 | load_checkpoint,
13 | wrap_fp16_model,
14 | )
15 |
16 | from mogen.apis import multi_gpu_test, single_gpu_test
17 | from mogen.datasets import build_dataloader, build_dataset
18 | from mogen.models import build_architecture
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='mogen evaluation')
23 | parser.add_argument('config', help='test config file path')
24 | parser.add_argument(
25 | '--work-dir', help='the dir to save evaluation results')
26 | parser.add_argument('checkpoint', help='checkpoint file')
27 | parser.add_argument('--out', help='output result file')
28 | parser.add_argument(
29 | '--gpu_collect',
30 | action='store_true',
31 | help='whether to use gpu to collect results')
32 | parser.add_argument('--tmpdir', help='tmp dir for writing some results')
33 | parser.add_argument(
34 | '--cfg-options',
35 | nargs='+',
36 | action=DictAction,
37 | help='override some settings in the used config, the key-value pair '
38 | 'in xxx=yyy format will be merged into config file.')
39 | parser.add_argument(
40 | '--launcher',
41 | choices=['none', 'pytorch', 'slurm', 'mpi'],
42 | default='none',
43 | help='job launcher')
44 | parser.add_argument('--local_rank', type=int, default=0)
45 | parser.add_argument(
46 | '--device',
47 | choices=['cpu', 'cuda'],
48 | default='cuda',
49 | help='device used for testing')
50 | args = parser.parse_args()
51 | if 'LOCAL_RANK' not in os.environ:
52 | os.environ['LOCAL_RANK'] = str(args.local_rank)
53 | return args
54 |
55 |
56 | def main():
57 | args = parse_args()
58 |
59 | cfg = mmcv.Config.fromfile(args.config)
60 | if args.cfg_options is not None:
61 | cfg.merge_from_dict(args.cfg_options)
62 | # set cudnn_benchmark
63 | if cfg.get('cudnn_benchmark', False):
64 | torch.backends.cudnn.benchmark = True
65 | cfg.data.test.test_mode = True
66 |
67 | # init distributed env first, since logger depends on the dist info.
68 | if args.launcher == 'none':
69 | distributed = False
70 | else:
71 | distributed = True
72 | init_dist(args.launcher, **cfg.dist_params)
73 |
74 | # build the dataloader
75 | dataset = build_dataset(cfg.data.test)
76 | # the extra round_up data will be removed during gpu/cpu collect
77 | data_loader = build_dataloader(
78 | dataset,
79 | samples_per_gpu=cfg.data.samples_per_gpu,
80 | workers_per_gpu=cfg.data.workers_per_gpu,
81 | dist=distributed,
82 | shuffle=False,
83 | round_up=False)
84 |
85 | # build the model and load checkpoint
86 | model = build_architecture(cfg.model)
87 | fp16_cfg = cfg.get('fp16', None)
88 | if fp16_cfg is not None:
89 | wrap_fp16_model(model)
90 | load_checkpoint(model, args.checkpoint, map_location='cpu')
91 |
92 | if not distributed:
93 | if args.device == 'cpu':
94 | model = model.cpu()
95 | else:
96 | model = MMDataParallel(model, device_ids=[0])
97 | outputs = single_gpu_test(model, data_loader)
98 | else:
99 | model = MMDistributedDataParallel(
100 | model.cuda(),
101 | device_ids=[torch.cuda.current_device()],
102 | broadcast_buffers=False)
103 | outputs = multi_gpu_test(model, data_loader, args.tmpdir,
104 | args.gpu_collect)
105 |
106 | rank, _ = get_dist_info()
107 | if rank == 0:
108 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
109 | results = dataset.evaluate(outputs, args.work_dir)
110 | for k, v in results.items():
111 | print(f'\n{k} : {v:.4f}')
112 |
113 | if args.out and rank == 0:
114 | print(f'\nwriting results to {args.out}')
115 | mmcv.dump(results, args.out)
116 |
117 |
118 | if __name__ == '__main__':
119 | main()
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | import os.path as osp
5 | import time
6 |
7 | import mmcv
8 | import torch
9 | from mmcv import Config, DictAction
10 | from mmcv.runner import get_dist_info, init_dist
11 |
12 | from mogen import __version__
13 | from mogen.apis import set_random_seed, train_model
14 | from mogen.datasets import build_dataset
15 | from mogen.models import build_architecture
16 | from mogen.utils import collect_env, get_root_logger
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Train a model')
21 | parser.add_argument('config', help='train config file path')
22 | parser.add_argument('--work-dir', help='the dir to save logs and models')
23 | parser.add_argument(
24 | '--resume-from', help='the checkpoint file to resume from')
25 | parser.add_argument(
26 | '--no-validate',
27 | action='store_true',
28 | help='whether not to evaluate the checkpoint during training')
29 | group_gpus = parser.add_mutually_exclusive_group()
30 | group_gpus.add_argument('--device', help='device used for training')
31 | group_gpus.add_argument(
32 | '--gpus',
33 | type=int,
34 | help='number of gpus to use '
35 | '(only applicable to non-distributed training)')
36 | group_gpus.add_argument(
37 | '--gpu-ids',
38 | type=int,
39 | nargs='+',
40 | help='ids of gpus to use '
41 | '(only applicable to non-distributed training)')
42 | parser.add_argument('--seed', type=int, default=None, help='random seed')
43 | parser.add_argument(
44 | '--deterministic',
45 | action='store_true',
46 | help='whether to set deterministic options for CUDNN backend.')
47 | parser.add_argument(
48 | '--options', nargs='+', action=DictAction, help='arguments in dict')
49 | parser.add_argument(
50 | '--launcher',
51 | choices=['none', 'pytorch', 'slurm', 'mpi'],
52 | default='none',
53 | help='job launcher')
54 | parser.add_argument('--local_rank', type=int, default=0)
55 | args = parser.parse_args()
56 | if 'LOCAL_RANK' not in os.environ:
57 | os.environ['LOCAL_RANK'] = str(args.local_rank)
58 |
59 | return args
60 |
61 |
62 | def main():
63 | args = parse_args()
64 |
65 | cfg = Config.fromfile(args.config)
66 | if args.options is not None:
67 | cfg.merge_from_dict(args.options)
68 | # set cudnn_benchmark
69 | if cfg.get('cudnn_benchmark', False):
70 | torch.backends.cudnn.benchmark = True
71 |
72 | # work_dir is determined in this priority: CLI > segment in file > filename
73 | if args.work_dir is not None:
74 | # update configs according to CLI args if args.work_dir is not None
75 | cfg.work_dir = args.work_dir
76 | elif cfg.get('work_dir', None) is None:
77 | # use config filename as default work_dir if cfg.work_dir is None
78 | cfg.work_dir = osp.join('./work_dirs',
79 | osp.splitext(osp.basename(args.config))[0])
80 | if args.resume_from is not None:
81 | cfg.resume_from = args.resume_from
82 | if args.gpu_ids is not None:
83 | cfg.gpu_ids = args.gpu_ids
84 | else:
85 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
86 |
87 | # init distributed env first, since logger depends on the dist info.
88 | if args.launcher == 'none':
89 | distributed = False
90 | else:
91 | distributed = True
92 | init_dist(args.launcher, **cfg.dist_params)
93 | _, world_size = get_dist_info()
94 | cfg.gpu_ids = range(world_size)
95 |
96 | # create work_dir
97 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
98 | # dump config
99 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
100 | # init the logger before other steps
101 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
102 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
103 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
104 |
105 | # init the meta dict to record some important information such as
106 | # environment info and seed, which will be logged
107 | meta = dict()
108 | # log env info
109 | env_info_dict = collect_env()
110 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
111 | dash_line = '-' * 60 + '\n'
112 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
113 | dash_line)
114 | meta['env_info'] = env_info
115 |
116 | # log some basic info
117 | logger.info(f'Distributed training: {distributed}')
118 | logger.info(f'Config:\n{cfg.pretty_text}')
119 |
120 | # set random seeds
121 | if args.seed is not None:
122 | logger.info(f'Set random seed to {args.seed}, '
123 | f'deterministic: {args.deterministic}')
124 | set_random_seed(args.seed, deterministic=args.deterministic)
125 | cfg.seed = args.seed
126 | meta['seed'] = args.seed
127 |
128 | model = build_architecture(cfg.model)
129 | model.init_weights()
130 |
131 | datasets = [build_dataset(cfg.data.train)]
132 | if len(cfg.workflow) == 2:
133 | val_dataset = copy.deepcopy(cfg.data.val)
134 | val_dataset.pipeline = cfg.data.train.pipeline
135 | datasets.append(build_dataset(val_dataset))
136 | # add an attribute for visualization convenience
137 | train_model(
138 | model,
139 | datasets,
140 | cfg,
141 | distributed=distributed,
142 | validate=(not args.no_validate),
143 | timestamp=timestamp,
144 | device='cpu' if args.device == 'cpu' else 'cuda',
145 | meta=meta)
146 |
147 |
148 | if __name__ == '__main__':
149 | main()
--------------------------------------------------------------------------------
/tools/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import mmcv
5 | import numpy as np
6 | import torch
7 | from mogen.models import build_architecture
8 | from mmcv.runner import load_checkpoint
9 | from mmcv.parallel import MMDataParallel
10 | from mogen.utils.plot_utils import (
11 | recover_from_ric,
12 | plot_3d_motion,
13 | t2m_kinematic_chain,
14 | plot_3d_motion_kit,
15 | kit_kinematic_chain
16 | )
17 | from scipy.ndimage import gaussian_filter
18 |
19 |
20 | def motion_temporal_filter(motion, sigma=1):
21 | motion = motion.reshape(motion.shape[0], -1)
22 | for i in range(motion.shape[1]):
23 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
24 | return motion.reshape(motion.shape[0], -1, 3)
25 |
26 |
27 | def plot_t2m(data, result_path, npy_path, caption):
28 | joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy()
29 | joint = motion_temporal_filter(joint, sigma=2.5)
30 | plot_3d_motion(result_path, t2m_kinematic_chain, joint, title=caption, fps=20)
31 | if npy_path is not None:
32 | np.save(npy_path, joint)
33 |
34 | def plot_kit(data, result_path, npy_path, caption):
35 | joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy()
36 | joint = motion_temporal_filter(joint, sigma=2.5)
37 | plot_3d_motion_kit(result_path, kit_kinematic_chain, joint, title=caption, fps=20)
38 | if npy_path is not None:
39 | np.save(npy_path, joint)
40 |
41 | def parse_args():
42 | parser = argparse.ArgumentParser(description='mogen evaluation')
43 | parser.add_argument('config', help='test config file path') # kit(configs/remodiffuse/remodiffuse_kit.py)
44 | parser.add_argument('checkpoint', help='checkpoint file')
45 | parser.add_argument('--text', help='motion description')
46 | parser.add_argument('--motion_length', type=int, help='expected motion length')
47 | parser.add_argument('--out', help='output animation file')
48 | parser.add_argument('--pose_npy', help='output pose sequence file', default=None)
49 | parser.add_argument(
50 | '--launcher',
51 | choices=['none', 'pytorch', 'slurm', 'mpi'],
52 | default='none',
53 | help='job launcher')
54 | parser.add_argument('--local_rank', type=int, default=0)
55 | parser.add_argument(
56 | '--device',
57 | choices=['cpu', 'cuda'],
58 | default='cuda',
59 | help='device used for testing')
60 | args = parser.parse_args()
61 | if 'LOCAL_RANK' not in os.environ:
62 | os.environ['LOCAL_RANK'] = str(args.local_rank)
63 | return args
64 |
65 |
66 | def main():
67 | args = parse_args()
68 |
69 | cfg = mmcv.Config.fromfile(args.config)
70 | # set cudnn_benchmark
71 | if cfg.get('cudnn_benchmark', False):
72 | torch.backends.cudnn.benchmark = True
73 | cfg.data.test.test_mode = True
74 |
75 | # init distributed env first, since logger depends on the dist info.
76 | if args.launcher == 'none':
77 | distributed = False
78 | else:
79 | distributed = True
80 | init_dist(args.launcher, **cfg.dist_params)
81 |
82 | assert args.motion_length >= 16 and args.motion_length <= 196
83 |
84 | # build the model and load checkpoint
85 | model = build_architecture(cfg.model)
86 | fp16_cfg = cfg.get('fp16', None)
87 | if fp16_cfg is not None:
88 | wrap_fp16_model(model)
89 | load_checkpoint(model, args.checkpoint, map_location='cpu')
90 |
91 | if args.device == 'cpu':
92 | model = model.cpu()
93 | else:
94 | model = MMDataParallel(model, device_ids=[0])
95 | model.eval()
96 |
97 | dataset_name = cfg.data.test.dataset_name
98 | print("dataset_name",dataset_name)
99 | if dataset_name == 'human_ml3d':
100 | #assert dataset_name == "human_ml3d"
101 | mean_path = "data/datasets/human_ml3d/mean.npy"
102 | std_path = "data/datasets/human_ml3d/std.npy"
103 | mean = np.load(mean_path)
104 | std = np.load(std_path)
105 | else:
106 | #assert dataset_name == "kit_ml"
107 | mean_path = "data/datasets/kit_ml/mean.npy"
108 | std_path = "data/datasets/kit_ml/std.npy"
109 | mean = np.load(mean_path)
110 | std = np.load(std_path)
111 |
112 | device = args.device
113 | text = args.text
114 | motion_length = args.motion_length
115 | if dataset_name == 'human_ml3d':
116 | motion = torch.zeros(1, motion_length, 263).to(device)
117 | else:
118 | motion = torch.zeros(1, motion_length, 251).to(device)
119 | motion_mask = torch.ones(1, motion_length).to(device)
120 | motion_length = torch.Tensor([motion_length]).long().to(device)
121 | model = model.to(device)
122 |
123 | input = {
124 | 'motion': motion,
125 | 'motion_mask': motion_mask,
126 | 'motion_length': motion_length,
127 | 'motion_metas': [{'text': text}],
128 | }
129 |
130 | all_pred_motion = []
131 | with torch.no_grad():
132 | input['inference_kwargs'] = {}
133 | output_list = []
134 | output = model(**input)[0]['pred_motion']
135 | pred_motion = output.cpu().detach().numpy()
136 | pred_motion = pred_motion * std + mean
137 |
138 | if dataset_name == 'human_ml3d':
139 | plot_t2m(pred_motion, args.out, args.pose_npy, text)
140 | else:
141 | plot_kit(pred_motion, args.out, args.pose_npy, text)
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
--------------------------------------------------------------------------------