├── .gitignore ├── LICENSE ├── README.md ├── configs ├── _base_ │ └── datasets │ │ ├── human_ml3d_bs128.py │ │ └── kit_ml_bs128.py ├── mdm │ └── mdm_t2m_official.py ├── motiondiffuse │ ├── motiondiffuse_kit.py │ └── motiondiffuse_t2m.py └── remodiffuse │ ├── remodiffuse_kit.py │ └── remodiffuse_t2m.py ├── imgs ├── pipeline.png └── teaser.png ├── mogen ├── __init__.py ├── apis │ ├── __init__.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── distributed_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── eval_hooks.py │ │ ├── evaluators │ │ │ ├── __init__.py │ │ │ ├── base_evaluator.py │ │ │ ├── diversity_evaluator.py │ │ │ ├── fid_evaluator.py │ │ │ ├── matching_score_evaluator.py │ │ │ ├── multimodality_evaluator.py │ │ │ └── precision_evaluator.py │ │ ├── get_model.py │ │ └── utils.py │ └── optimizer │ │ ├── __init__.py │ │ └── builder.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── builder.py │ ├── dataset_wrappers.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formatting.py │ │ └── transforms.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ └── text_motion_dataset.py ├── models │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ ├── base_architecture.py │ │ ├── diffusion_architecture.py │ │ └── vae_architecture.py │ ├── attentions │ │ ├── __init__.py │ │ ├── base_attention.py │ │ ├── efficient_attention.py │ │ └── semantics_modulated.py │ ├── builder.py │ ├── losses │ │ ├── __init__.py │ │ ├── mse_loss.py │ │ └── utils.py │ ├── rnns │ │ ├── __init__.py │ │ └── t2m_bigru.py │ ├── transformers │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── diffusion_transformer.py │ │ ├── mdm.py │ │ ├── motiondiffuse.py │ │ ├── position_encoding.py │ │ └── remodiffuse.py │ └── utils │ │ ├── __init__.py │ │ ├── gaussian_diffusion.py │ │ ├── mlp.py │ │ ├── stylization_block.py │ │ └── word_vectorizer.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── dist_utils.py │ ├── logger.py │ ├── misc.py │ ├── path_utils.py │ └── plot_utils.py └── version.py ├── requirements.txt └── tools ├── dist_train.sh ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── train.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/*.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # custom 108 | data 109 | # data for pytest moved to http server 110 | # !tests/data 111 | .vscode 112 | .idea 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | logs/ 118 | 119 | # Pytorch 120 | *.pth 121 | *.pt 122 | 123 | 124 | # Visualization 125 | *.mp4 126 | *.png 127 | *.gif 128 | *.jpg 129 | *.obj 130 | *.ply 131 | !demo/resources/* 132 | 133 | # Resources as exception 134 | !resources/* 135 | 136 | # Loaded/Saved data files 137 | *.npz 138 | *.npy 139 | *.pickle 140 | 141 | # MacOS 142 | *DS_Store* 143 | # git 144 | *.orig 145 | 146 | env.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model

4 | 5 |
6 | Mingyuan Zhang1  7 | Xinying Guo1  8 | Liang Pan1  9 | Zhongang Cai1,2  10 | Fangzhou Hong1  11 | Huirong Li1
12 | Lei Yang2  13 | Ziwei Liu1+ 14 |
15 |
16 | 1S-Lab, Nanyang Technological University  17 | 2SenseTime Research  18 |
19 |
20 | +corresponding author 21 |
22 | 23 | 24 | --- 25 | 26 |

27 | [Project Page] • 28 | [arXiv] • 29 | [Video] • 30 | [Colab Demo] • 31 | [Hugging Face Demo] 32 |

33 | Accepted to ICCV 2023

34 | visitor badge 35 | 36 | 37 |
38 | 39 | 40 | >**Abstract:** 3D human motion generation is crucial for creative industry. Recent advances rely on generative models with domain knowledge for text-driven motion generation, leading to substantial progress in capturing common motions. However, the performance on more diverse motions remains unsatisfactory. In this work, we propose **ReMoDiffuse**, a diffusion-model-based motion generation framework that integrates a retrieval mechanism to refine the denoising process. 41 | 42 |
43 | 44 | 45 | 46 | 47 |
48 | 49 | >**Pipeline Overview:** ReMoDiffuse is a retrieval-augmented 3D human motion diffusion model. Benefiting from the extra knowledge from the retrieved samples, ReMoDiffuse is able to achieve high-fidelity on the given prompts. It contains three core components: a) **Hybrid Retrieval** database stores multi-modality features of each motion sequence. b) Semantics-modulated transformer incorporates several identical decoder layers, including a **Semantics-Modulated Attention (SMA)** layer and an FFN layer. The SMA layer will adaptively absorb knowledge from both retrived samples and the given prompts. c) **Condition Mxture** technique is proposed to better mix model's outputs under different combinations of conditions. 50 | 51 | ## Updates 52 | 53 | [09/2023] Add a [🤗Hugging Face Demo](https://huggingface.co/spaces/mingyuan/ReMoDiffuse)! 54 | 55 | [09/2023] Add a [Colab Demo](https://colab.research.google.com/drive/1jztE7c8js3P4YFbw5cGNPJAsCVrreTov?usp=sharing)! [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jztE7c8js3P4YFbw5cGNPJAsCVrreTov?usp=sharing) 56 | 57 | [09/2023] Release code for [ReMoDiffuse](https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html) and [MotionDiffuse](https://mingyuan-zhang.github.io/projects/MotionDiffuse.html) 58 | 59 | ## Benchmark and Model Zoo 60 | 61 | #### Supported methods 62 | 63 | - [x] [MotionDiffuse](https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html) 64 | - [x] [MDM](https://guytevet.github.io/mdm-page/) 65 | - [x] [ReMoDiffuse](https://mingyuan-zhang.github.io/projects/MotionDiffuse.html) 66 | 67 | 68 | ## Citation 69 | 70 | If you find our work useful for your research, please consider citing the paper: 71 | 72 | ``` 73 | @article{zhang2023remodiffuse, 74 | title={ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model}, 75 | author={Zhang, Mingyuan and Guo, Xinying and Pan, Liang and Cai, Zhongang and Hong, Fangzhou and Li, Huirong and Yang, Lei and Liu, Ziwei}, 76 | journal={arXiv preprint arXiv:2304.01116}, 77 | year={2023} 78 | } 79 | @article{zhang2022motiondiffuse, 80 | title={MotionDiffuse: Text-Driven Human Motion Generation with Diffusion Model}, 81 | author={Zhang, Mingyuan and Cai, Zhongang and Pan, Liang and Hong, Fangzhou and Guo, Xinying and Yang, Lei and Liu, Ziwei}, 82 | journal={arXiv preprint arXiv:2208.15001}, 83 | year={2022} 84 | } 85 | ``` 86 | 87 | ## Installation 88 | 89 | ```shell 90 | # Create Conda Environment 91 | conda create -n mogen python=3.9 -y 92 | conda activate mogen 93 | 94 | # C++ Environment 95 | export PATH=/mnt/lustre/share/gcc/gcc-8.5.0/bin:$PATH 96 | export LD_LIBRARY_PATH=/mnt/lustre/share/gcc/gcc-8.5.0/lib:/mnt/lustre/share/gcc/gcc-8.5.0/lib64:/mnt/lustre/share/gcc/gmp-4.3.2/lib:/mnt/lustre/share/gcc/mpc-0.8.1/lib:/mnt/lustre/share/gcc/mpfr-2.4.2/lib:$LD_LIBRARY_PATH 97 | 98 | # Install Pytorch 99 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -y 100 | 101 | # Install MMCV 102 | pip install "mmcv-full>=1.4.2,<=1.9.0" -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.1/index.html 103 | 104 | # Install Pytorch3d 105 | conda install -c bottler nvidiacub -y 106 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y 107 | conda install pytorch3d -c pytorch3d -y 108 | 109 | # Install other requirements 110 | pip install -r requirements.txt 111 | ``` 112 | 113 | ## Data Preparation 114 | 115 | Download data files from google drive [link](https://drive.google.com/drive/folders/13kwahiktQ2GMVKfVH3WT-VGAQ6JHbvUv?usp=sharing) or Baidu Netdisk [link](https://pan.baidu.com/s/1604jks-9PtBUtqCpQQmeEg)(access code: vprc). Unzipped all files and arrange them in the following file structure: 116 | 117 | ```text 118 | ReMoDiffuse 119 | ├── mogen 120 | ├── tools 121 | ├── configs 122 | ├── logs 123 | │ ├── motiondiffuse 124 | │ ├── remodiffuse 125 | │ └── mdm 126 | └── data 127 | ├── database 128 | ├── datasets 129 | ├── evaluators 130 | └── glove 131 | ``` 132 | 133 | ## Training 134 | 135 | ### Training with a single / multiple GPUs 136 | 137 | ```shell 138 | PYTHONPATH=".":$PYTHONPATH python tools/train.py ${CONFIG_FILE} ${WORK_DIR} --no-validate 139 | ``` 140 | 141 | **Note:** The provided config files are designed for training with 8 gpus. If you want to train on a single gpu, you can reduce the number of epochs to one-fourth of the original. 142 | 143 | ### Training with Slurm 144 | 145 | ```shell 146 | ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} --no-validate 147 | ``` 148 | 149 | Common optional arguments include: 150 | - `--resume-from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file. 151 | - `--no-validate`: Whether not to evaluate the checkpoint during training. 152 | 153 | Example: using 8 GPUs to train ReMoDiffuse on a slurm cluster. 154 | ```shell 155 | ./tools/slurm_train.sh my_partition my_job configs/remodiffuse/remodiffuse_kit.py logs/remodiffuse_kit 8 --no-validate 156 | ``` 157 | 158 | ## Evaluation 159 | 160 | ### Evaluate with a single GPU / multiple GPUs 161 | 162 | ```shell 163 | PYTHONPATH=".":$PYTHONPATH python tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT} 164 | ``` 165 | 166 | ### Evaluate with slurm 167 | 168 | ```shell 169 | ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG} ${WORK_DIR} ${CHECKPOINT} 170 | ``` 171 | Example: 172 | ```shell 173 | ./tools/slurm_test.sh my_partition test_remodiffuse configs/remodiffuse/remodiffuse_kit.py logs/remodiffuse_kit logs/remodiffuse_kit/latest.pth 174 | ``` 175 | 176 | **Note:** Run full evaluation for HumanML3D dataset is very slow. You can change `replication_times` in [human_ml3d_bs128.py](configs/_base_/datasets/human_ml3d_bs128.py) to $1$ for a quick evaluation. 177 | 178 | ## Visualization 179 | 180 | ```shell 181 | PYTHONPATH=".":$PYTHONPATH python tools/visualize.py ${CONFIG} ${CHECKPOINT} \ 182 | --text ${TEXT} \ 183 | --motion_length ${MOTION_LENGTH} \ 184 | --out ${OUTPUT_ANIMATION_PATH} \ 185 | --device cpu 186 | ``` 187 | 188 | Example: 189 | ```shell 190 | PYTHONPATH=".":$PYTHONPATH python tools/visualize.py \ 191 | configs/remodiffuse/remodiffuse_t2m.py \ 192 | logs/remodiffuse/remodiffuse_t2m/latest.pth \ 193 | --text "a person is running quickly" \ 194 | --motion_length 120 \ 195 | --out "test.gif" \ 196 | --device cpu 197 | ``` 198 | 199 | ## Acknowledgement 200 | 201 | This study is supported by the Ministry of Education, Singapore, under its MOE AcRF Tier 2 (MOE-T2EP20221-0012), NTU NAP, and under the RIE2020 Industry Alignment Fund – Industry Collaboration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s). 202 | 203 | The visualization tool is developed on top of [Generating Diverse and Natural 3D Human Motions from Text](https://github.com/EricGuo5513/text-to-motion) 204 | -------------------------------------------------------------------------------- /configs/_base_/datasets/human_ml3d_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 3 | meta_keys = ['text', 'token'] 4 | train_pipeline = [ 5 | dict( 6 | type='Normalize', 7 | mean_path='data/datasets/human_ml3d/mean.npy', 8 | std_path='data/datasets/human_ml3d/std.npy'), 9 | dict(type='Crop', crop_size=196), 10 | dict(type='ToTensor', keys=data_keys), 11 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 12 | ] 13 | 14 | data = dict( 15 | samples_per_gpu=128, 16 | workers_per_gpu=1, 17 | train=dict( 18 | type='RepeatDataset', 19 | dataset=dict( 20 | type='TextMotionDataset', 21 | dataset_name='human_ml3d', 22 | data_prefix='data', 23 | pipeline=train_pipeline, 24 | ann_file='train.txt', 25 | motion_dir='motions', 26 | text_dir='texts', 27 | token_dir='tokens', 28 | clip_feat_dir='clip_feats', 29 | ), 30 | times=200 31 | ), 32 | test=dict( 33 | type='TextMotionDataset', 34 | dataset_name='human_ml3d', 35 | data_prefix='data', 36 | pipeline=train_pipeline, 37 | ann_file='test.txt', 38 | motion_dir='motions', 39 | text_dir='texts', 40 | token_dir='tokens', 41 | clip_feat_dir='clip_feats', 42 | eval_cfg=dict( 43 | shuffle_indexes=True, 44 | replication_times=20, 45 | replication_reduction='statistics', 46 | text_encoder_name='human_ml3d', 47 | text_encoder_path='data/evaluators/human_ml3d/finest.tar', 48 | motion_encoder_name='human_ml3d', 49 | motion_encoder_path='data/evaluators/human_ml3d/finest.tar', 50 | metrics=[ 51 | dict(type='R Precision', batch_size=32, top_k=3), 52 | dict(type='Matching Score', batch_size=32), 53 | dict(type='FID'), 54 | dict(type='Diversity', num_samples=300), 55 | dict(type='MultiModality', num_samples=100, num_repeats=30, num_picks=10) 56 | ] 57 | ), 58 | test_mode=True 59 | ) 60 | ) -------------------------------------------------------------------------------- /configs/_base_/datasets/kit_ml_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 3 | meta_keys = ['text', 'token'] 4 | train_pipeline = [ 5 | dict(type='Crop', crop_size=196), 6 | dict( 7 | type='Normalize', 8 | mean_path='data/datasets/kit_ml/mean.npy', 9 | std_path='data/datasets/kit_ml/std.npy'), 10 | dict(type='ToTensor', keys=data_keys), 11 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 12 | ] 13 | 14 | data = dict( 15 | samples_per_gpu=128, 16 | workers_per_gpu=1, 17 | train=dict( 18 | type='RepeatDataset', 19 | dataset=dict( 20 | type='TextMotionDataset', 21 | dataset_name='kit_ml', 22 | data_prefix='data', 23 | pipeline=train_pipeline, 24 | ann_file='train.txt', 25 | motion_dir='motions', 26 | text_dir='texts', 27 | token_dir='tokens', 28 | clip_feat_dir='clip_feats', 29 | ), 30 | times=100 31 | ), 32 | test=dict( 33 | type='TextMotionDataset', 34 | dataset_name='kit_ml', 35 | data_prefix='data', 36 | pipeline=train_pipeline, 37 | ann_file='test.txt', 38 | motion_dir='motions', 39 | text_dir='texts', 40 | token_dir='tokens', 41 | clip_feat_dir='clip_feats', 42 | eval_cfg=dict( 43 | shuffle_indexes=True, 44 | replication_times=20, 45 | replication_reduction='statistics', 46 | text_encoder_name='kit_ml', 47 | text_encoder_path='data/evaluators/kit_ml/finest.tar', 48 | motion_encoder_name='kit_ml', 49 | motion_encoder_path='data/evaluators/kit_ml/finest.tar', 50 | metrics=[ 51 | dict(type='R Precision', batch_size=32, top_k=3), 52 | dict(type='Matching Score', batch_size=32), 53 | dict(type='FID'), 54 | dict(type='Diversity', num_samples=300), 55 | dict(type='MultiModality', num_samples=50, num_repeats=30, num_picks=10) 56 | ] 57 | ), 58 | test_mode=True 59 | ) 60 | ) -------------------------------------------------------------------------------- /configs/mdm/mdm_t2m_official.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=1e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=50) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_layers = 8 33 | num_heads = 4 34 | dropout = 0.1 35 | cond_mask_prob = 0.1 36 | # model settings 37 | model = dict( 38 | type='MotionDiffusion', 39 | model=dict( 40 | type='MDMTransformer', 41 | input_feats=input_feats, 42 | latent_dim=latent_dim, 43 | ff_size=ff_size, 44 | num_layers=num_layers, 45 | num_heads=num_heads, 46 | dropout=dropout, 47 | time_embed_dim=time_embed_dim, 48 | cond_mask_prob=cond_mask_prob, 49 | guide_scale=2.5, 50 | clip_version='ViT-B/32', 51 | use_official_ckpt=True 52 | ), 53 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 54 | diffusion_train=dict( 55 | beta_scheduler='cosine', 56 | diffusion_steps=1000, 57 | model_mean_type='start_x', 58 | model_var_type='fixed_small', 59 | ), 60 | diffusion_test=dict( 61 | beta_scheduler='cosine', 62 | diffusion_steps=1000, 63 | model_mean_type='start_x', 64 | model_var_type='fixed_small', 65 | ), 66 | inference_type='ddpm' 67 | ) -------------------------------------------------------------------------------- /configs/motiondiffuse/motiondiffuse_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=50) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | # model settings 35 | model = dict( 36 | type='MotionDiffusion', 37 | model=dict( 38 | type='MotionDiffuseTransformer', 39 | input_feats=input_feats, 40 | max_seq_len=max_seq_len, 41 | latent_dim=latent_dim, 42 | time_embed_dim=time_embed_dim, 43 | num_layers=8, 44 | sa_block_cfg=dict( 45 | type='EfficientSelfAttention', 46 | latent_dim=latent_dim, 47 | num_heads=num_heads, 48 | dropout=dropout, 49 | time_embed_dim=time_embed_dim 50 | ), 51 | ca_block_cfg=dict( 52 | type='EfficientCrossAttention', 53 | latent_dim=latent_dim, 54 | text_latent_dim=text_latent_dim, 55 | num_heads=num_heads, 56 | dropout=dropout, 57 | time_embed_dim=time_embed_dim 58 | ), 59 | ffn_cfg=dict( 60 | latent_dim=latent_dim, 61 | ffn_dim=ff_size, 62 | dropout=dropout, 63 | time_embed_dim=time_embed_dim 64 | ), 65 | text_encoder=dict( 66 | pretrained_model='clip', 67 | latent_dim=text_latent_dim, 68 | num_layers=4, 69 | num_heads=4, 70 | ff_size=2048, 71 | dropout=dropout, 72 | use_text_proj=True 73 | ) 74 | ), 75 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 76 | diffusion_train=dict( 77 | beta_scheduler='linear', 78 | diffusion_steps=1000, 79 | model_mean_type='epsilon', 80 | model_var_type='fixed_small', 81 | ), 82 | diffusion_test=dict( 83 | beta_scheduler='linear', 84 | diffusion_steps=1000, 85 | model_mean_type='epsilon', 86 | model_var_type='fixed_small', 87 | ), 88 | inference_type='ddpm' 89 | ) 90 | -------------------------------------------------------------------------------- /configs/motiondiffuse/motiondiffuse_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=50) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | # model settings 35 | model = dict( 36 | type='MotionDiffusion', 37 | model=dict( 38 | type='MotionDiffuseTransformer', 39 | input_feats=input_feats, 40 | max_seq_len=max_seq_len, 41 | latent_dim=latent_dim, 42 | time_embed_dim=time_embed_dim, 43 | num_layers=8, 44 | sa_block_cfg=dict( 45 | type='EfficientSelfAttention', 46 | latent_dim=latent_dim, 47 | num_heads=num_heads, 48 | dropout=dropout, 49 | time_embed_dim=time_embed_dim 50 | ), 51 | ca_block_cfg=dict( 52 | type='EfficientCrossAttention', 53 | latent_dim=latent_dim, 54 | text_latent_dim=text_latent_dim, 55 | num_heads=num_heads, 56 | dropout=dropout, 57 | time_embed_dim=time_embed_dim 58 | ), 59 | ffn_cfg=dict( 60 | latent_dim=latent_dim, 61 | ffn_dim=ff_size, 62 | dropout=dropout, 63 | time_embed_dim=time_embed_dim 64 | ), 65 | text_encoder=dict( 66 | pretrained_model='clip', 67 | latent_dim=text_latent_dim, 68 | num_layers=4, 69 | num_heads=4, 70 | ff_size=2048, 71 | dropout=dropout, 72 | use_text_proj=True 73 | ) 74 | ), 75 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 76 | diffusion_train=dict( 77 | beta_scheduler='linear', 78 | diffusion_steps=1000, 79 | model_mean_type='epsilon', 80 | model_var_type='fixed_small', 81 | ), 82 | diffusion_test=dict( 83 | beta_scheduler='linear', 84 | diffusion_steps=1000, 85 | model_mean_type='epsilon', 86 | model_var_type='fixed_small', 87 | ), 88 | inference_type='ddpm' 89 | ) 90 | data = dict(samples_per_gpu=128) -------------------------------------------------------------------------------- /configs/remodiffuse/remodiffuse_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=20) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict( 37 | type='MotionDiffusion', 38 | model=dict( 39 | type='ReMoDiffuseTransformer', 40 | input_feats=input_feats, 41 | max_seq_len=max_seq_len, 42 | latent_dim=latent_dim, 43 | time_embed_dim=time_embed_dim, 44 | num_layers=4, 45 | ca_block_cfg=dict( 46 | type='SemanticsModulatedAttention', 47 | latent_dim=latent_dim, 48 | text_latent_dim=text_latent_dim, 49 | num_heads=num_heads, 50 | dropout=dropout, 51 | time_embed_dim=time_embed_dim 52 | ), 53 | ffn_cfg=dict( 54 | latent_dim=latent_dim, 55 | ffn_dim=ff_size, 56 | dropout=dropout, 57 | time_embed_dim=time_embed_dim 58 | ), 59 | text_encoder=dict( 60 | pretrained_model='clip', 61 | latent_dim=text_latent_dim, 62 | num_layers=2, 63 | ff_size=2048, 64 | dropout=dropout, 65 | use_text_proj=False 66 | ), 67 | retrieval_cfg=dict( 68 | num_retrieval=2, 69 | stride=4, 70 | num_layers=2, 71 | num_motion_layers=2, 72 | kinematic_coef=0.1, 73 | topk=2, 74 | retrieval_file='data/database/kit_text_train.npz', 75 | latent_dim=latent_dim, 76 | output_dim=latent_dim, 77 | max_seq_len=max_seq_len, 78 | num_heads=num_heads, 79 | ff_size=ff_size, 80 | dropout=dropout, 81 | ffn_cfg=dict( 82 | latent_dim=latent_dim, 83 | ffn_dim=ff_size, 84 | dropout=dropout, 85 | ), 86 | sa_block_cfg=dict( 87 | type='EfficientSelfAttention', 88 | latent_dim=latent_dim, 89 | num_heads=num_heads, 90 | dropout=dropout 91 | ), 92 | ), 93 | scale_func_cfg=dict( 94 | coarse_scale=4.0, 95 | both_coef=0.78123, 96 | text_coef=0.39284, 97 | retr_coef=-0.12475 98 | ) 99 | ), 100 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 101 | diffusion_train=dict( 102 | beta_scheduler='linear', 103 | diffusion_steps=1000, 104 | model_mean_type='start_x', 105 | model_var_type='fixed_large', 106 | ), 107 | diffusion_test=dict( 108 | beta_scheduler='linear', 109 | diffusion_steps=1000, 110 | model_mean_type='start_x', 111 | model_var_type='fixed_large', 112 | respace='15,15,8,6,6', 113 | ), 114 | inference_type='ddim' 115 | ) -------------------------------------------------------------------------------- /configs/remodiffuse/remodiffuse_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=40) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict( 37 | type='MotionDiffusion', 38 | model=dict( 39 | type='ReMoDiffuseTransformer', 40 | input_feats=input_feats, 41 | max_seq_len=max_seq_len, 42 | latent_dim=latent_dim, 43 | time_embed_dim=time_embed_dim, 44 | num_layers=4, 45 | ca_block_cfg=dict( 46 | type='SemanticsModulatedAttention', 47 | latent_dim=latent_dim, 48 | text_latent_dim=text_latent_dim, 49 | num_heads=num_heads, 50 | dropout=dropout, 51 | time_embed_dim=time_embed_dim 52 | ), 53 | ffn_cfg=dict( 54 | latent_dim=latent_dim, 55 | ffn_dim=ff_size, 56 | dropout=dropout, 57 | time_embed_dim=time_embed_dim 58 | ), 59 | text_encoder=dict( 60 | pretrained_model='clip', 61 | latent_dim=text_latent_dim, 62 | num_layers=2, 63 | ff_size=2048, 64 | dropout=dropout, 65 | use_text_proj=False 66 | ), 67 | retrieval_cfg=dict( 68 | num_retrieval=2, 69 | stride=4, 70 | num_layers=2, 71 | num_motion_layers=2, 72 | kinematic_coef=0.1, 73 | topk=2, 74 | retrieval_file='data/database/t2m_text_train.npz', 75 | latent_dim=latent_dim, 76 | output_dim=latent_dim, 77 | max_seq_len=max_seq_len, 78 | num_heads=num_heads, 79 | ff_size=ff_size, 80 | dropout=dropout, 81 | ffn_cfg=dict( 82 | latent_dim=latent_dim, 83 | ffn_dim=ff_size, 84 | dropout=dropout, 85 | ), 86 | sa_block_cfg=dict( 87 | type='EfficientSelfAttention', 88 | latent_dim=latent_dim, 89 | num_heads=num_heads, 90 | dropout=dropout 91 | ), 92 | ), 93 | scale_func_cfg=dict( 94 | coarse_scale=6.5, 95 | both_coef=0.52351, 96 | text_coef=-0.28419, 97 | retr_coef=2.39872 98 | ) 99 | ), 100 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 101 | diffusion_train=dict( 102 | beta_scheduler='linear', 103 | diffusion_steps=1000, 104 | model_mean_type='start_x', 105 | model_var_type='fixed_large', 106 | ), 107 | diffusion_test=dict( 108 | beta_scheduler='linear', 109 | diffusion_steps=1000, 110 | model_mean_type='start_x', 111 | model_var_type='fixed_large', 112 | respace='15,15,8,6,6', 113 | ), 114 | inference_type='ddim' 115 | ) -------------------------------------------------------------------------------- /imgs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/imgs/pipeline.png -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/imgs/teaser.png -------------------------------------------------------------------------------- /mogen/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | from packaging.version import parse 5 | 6 | from .version import __version__ 7 | 8 | 9 | def digit_version(version_str: str, length: int = 4): 10 | """Convert a version string into a tuple of integers. 11 | This method is usually used for comparing two versions. For pre-release 12 | versions: alpha < beta < rc. 13 | Args: 14 | version_str (str): The version string. 15 | length (int): The maximum number of version levels. Default: 4. 16 | Returns: 17 | tuple[int]: The version info in digits (integers). 18 | """ 19 | version = parse(version_str) 20 | assert version.release, f'failed to parse version {version_str}' 21 | release = list(version.release) 22 | release = release[:length] 23 | if len(release) < length: 24 | release = release + [0] * (length - len(release)) 25 | if version.is_prerelease: 26 | mapping = {'a': -3, 'b': -2, 'rc': -1} 27 | val = -4 28 | # version.pre can be None 29 | if version.pre: 30 | if version.pre[0] not in mapping: 31 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 32 | 'version checking may go wrong') 33 | else: 34 | val = mapping[version.pre[0]] 35 | release.extend([val, version.pre[-1]]) 36 | else: 37 | release.extend([val, 0]) 38 | 39 | elif version.is_postrelease: 40 | release.extend([1, version.post]) 41 | else: 42 | release.extend([0, 0]) 43 | return tuple(release) 44 | 45 | 46 | mmcv_minimum_version = '1.4.2' 47 | mmcv_maximum_version = '1.9.0' 48 | mmcv_version = digit_version(mmcv.__version__) 49 | 50 | 51 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 52 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 53 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 54 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 55 | 56 | __all__ = ['__version__', 'digit_version'] -------------------------------------------------------------------------------- /mogen/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.apis import test, train 2 | from mogen.apis.test import ( 3 | collect_results_cpu, 4 | collect_results_gpu, 5 | multi_gpu_test, 6 | single_gpu_test, 7 | ) 8 | from mogen.apis.train import set_random_seed, train_model 9 | 10 | __all__ = [ 11 | 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', 12 | 'single_gpu_test', 'set_random_seed', 'train_model' 13 | ] -------------------------------------------------------------------------------- /mogen/apis/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import shutil 4 | import tempfile 5 | import time 6 | 7 | import mmcv 8 | import torch 9 | import torch.distributed as dist 10 | from mmcv.runner import get_dist_info 11 | 12 | 13 | def single_gpu_test(model, data_loader): 14 | """Test with single gpu.""" 15 | model.eval() 16 | results = [] 17 | dataset = data_loader.dataset 18 | prog_bar = mmcv.ProgressBar(len(dataset)) 19 | for i, data in enumerate(data_loader): 20 | with torch.no_grad(): 21 | result = model(return_loss=False, **data) 22 | 23 | batch_size = len(result) 24 | if isinstance(result, list): 25 | results.extend(result) 26 | else: 27 | results.append(result) 28 | 29 | batch_size = data['motion'].size(0) 30 | for _ in range(batch_size): 31 | prog_bar.update() 32 | return results 33 | 34 | 35 | def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): 36 | """Test model with multiple gpus. 37 | This method tests model with multiple gpus and collects the results 38 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 39 | it encodes results to gpu tensors and use gpu communication for results 40 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 41 | and collects them by the rank 0 worker. 42 | Args: 43 | model (nn.Module): Model to be tested. 44 | data_loader (nn.Dataloader): Pytorch data loader. 45 | tmpdir (str): Path of directory to save the temporary results from 46 | different gpus under cpu mode. 47 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 48 | Returns: 49 | list: The prediction results. 50 | """ 51 | model.eval() 52 | results = [] 53 | dataset = data_loader.dataset 54 | rank, world_size = get_dist_info() 55 | if rank == 0: 56 | # Check if tmpdir is valid for cpu_collect 57 | if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)): 58 | raise OSError((f'The tmpdir {tmpdir} already exists.', 59 | ' Since tmpdir will be deleted after testing,', 60 | ' please make sure you specify an empty one.')) 61 | prog_bar = mmcv.ProgressBar(len(dataset)) 62 | time.sleep(2) # This line can prevent deadlock problem in some cases. 63 | for i, data in enumerate(data_loader): 64 | with torch.no_grad(): 65 | result = model(return_loss=False, **data) 66 | if isinstance(result, list): 67 | results.extend(result) 68 | else: 69 | results.append(result) 70 | 71 | if rank == 0: 72 | batch_size = data['motion'].size(0) 73 | for _ in range(batch_size * world_size): 74 | prog_bar.update() 75 | 76 | # collect results from all ranks 77 | if gpu_collect: 78 | results = collect_results_gpu(results, len(dataset)) 79 | else: 80 | results = collect_results_cpu(results, len(dataset), tmpdir) 81 | return results 82 | 83 | 84 | def collect_results_cpu(result_part, size, tmpdir=None): 85 | """Collect results in cpu.""" 86 | rank, world_size = get_dist_info() 87 | # create a tmp dir if it is not specified 88 | if tmpdir is None: 89 | MAX_LEN = 512 90 | # 32 is whitespace 91 | dir_tensor = torch.full((MAX_LEN, ), 92 | 32, 93 | dtype=torch.uint8, 94 | device='cuda') 95 | if rank == 0: 96 | mmcv.mkdir_or_exist('.dist_test') 97 | tmpdir = tempfile.mkdtemp(dir='.dist_test') 98 | tmpdir = torch.tensor( 99 | bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') 100 | dir_tensor[:len(tmpdir)] = tmpdir 101 | dist.broadcast(dir_tensor, 0) 102 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 103 | else: 104 | mmcv.mkdir_or_exist(tmpdir) 105 | # dump the part result to the dir 106 | mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) 107 | dist.barrier() 108 | # collect all parts 109 | if rank != 0: 110 | return None 111 | else: 112 | # load results of all parts from tmp dir 113 | part_list = [] 114 | for i in range(world_size): 115 | part_file = osp.join(tmpdir, f'part_{i}.pkl') 116 | part_result = mmcv.load(part_file) 117 | part_list.append(part_result) 118 | # sort the results 119 | ordered_results = [] 120 | for res in zip(*part_list): 121 | ordered_results.extend(list(res)) 122 | # the dataloader may pad some samples 123 | ordered_results = ordered_results[:size] 124 | # remove tmp dir 125 | shutil.rmtree(tmpdir) 126 | return ordered_results 127 | 128 | 129 | def collect_results_gpu(result_part, size): 130 | """Collect results in gpu.""" 131 | rank, world_size = get_dist_info() 132 | # dump result part to tensor with pickle 133 | part_tensor = torch.tensor( 134 | bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') 135 | # gather all result part tensor shape 136 | shape_tensor = torch.tensor(part_tensor.shape, device='cuda') 137 | shape_list = [shape_tensor.clone() for _ in range(world_size)] 138 | dist.all_gather(shape_list, shape_tensor) 139 | # padding result part tensor to max length 140 | shape_max = torch.tensor(shape_list).max() 141 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') 142 | part_send[:shape_tensor[0]] = part_tensor 143 | part_recv_list = [ 144 | part_tensor.new_zeros(shape_max) for _ in range(world_size) 145 | ] 146 | # gather all result part 147 | dist.all_gather(part_recv_list, part_send) 148 | 149 | if rank == 0: 150 | part_list = [] 151 | for recv, shape in zip(part_recv_list, shape_list): 152 | part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) 153 | part_list.append(part_result) 154 | # sort the results 155 | ordered_results = [] 156 | for res in zip(*part_list): 157 | ordered_results.extend(list(res)) 158 | # the dataloader may pad some samples 159 | ordered_results = ordered_results[:size] 160 | return ordered_results -------------------------------------------------------------------------------- /mogen/apis/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import ( 8 | DistSamplerSeedHook, 9 | Fp16OptimizerHook, 10 | OptimizerHook, 11 | build_runner, 12 | ) 13 | 14 | from mogen.core.distributed_wrapper import DistributedDataParallelWrapper 15 | from mogen.core.evaluation import DistEvalHook, EvalHook 16 | from mogen.core.optimizer import build_optimizers 17 | from mogen.datasets import build_dataloader, build_dataset 18 | from mogen.utils import get_root_logger 19 | 20 | 21 | def set_random_seed(seed, deterministic=False): 22 | """Set random seed. 23 | Args: 24 | seed (int): Seed to be used. 25 | deterministic (bool): Whether to set the deterministic option for 26 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 27 | to True and `torch.backends.cudnn.benchmark` to False. 28 | Default: False. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | if deterministic: 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | 38 | 39 | def train_model(model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | device='cuda', 46 | meta=None): 47 | """Main api for training model.""" 48 | logger = get_root_logger(cfg.log_level) 49 | 50 | # prepare data loaders 51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 52 | 53 | data_loaders = [ 54 | build_dataloader( 55 | ds, 56 | cfg.data.samples_per_gpu, 57 | cfg.data.workers_per_gpu, 58 | # cfg.gpus will be ignored if distributed 59 | num_gpus=len(cfg.gpu_ids), 60 | dist=distributed, 61 | round_up=True, 62 | seed=cfg.seed) for ds in dataset 63 | ] 64 | 65 | # determine whether use adversarial training precess or not 66 | use_adverserial_train = cfg.get('use_adversarial_train', False) 67 | 68 | # put model on gpus 69 | if distributed: 70 | find_unused_parameters = cfg.get('find_unused_parameters', True) 71 | # Sets the `find_unused_parameters` parameter in 72 | # torch.nn.parallel.DistributedDataParallel 73 | if use_adverserial_train: 74 | # Use DistributedDataParallelWrapper for adversarial training 75 | model = DistributedDataParallelWrapper( 76 | model, 77 | device_ids=[torch.cuda.current_device()], 78 | broadcast_buffers=False, 79 | find_unused_parameters=find_unused_parameters) 80 | else: 81 | model = MMDistributedDataParallel( 82 | model.cuda(), 83 | device_ids=[torch.cuda.current_device()], 84 | broadcast_buffers=False, 85 | find_unused_parameters=find_unused_parameters) 86 | else: 87 | if device == 'cuda': 88 | model = MMDataParallel( 89 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 90 | elif device == 'cpu': 91 | model = model.cpu() 92 | else: 93 | raise ValueError(F'unsupported device name {device}.') 94 | 95 | # build runner 96 | optimizer = build_optimizers(model, cfg.optimizer) 97 | 98 | if cfg.get('runner') is None: 99 | cfg.runner = { 100 | 'type': 'EpochBasedRunner', 101 | 'max_epochs': cfg.total_epochs 102 | } 103 | warnings.warn( 104 | 'config is now expected to have a `runner` section, ' 105 | 'please set `runner` in your config.', UserWarning) 106 | 107 | runner = build_runner( 108 | cfg.runner, 109 | default_args=dict( 110 | model=model, 111 | batch_processor=None, 112 | optimizer=optimizer, 113 | work_dir=cfg.work_dir, 114 | logger=logger, 115 | meta=meta)) 116 | 117 | # an ugly walkaround to make the .log and .log.json filenames the same 118 | runner.timestamp = timestamp 119 | 120 | if use_adverserial_train: 121 | # The optimizer step process is included in the train_step function 122 | # of the model, so the runner should NOT include optimizer hook. 123 | optimizer_config = None 124 | else: 125 | # fp16 setting 126 | fp16_cfg = cfg.get('fp16', None) 127 | if fp16_cfg is not None: 128 | optimizer_config = Fp16OptimizerHook( 129 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed) 130 | elif distributed and 'type' not in cfg.optimizer_config: 131 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 132 | else: 133 | optimizer_config = cfg.optimizer_config 134 | 135 | # register hooks 136 | runner.register_training_hooks( 137 | cfg.lr_config, 138 | optimizer_config, 139 | cfg.checkpoint_config, 140 | cfg.log_config, 141 | cfg.get('momentum_config', None), 142 | custom_hooks_config=cfg.get('custom_hooks', None)) 143 | if distributed: 144 | runner.register_hook(DistSamplerSeedHook()) 145 | 146 | # register eval hooks 147 | if validate: 148 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 149 | val_dataloader = build_dataloader( 150 | val_dataset, 151 | samples_per_gpu=cfg.data.samples_per_gpu, 152 | workers_per_gpu=cfg.data.workers_per_gpu, 153 | dist=distributed, 154 | shuffle=False, 155 | round_up=True) 156 | eval_cfg = cfg.get('evaluation', {}) 157 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 158 | eval_hook = DistEvalHook if distributed else EvalHook 159 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 160 | 161 | if cfg.resume_from: 162 | runner.resume(cfg.resume_from) 163 | elif cfg.load_from: 164 | runner.load_checkpoint(cfg.load_from) 165 | runner.run(data_loaders, cfg.workflow) -------------------------------------------------------------------------------- /mogen/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/core/__init__.py -------------------------------------------------------------------------------- /mogen/core/distributed_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 5 | from mmcv.parallel.scatter_gather import scatter_kwargs 6 | from torch.cuda._utils import _get_device_index 7 | 8 | 9 | @MODULE_WRAPPERS.register_module() 10 | class DistributedDataParallelWrapper(nn.Module): 11 | """A DistributedDataParallel wrapper for models in 3D mesh estimation task. 12 | 13 | In 3D mesh estimation task, there is a need to wrap different modules in 14 | the models with separate DistributedDataParallel. Otherwise, it will cause 15 | errors for GAN training. 16 | More specific, the GAN model, usually has two sub-modules: 17 | generator and discriminator. If we wrap both of them in one 18 | standard DistributedDataParallel, it will cause errors during training, 19 | because when we update the parameters of the generator (or discriminator), 20 | the parameters of the discriminator (or generator) is not updated, which is 21 | not allowed for DistributedDataParallel. 22 | So we design this wrapper to separately wrap DistributedDataParallel 23 | for generator and discriminator. 24 | In this wrapper, we perform two operations: 25 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 26 | Note that only modules with parameters will be wrapped. 27 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 28 | Note that the arguments of this wrapper is the same as those in 29 | `torch.nn.parallel.distributed.DistributedDataParallel`. 30 | Args: 31 | module (nn.Module): Module that needs to be wrapped. 32 | device_ids (list[int | `torch.device`]): Same as that in 33 | `torch.nn.parallel.distributed.DistributedDataParallel`. 34 | dim (int, optional): Same as that in the official scatter function in 35 | pytorch. Defaults to 0. 36 | broadcast_buffers (bool): Same as that in 37 | `torch.nn.parallel.distributed.DistributedDataParallel`. 38 | Defaults to False. 39 | find_unused_parameters (bool, optional): Same as that in 40 | `torch.nn.parallel.distributed.DistributedDataParallel`. 41 | Traverse the autograd graph of all tensors contained in returned 42 | value of the wrapped module’s forward function. Defaults to False. 43 | kwargs (dict): Other arguments used in 44 | `torch.nn.parallel.distributed.DistributedDataParallel`. 45 | """ 46 | 47 | def __init__(self, 48 | module, 49 | device_ids, 50 | dim=0, 51 | broadcast_buffers=False, 52 | find_unused_parameters=False, 53 | **kwargs): 54 | super().__init__() 55 | assert len(device_ids) == 1, ( 56 | 'Currently, DistributedDataParallelWrapper only supports one' 57 | 'single CUDA device for each process.' 58 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 59 | self.module = module 60 | self.dim = dim 61 | self.to_ddp( 62 | device_ids=device_ids, 63 | dim=dim, 64 | broadcast_buffers=broadcast_buffers, 65 | find_unused_parameters=find_unused_parameters, 66 | **kwargs) 67 | self.output_device = _get_device_index(device_ids[0], True) 68 | 69 | def to_ddp(self, device_ids, dim, broadcast_buffers, 70 | find_unused_parameters, **kwargs): 71 | """Wrap models with separate MMDistributedDataParallel. 72 | 73 | It only wraps the modules with parameters. 74 | """ 75 | for name, module in self.module._modules.items(): 76 | if next(module.parameters(), None) is None: 77 | module = module.cuda() 78 | elif all(not p.requires_grad for p in module.parameters()): 79 | module = module.cuda() 80 | else: 81 | module = MMDistributedDataParallel( 82 | module.cuda(), 83 | device_ids=device_ids, 84 | dim=dim, 85 | broadcast_buffers=broadcast_buffers, 86 | find_unused_parameters=find_unused_parameters, 87 | **kwargs) 88 | self.module._modules[name] = module 89 | 90 | def scatter(self, inputs, kwargs, device_ids): 91 | """Scatter function. 92 | 93 | Args: 94 | inputs (Tensor): Input Tensor. 95 | kwargs (dict): Args for 96 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 97 | device_ids (int): Device id. 98 | """ 99 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 100 | 101 | def forward(self, *inputs, **kwargs): 102 | """Forward function. 103 | 104 | Args: 105 | inputs (tuple): Input data. 106 | kwargs (dict): Args for 107 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 108 | """ 109 | inputs, kwargs = self.scatter(inputs, kwargs, 110 | [torch.cuda.current_device()]) 111 | return self.module(*inputs[0], **kwargs[0]) 112 | 113 | def train_step(self, *inputs, **kwargs): 114 | """Train step function. 115 | 116 | Args: 117 | inputs (Tensor): Input Tensor. 118 | kwargs (dict): Args for 119 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 120 | """ 121 | inputs, kwargs = self.scatter(inputs, kwargs, 122 | [torch.cuda.current_device()]) 123 | output = self.module.train_step(*inputs[0], **kwargs[0]) 124 | return output 125 | 126 | def val_step(self, *inputs, **kwargs): 127 | """Validation step function. 128 | 129 | Args: 130 | inputs (tuple): Input data. 131 | kwargs (dict): Args for ``scatter_kwargs``. 132 | """ 133 | inputs, kwargs = self.scatter(inputs, kwargs, 134 | [torch.cuda.current_device()]) 135 | output = self.module.val_step(*inputs[0], **kwargs[0]) 136 | return output -------------------------------------------------------------------------------- /mogen/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.core.evaluation.eval_hooks import DistEvalHook, EvalHook 2 | from mogen.core.evaluation.builder import build_evaluator 3 | 4 | __all__ = ["DistEvalHook", "EvalHook", "build_evaluator"] -------------------------------------------------------------------------------- /mogen/core/evaluation/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from mmcv.utils import Registry 4 | from .evaluators.precision_evaluator import PrecisionEvaluator 5 | from .evaluators.matching_score_evaluator import MatchingScoreEvaluator 6 | from .evaluators.fid_evaluator import FIDEvaluator 7 | from .evaluators.diversity_evaluator import DiversityEvaluator 8 | from .evaluators.multimodality_evaluator import MultiModalityEvaluator 9 | 10 | EVALUATORS = Registry('evaluators') 11 | 12 | EVALUATORS.register_module(name='R Precision', module=PrecisionEvaluator) 13 | EVALUATORS.register_module(name='Matching Score', module=MatchingScoreEvaluator) 14 | EVALUATORS.register_module(name='FID', module=FIDEvaluator) 15 | EVALUATORS.register_module(name='Diversity', module=DiversityEvaluator) 16 | EVALUATORS.register_module(name='MultiModality', module=MultiModalityEvaluator) 17 | 18 | 19 | def build_evaluator(metric, eval_cfg, data_len, eval_indexes): 20 | cfg = copy.deepcopy(eval_cfg) 21 | cfg.update(metric) 22 | cfg.pop('metrics') 23 | cfg['data_len'] = data_len 24 | cfg['eval_indexes'] = eval_indexes 25 | evaluator = EVALUATORS.build(cfg) 26 | if evaluator.append_indexes is not None: 27 | for i in range(eval_cfg['replication_times']): 28 | eval_indexes[i] = np.concatenate((eval_indexes[i], evaluator.append_indexes[i]), axis=0) 29 | return evaluator, eval_indexes 30 | -------------------------------------------------------------------------------- /mogen/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import tempfile 3 | import warnings 4 | 5 | from mmcv.runner import DistEvalHook as BaseDistEvalHook 6 | from mmcv.runner import EvalHook as BaseEvalHook 7 | 8 | mogen_GREATER_KEYS = [] 9 | mogen_LESS_KEYS = [] 10 | 11 | 12 | class EvalHook(BaseEvalHook): 13 | 14 | def __init__(self, 15 | dataloader, 16 | start=None, 17 | interval=1, 18 | by_epoch=True, 19 | save_best=None, 20 | rule=None, 21 | test_fn=None, 22 | greater_keys=mogen_GREATER_KEYS, 23 | less_keys=mogen_LESS_KEYS, 24 | **eval_kwargs): 25 | if test_fn is None: 26 | from mogen.apis import single_gpu_test 27 | test_fn = single_gpu_test 28 | 29 | # remove "gpu_collect" from eval_kwargs 30 | if 'gpu_collect' in eval_kwargs: 31 | warnings.warn( 32 | '"gpu_collect" will be deprecated in EvalHook.' 33 | 'Please remove it from the config.', DeprecationWarning) 34 | _ = eval_kwargs.pop('gpu_collect') 35 | 36 | # update "save_best" according to "key_indicator" and remove the 37 | # latter from eval_kwargs 38 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): 39 | warnings.warn( 40 | '"key_indicator" will be deprecated in EvalHook.' 41 | 'Please use "save_best" to specify the metric key,' 42 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning) 43 | 44 | key_indicator = eval_kwargs.pop('key_indicator', None) 45 | if save_best is True and key_indicator is None: 46 | raise ValueError('key_indicator should not be None, when ' 47 | 'save_best is set to True.') 48 | save_best = key_indicator 49 | 50 | super().__init__(dataloader, start, interval, by_epoch, save_best, 51 | rule, test_fn, greater_keys, less_keys, **eval_kwargs) 52 | 53 | def evaluate(self, runner, results): 54 | 55 | with tempfile.TemporaryDirectory() as tmp_dir: 56 | eval_res = self.dataloader.dataset.evaluate( 57 | results, 58 | work_dir=tmp_dir, 59 | logger=runner.logger, 60 | **self.eval_kwargs) 61 | 62 | for name, val in eval_res.items(): 63 | runner.log_buffer.output[name] = val 64 | runner.log_buffer.ready = True 65 | 66 | if self.save_best is not None: 67 | if self.key_indicator == 'auto': 68 | self._init_rule(self.rule, list(eval_res.keys())[0]) 69 | 70 | return eval_res[self.key_indicator] 71 | 72 | return None 73 | 74 | 75 | class DistEvalHook(BaseDistEvalHook): 76 | 77 | def __init__(self, 78 | dataloader, 79 | start=None, 80 | interval=1, 81 | by_epoch=True, 82 | save_best=None, 83 | rule=None, 84 | test_fn=None, 85 | greater_keys=mogen_GREATER_KEYS, 86 | less_keys=mogen_LESS_KEYS, 87 | broadcast_bn_buffer=True, 88 | tmpdir=None, 89 | gpu_collect=False, 90 | **eval_kwargs): 91 | 92 | if test_fn is None: 93 | from mogen.apis import multi_gpu_test 94 | test_fn = multi_gpu_test 95 | 96 | # update "save_best" according to "key_indicator" and remove the 97 | # latter from eval_kwargs 98 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): 99 | warnings.warn( 100 | '"key_indicator" will be deprecated in EvalHook.' 101 | 'Please use "save_best" to specify the metric key,' 102 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning) 103 | 104 | key_indicator = eval_kwargs.pop('key_indicator', None) 105 | if save_best is True and key_indicator is None: 106 | raise ValueError('key_indicator should not be None, when ' 107 | 'save_best is set to True.') 108 | save_best = key_indicator 109 | 110 | super().__init__(dataloader, start, interval, by_epoch, save_best, 111 | rule, test_fn, greater_keys, less_keys, 112 | broadcast_bn_buffer, tmpdir, gpu_collect, 113 | **eval_kwargs) 114 | 115 | def evaluate(self, runner, results): 116 | """Evaluate the results. 117 | Args: 118 | runner (:obj:`mmcv.Runner`): The underlined training runner. 119 | results (list): Output results. 120 | """ 121 | with tempfile.TemporaryDirectory() as tmp_dir: 122 | eval_res = self.dataloader.dataset.evaluate( 123 | results, 124 | work_dir=tmp_dir, 125 | logger=runner.logger, 126 | **self.eval_kwargs) 127 | 128 | for name, val in eval_res.items(): 129 | runner.log_buffer.output[name] = val 130 | runner.log_buffer.ready = True 131 | 132 | if self.save_best is not None: 133 | if self.key_indicator == 'auto': 134 | # infer from eval_results 135 | self._init_rule(self.rule, list(eval_res.keys())[0]) 136 | return eval_res[self.key_indicator] 137 | 138 | return None -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/core/evaluation/evaluators/__init__.py -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/base_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ..utils import get_metric_statistics 4 | 5 | 6 | class BaseEvaluator(object): 7 | 8 | def __init__(self, 9 | batch_size=None, 10 | drop_last=False, 11 | replication_times=1, 12 | replication_reduction='statistics', 13 | eval_begin_idx=None, 14 | eval_end_idx=None): 15 | self.batch_size = batch_size 16 | self.drop_last = drop_last 17 | self.replication_times = replication_times 18 | self.replication_reduction = replication_reduction 19 | assert replication_reduction in ['statistics', 'mean', 'concat'] 20 | self.eval_begin_idx = eval_begin_idx 21 | self.eval_end_idx = eval_end_idx 22 | 23 | def evaluate(self, results): 24 | total_len = len(results) 25 | partial_len = total_len // self.replication_times 26 | all_metrics = [] 27 | for replication_idx in range(self.replication_times): 28 | partial_results = results[ 29 | replication_idx * partial_len: (replication_idx + 1) * partial_len] 30 | if self.batch_size is not None: 31 | batch_metrics = [] 32 | for batch_start in range(self.eval_begin_idx, self.eval_end_idx, self.batch_size): 33 | batch_results = partial_results[batch_start: batch_start + self.batch_size] 34 | if len(batch_results) < self.batch_size and self.drop_last: 35 | continue 36 | batch_metrics.append(self.single_evaluate(batch_results)) 37 | all_metrics.append(self.concat_batch_metrics(batch_metrics)) 38 | else: 39 | batch_results = partial_results[self.eval_begin_idx: self.eval_end_idx] 40 | all_metrics.append(self.single_evaluate(batch_results)) 41 | all_metrics = np.stack(all_metrics, axis=0) 42 | if self.replication_reduction == 'statistics': 43 | values = get_metric_statistics(all_metrics, self.replication_times) 44 | elif self.replication_reduction == 'mean': 45 | values = np.mean(all_metrics, axis=0) 46 | elif self.replication_reduction == 'concat': 47 | values = all_metrics 48 | return self.parse_values(values) 49 | 50 | def prepare_results(self, results): 51 | text = [] 52 | pred_motion = [] 53 | pred_motion_length = [] 54 | pred_motion_mask = [] 55 | motion = [] 56 | motion_length = [] 57 | motion_mask = [] 58 | token = [] 59 | # count the maximum motion length 60 | T = max([result['motion'].shape[0] for result in results]) 61 | for result in results: 62 | cur_motion = result['motion'] 63 | if cur_motion.shape[0] < T: 64 | padding_values = torch.zeros((T - cur_motion.shape[0], cur_motion.shape[1])) 65 | padding_values = padding_values.type_as(pred_motion) 66 | cur_motion = torch.cat([cur_motion, padding_values], dim=0) 67 | motion.append(cur_motion) 68 | cur_pred_motion = result['pred_motion'] 69 | if cur_pred_motion.shape[0] < T: 70 | padding_values = torch.zeros((T - cur_pred_motion.shape[0], cur_pred_motion.shape[1])) 71 | padding_values = padding_values.type_as(cur_pred_motion) 72 | cur_pred_motion = torch.cat([cur_pred_motion, padding_values], dim=0) 73 | pred_motion.append(cur_pred_motion) 74 | cur_motion_mask = result['motion_mask'] 75 | if cur_motion_mask.shape[0] < T: 76 | padding_values = torch.zeros((T - cur_motion_mask.shape[0])) 77 | padding_values = padding_values.type_as(cur_motion_mask) 78 | cur_motion_mask= torch.cat([cur_motion_mask, padding_values], dim=0) 79 | motion_mask.append(cur_motion_mask) 80 | cur_pred_motion_mask = result['pred_motion_mask'] 81 | if cur_pred_motion_mask.shape[0] < T: 82 | padding_values = torch.zeros((T - cur_pred_motion_mask.shape[0])) 83 | padding_values = padding_values.type_as(cur_pred_motion_mask) 84 | cur_pred_motion_mask= torch.cat([cur_pred_motion_mask, padding_values], dim=0) 85 | pred_motion_mask.append(cur_pred_motion_mask) 86 | motion_length.append(result['motion_length'].item()) 87 | pred_motion_length.append(result['pred_motion_length'].item()) 88 | if 'text' in result.keys(): 89 | text.append(result['text']) 90 | if 'token' in result.keys(): 91 | token.append(result['token']) 92 | 93 | motion = torch.stack(motion, dim=0) 94 | pred_motion = torch.stack(pred_motion, dim=0) 95 | motion_mask = torch.stack(motion_mask, dim=0) 96 | pred_motion_mask = torch.stack(pred_motion_mask, dim=0) 97 | motion_length = torch.Tensor(motion_length).to(motion.device).long() 98 | pred_motion_length = torch.Tensor(pred_motion_length).to(motion.device).long() 99 | output = { 100 | 'pred_motion': pred_motion, 101 | 'pred_motion_mask': pred_motion_mask, 102 | 'pred_motion_length': pred_motion_length, 103 | 'motion': motion, 104 | 'motion_mask': motion_mask, 105 | 'motion_length': motion_length, 106 | 'text': text, 107 | 'token': token 108 | } 109 | return output 110 | 111 | def to_device(self, device): 112 | for model in self.model_list: 113 | model.to(device) 114 | 115 | def motion_encode(self, motion, motion_length, motion_mask, device): 116 | N = motion.shape[0] 117 | motion_emb = [] 118 | batch_size = 32 119 | cur_idx = 0 120 | with torch.no_grad(): 121 | while cur_idx < N: 122 | cur_motion = motion[cur_idx: cur_idx + batch_size].to(device) 123 | cur_motion_length = motion_length[cur_idx: cur_idx + batch_size].to(device) 124 | cur_motion_mask = motion_mask[cur_idx: cur_idx + batch_size].to(device) 125 | cur_motion_emb = self.motion_encoder(cur_motion, cur_motion_length, cur_motion_mask) 126 | motion_emb.append(cur_motion_emb) 127 | cur_idx += batch_size 128 | motion_emb = torch.cat(motion_emb, dim=0) 129 | return motion_emb 130 | 131 | def text_encode(self, text, token, device): 132 | N = len(text) 133 | text_emb = [] 134 | batch_size = 32 135 | cur_idx = 0 136 | with torch.no_grad(): 137 | while cur_idx < N: 138 | cur_text = text[cur_idx: cur_idx + batch_size] 139 | cur_token = token[cur_idx: cur_idx + batch_size] 140 | cur_text_emb = self.text_encoder(cur_text, cur_token, device) 141 | text_emb.append(cur_text_emb) 142 | cur_idx += batch_size 143 | text_emb = torch.cat(text_emb, dim=0) 144 | return text_emb -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/diversity_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..get_model import get_motion_model 5 | from .base_evaluator import BaseEvaluator 6 | from ..utils import calculate_diversity 7 | 8 | 9 | class DiversityEvaluator(BaseEvaluator): 10 | 11 | def __init__(self, 12 | data_len=0, 13 | motion_encoder_name=None, 14 | motion_encoder_path=None, 15 | num_samples=300, 16 | batch_size=None, 17 | drop_last=False, 18 | replication_times=1, 19 | replication_reduction='statistics', 20 | **kwargs): 21 | super().__init__( 22 | replication_times=replication_times, 23 | replication_reduction=replication_reduction, 24 | batch_size=batch_size, 25 | drop_last=drop_last, 26 | eval_begin_idx=0, 27 | eval_end_idx=data_len 28 | ) 29 | self.num_samples = num_samples 30 | self.append_indexes = None 31 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) 32 | self.model_list = [self.motion_encoder] 33 | 34 | def single_evaluate(self, results): 35 | results = self.prepare_results(results) 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | motion = results['motion'] 38 | pred_motion = results['pred_motion'] 39 | pred_motion_length = results['pred_motion_length'] 40 | pred_motion_mask = results['pred_motion_mask'] 41 | self.motion_encoder.to(device) 42 | self.motion_encoder.eval() 43 | with torch.no_grad(): 44 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() 45 | diversity = calculate_diversity(pred_motion_emb, self.num_samples) 46 | return diversity 47 | 48 | def parse_values(self, values): 49 | metrics = {} 50 | metrics['Diversity (mean)'] = values[0] 51 | metrics['Diversity (conf)'] = values[1] 52 | return metrics 53 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/fid_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..get_model import get_motion_model 5 | from .base_evaluator import BaseEvaluator 6 | from ..utils import ( 7 | calculate_activation_statistics, 8 | calculate_frechet_distance) 9 | 10 | 11 | class FIDEvaluator(BaseEvaluator): 12 | 13 | def __init__(self, 14 | data_len=0, 15 | motion_encoder_name=None, 16 | motion_encoder_path=None, 17 | batch_size=None, 18 | drop_last=False, 19 | replication_times=1, 20 | replication_reduction='statistics', 21 | **kwargs): 22 | super().__init__( 23 | replication_times=replication_times, 24 | replication_reduction=replication_reduction, 25 | batch_size=batch_size, 26 | drop_last=drop_last, 27 | eval_begin_idx=0, 28 | eval_end_idx=data_len 29 | ) 30 | self.append_indexes = None 31 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) 32 | self.model_list = [self.motion_encoder] 33 | 34 | def single_evaluate(self, results): 35 | results = self.prepare_results(results) 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | pred_motion = results['pred_motion'] 38 | 39 | pred_motion_length = results['pred_motion_length'] 40 | pred_motion_mask = results['pred_motion_mask'] 41 | motion = results['motion'] 42 | motion_length = results['motion_length'] 43 | motion_mask = results['motion_mask'] 44 | self.motion_encoder.to(device) 45 | self.motion_encoder.eval() 46 | with torch.no_grad(): 47 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() 48 | gt_motion_emb = self.motion_encode(motion, motion_length, motion_mask, device).cpu().detach().numpy() 49 | gt_mu, gt_cov = calculate_activation_statistics(gt_motion_emb) 50 | pred_mu, pred_cov = calculate_activation_statistics(pred_motion_emb) 51 | fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov) 52 | return fid 53 | 54 | def parse_values(self, values): 55 | metrics = {} 56 | metrics['FID (mean)'] = values[0] 57 | metrics['FID (conf)'] = values[1] 58 | return metrics 59 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/matching_score_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..get_model import get_motion_model, get_text_model 5 | from .base_evaluator import BaseEvaluator 6 | from ..utils import calculate_top_k, euclidean_distance_matrix 7 | 8 | 9 | class MatchingScoreEvaluator(BaseEvaluator): 10 | 11 | def __init__(self, 12 | data_len=0, 13 | text_encoder_name=None, 14 | text_encoder_path=None, 15 | motion_encoder_name=None, 16 | motion_encoder_path=None, 17 | top_k=3, 18 | batch_size=32, 19 | drop_last=False, 20 | replication_times=1, 21 | replication_reduction='statistics', 22 | **kwargs): 23 | super().__init__( 24 | replication_times=replication_times, 25 | replication_reduction=replication_reduction, 26 | batch_size=batch_size, 27 | drop_last=drop_last, 28 | eval_begin_idx=0, 29 | eval_end_idx=data_len 30 | ) 31 | self.append_indexes = None 32 | self.text_encoder = get_text_model(text_encoder_name, text_encoder_path) 33 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) 34 | self.top_k = top_k 35 | self.model_list = [self.text_encoder, self.motion_encoder] 36 | 37 | def single_evaluate(self, results): 38 | results = self.prepare_results(results) 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | motion = results['motion'] 41 | pred_motion = results['pred_motion'] 42 | pred_motion_length = results['pred_motion_length'] 43 | pred_motion_mask = results['pred_motion_mask'] 44 | text = results['text'] 45 | token = results['token'] 46 | self.text_encoder.to(device) 47 | self.motion_encoder.to(device) 48 | self.text_encoder.eval() 49 | self.motion_encoder.eval() 50 | with torch.no_grad(): 51 | word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy() 52 | motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() 53 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb) 54 | matching_score = dist_mat.trace() 55 | all_size = word_emb.shape[0] 56 | return matching_score, all_size 57 | 58 | def concat_batch_metrics(self, batch_metrics): 59 | matching_score_sum = 0 60 | all_size = 0 61 | for batch_matching_score, batch_all_size in batch_metrics: 62 | matching_score_sum += batch_matching_score 63 | all_size += batch_all_size 64 | matching_score = matching_score_sum / all_size 65 | return matching_score 66 | 67 | def parse_values(self, values): 68 | metrics = {} 69 | metrics['Matching Score (mean)'] = values[0] 70 | metrics['Matching Score (conf)'] = values[1] 71 | return metrics 72 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/multimodality_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..get_model import get_motion_model 5 | from .base_evaluator import BaseEvaluator 6 | from ..utils import calculate_multimodality 7 | 8 | 9 | class MultiModalityEvaluator(BaseEvaluator): 10 | 11 | def __init__(self, 12 | data_len=0, 13 | motion_encoder_name=None, 14 | motion_encoder_path=None, 15 | num_samples=100, 16 | num_repeats=30, 17 | num_picks=10, 18 | batch_size=None, 19 | drop_last=False, 20 | replication_times=1, 21 | replication_reduction='statistics', 22 | **kwargs): 23 | super().__init__( 24 | replication_times=replication_times, 25 | replication_reduction=replication_reduction, 26 | batch_size=batch_size, 27 | drop_last=drop_last, 28 | eval_begin_idx=data_len, 29 | eval_end_idx=data_len + num_samples * num_repeats 30 | ) 31 | self.num_samples = num_samples 32 | self.num_repeats = num_repeats 33 | self.num_picks = num_picks 34 | self.append_indexes = [] 35 | for i in range(replication_times): 36 | append_indexes = [] 37 | selected_indexs = np.random.choice(data_len, self.num_samples) 38 | for index in selected_indexs: 39 | append_indexes = append_indexes + [index] * self.num_repeats 40 | self.append_indexes.append(np.array(append_indexes)) 41 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) 42 | self.model_list = [self.motion_encoder] 43 | 44 | def single_evaluate(self, results): 45 | results = self.prepare_results(results) 46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 47 | motion = results['motion'] 48 | pred_motion = results['pred_motion'] 49 | pred_motion_length = results['pred_motion_length'] 50 | pred_motion_mask = results['pred_motion_mask'] 51 | self.motion_encoder.to(device) 52 | self.motion_encoder.eval() 53 | with torch.no_grad(): 54 | pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() 55 | pred_motion_emb = pred_motion_emb.reshape((self.num_samples, self.num_repeats, -1)) 56 | multimodality = calculate_multimodality(pred_motion_emb, self.num_picks) 57 | return multimodality 58 | 59 | def parse_values(self, values): 60 | metrics = {} 61 | metrics['MultiModality (mean)'] = values[0] 62 | metrics['MultiModality (conf)'] = values[1] 63 | return metrics 64 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/precision_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..get_model import get_motion_model, get_text_model 5 | from .base_evaluator import BaseEvaluator 6 | from ..utils import calculate_top_k, euclidean_distance_matrix 7 | 8 | 9 | class PrecisionEvaluator(BaseEvaluator): 10 | 11 | def __init__(self, 12 | data_len=0, 13 | text_encoder_name=None, 14 | text_encoder_path=None, 15 | motion_encoder_name=None, 16 | motion_encoder_path=None, 17 | top_k=3, 18 | batch_size=32, 19 | drop_last=False, 20 | replication_times=1, 21 | replication_reduction='statistics', 22 | **kwargs): 23 | super().__init__( 24 | replication_times=replication_times, 25 | replication_reduction=replication_reduction, 26 | batch_size=batch_size, 27 | drop_last=drop_last, 28 | eval_begin_idx=0, 29 | eval_end_idx=data_len 30 | ) 31 | self.append_indexes = None 32 | self.text_encoder = get_text_model(text_encoder_name, text_encoder_path) 33 | self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) 34 | self.top_k = top_k 35 | self.model_list = [self.text_encoder, self.motion_encoder] 36 | 37 | def single_evaluate(self, results): 38 | results = self.prepare_results(results) 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | motion = results['motion'] 41 | pred_motion = results['pred_motion'] 42 | pred_motion_length = results['pred_motion_length'] 43 | pred_motion_mask = results['pred_motion_mask'] 44 | text = results['text'] 45 | token = results['token'] 46 | self.text_encoder.to(device) 47 | self.motion_encoder.to(device) 48 | self.text_encoder.eval() 49 | self.motion_encoder.eval() 50 | with torch.no_grad(): 51 | word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy() 52 | motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() 53 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb) 54 | argsmax = np.argsort(dist_mat, axis=1) 55 | top_k_mat = calculate_top_k(argsmax, top_k=self.top_k) 56 | top_k_count = top_k_mat.sum(axis=0) 57 | all_size = word_emb.shape[0] 58 | return top_k_count, all_size 59 | 60 | def concat_batch_metrics(self, batch_metrics): 61 | top_k_count = 0 62 | all_size = 0 63 | for batch_top_k_count, batch_all_size in batch_metrics: 64 | top_k_count += batch_top_k_count 65 | all_size += batch_all_size 66 | R_precision = top_k_count / all_size 67 | return R_precision 68 | 69 | def parse_values(self, values): 70 | metrics = {} 71 | for top_k in range(self.top_k): 72 | metrics['R_precision Top %d (mean)' % (top_k + 1)] = values[0][top_k] 73 | metrics['R_precision Top %d (conf)' % (top_k + 1)] = values[1][top_k] 74 | return metrics 75 | -------------------------------------------------------------------------------- /mogen/core/evaluation/get_model.py: -------------------------------------------------------------------------------- 1 | from mogen.models import build_submodule 2 | 3 | 4 | def get_motion_model(name, ckpt_path): 5 | if name == 'kit_ml': 6 | model = build_submodule(dict( 7 | type='T2MMotionEncoder', 8 | input_size=251, 9 | movement_hidden_size=512, 10 | movement_latent_size=512, 11 | motion_hidden_size=1024, 12 | motion_latent_size=512, 13 | )) 14 | else: 15 | model = build_submodule(dict( 16 | type='T2MMotionEncoder', 17 | input_size=263, 18 | movement_hidden_size=512, 19 | movement_latent_size=512, 20 | motion_hidden_size=1024, 21 | motion_latent_size=512, 22 | )) 23 | model.load_pretrained(ckpt_path) 24 | return model 25 | 26 | def get_text_model(name, ckpt_path): 27 | if name == 'kit_ml': 28 | model = build_submodule(dict( 29 | type='T2MTextEncoder', 30 | word_size=300, 31 | pos_size=15, 32 | hidden_size=512, 33 | output_size=512, 34 | max_text_len=20 35 | )) 36 | else: 37 | model = build_submodule(dict( 38 | type='T2MTextEncoder', 39 | word_size=300, 40 | pos_size=15, 41 | hidden_size=512, 42 | output_size=512, 43 | max_text_len=20 44 | )) 45 | model.load_pretrained(ckpt_path) 46 | return model 47 | -------------------------------------------------------------------------------- /mogen/core/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | def get_metric_statistics(values, replication_times): 6 | mean = np.mean(values, axis=0) 7 | std = np.std(values, axis=0) 8 | conf_interval = 1.96 * std / np.sqrt(replication_times) 9 | return mean, conf_interval 10 | 11 | 12 | # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train 13 | def euclidean_distance_matrix(matrix1, matrix2): 14 | """ 15 | Params: 16 | -- matrix1: N1 x D 17 | -- matrix2: N2 x D 18 | Returns: 19 | -- dist: N1 x N2 20 | dist[i, j] == distance(matrix1[i], matrix2[j]) 21 | """ 22 | assert matrix1.shape[1] == matrix2.shape[1] 23 | d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) 24 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) 25 | d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) 26 | dists = np.sqrt(d1 + d2 + d3) # broadcasting 27 | return dists 28 | 29 | 30 | def calculate_top_k(mat, top_k): 31 | size = mat.shape[0] 32 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 33 | bool_mat = (mat == gt_mat) 34 | correct_vec = False 35 | top_k_list = [] 36 | for i in range(top_k): 37 | # print(correct_vec, bool_mat[:, i]) 38 | correct_vec = (correct_vec | bool_mat[:, i]) 39 | # print(correct_vec) 40 | top_k_list.append(correct_vec[:, None]) 41 | top_k_mat = np.concatenate(top_k_list, axis=1) 42 | return top_k_mat 43 | 44 | 45 | def calculate_activation_statistics(activations): 46 | """ 47 | Params: 48 | -- activation: num_samples x dim_feat 49 | Returns: 50 | -- mu: dim_feat 51 | -- sigma: dim_feat x dim_feat 52 | """ 53 | mu = np.mean(activations, axis=0) 54 | cov = np.cov(activations, rowvar=False) 55 | return mu, cov 56 | 57 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 60 | and X_2 ~ N(mu_2, C_2) is 61 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 62 | Stable version by Dougal J. Sutherland. 63 | Params: 64 | -- mu1 : Numpy array containing the activations of a layer of the 65 | inception net (like returned by the function 'get_predictions') 66 | for generated samples. 67 | -- mu2 : The sample mean over activations, precalculated on an 68 | representative data set. 69 | -- sigma1: The covariance matrix over activations for generated samples. 70 | -- sigma2: The covariance matrix over activations, precalculated on an 71 | representative data set. 72 | Returns: 73 | -- : The Frechet Distance. 74 | """ 75 | 76 | mu1 = np.atleast_1d(mu1) 77 | mu2 = np.atleast_1d(mu2) 78 | 79 | sigma1 = np.atleast_2d(sigma1) 80 | sigma2 = np.atleast_2d(sigma2) 81 | 82 | assert mu1.shape == mu2.shape, \ 83 | 'Training and test mean vectors have different lengths' 84 | assert sigma1.shape == sigma2.shape, \ 85 | 'Training and test covariances have different dimensions' 86 | 87 | diff = mu1 - mu2 88 | 89 | # Product might be almost singular 90 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 91 | if not np.isfinite(covmean).all(): 92 | msg = ('fid calculation produces singular product; ' 93 | 'adding %s to diagonal of cov estimates') % eps 94 | print(msg) 95 | offset = np.eye(sigma1.shape[0]) * eps 96 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 97 | 98 | # Numerical error might give slight imaginary component 99 | if np.iscomplexobj(covmean): 100 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 101 | m = np.max(np.abs(covmean.imag)) 102 | raise ValueError('Imaginary component {}'.format(m)) 103 | covmean = covmean.real 104 | 105 | tr_covmean = np.trace(covmean) 106 | 107 | return (diff.dot(diff) + np.trace(sigma1) + 108 | np.trace(sigma2) - 2 * tr_covmean) 109 | 110 | 111 | def calculate_diversity(activation, diversity_times): 112 | assert len(activation.shape) == 2 113 | assert activation.shape[0] > diversity_times 114 | num_samples = activation.shape[0] 115 | 116 | first_indices = np.random.choice(num_samples, diversity_times, replace=False) 117 | second_indices = np.random.choice(num_samples, diversity_times, replace=False) 118 | dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) 119 | return dist.mean() 120 | 121 | 122 | def calculate_multimodality(activation, multimodality_times): 123 | assert len(activation.shape) == 3 124 | assert activation.shape[1] > multimodality_times 125 | num_per_sent = activation.shape[1] 126 | 127 | first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 128 | second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) 129 | dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) 130 | return dist.mean() 131 | -------------------------------------------------------------------------------- /mogen/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import OPTIMIZERS, build_optimizers 2 | 3 | __all__ = ['build_optimizers', 'OPTIMIZERS'] -------------------------------------------------------------------------------- /mogen/core/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import build_optimizer 3 | from mmcv.utils import Registry 4 | 5 | OPTIMIZERS = Registry('optimizers') 6 | 7 | 8 | def build_optimizers(model, cfgs): 9 | """Build multiple optimizers from configs. If `cfgs` contains several dicts 10 | for optimizers, then a dict for each constructed optimizers will be 11 | returned. If `cfgs` only contains one optimizer config, the constructed 12 | optimizer itself will be returned. For example, 13 | 14 | 1) Multiple optimizer configs: 15 | 16 | .. code-block:: python 17 | 18 | optimizer_cfg = dict( 19 | model1=dict(type='SGD', lr=lr), 20 | model2=dict(type='SGD', lr=lr)) 21 | 22 | The return dict is 23 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` 24 | 25 | 2) Single optimizer config: 26 | 27 | .. code-block:: python 28 | 29 | optimizer_cfg = dict(type='SGD', lr=lr) 30 | 31 | The return is ``torch.optim.Optimizer``. 32 | 33 | Args: 34 | model (:obj:`nn.Module`): The model with parameters to be optimized. 35 | cfgs (dict): The config dict of the optimizer. 36 | 37 | Returns: 38 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: 39 | The initialized optimizers. 40 | """ 41 | optimizers = {} 42 | if hasattr(model, 'module'): 43 | model = model.module 44 | # determine whether 'cfgs' has several dicts for optimizers 45 | if all(isinstance(v, dict) for v in cfgs.values()): 46 | for key, cfg in cfgs.items(): 47 | cfg_ = cfg.copy() 48 | module = getattr(model, key) 49 | optimizers[key] = build_optimizer(module, cfg_) 50 | return optimizers 51 | 52 | return build_optimizer(model, cfgs) -------------------------------------------------------------------------------- /mogen/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseMotionDataset 2 | from .text_motion_dataset import TextMotionDataset 3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 4 | from .pipelines import Compose 5 | from .samplers import DistributedSampler 6 | 7 | 8 | __all__ = [ 9 | 'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', 'build_dataloader', 10 | 'build_dataset', 'Compose', 'DistributedSampler' 11 | ] -------------------------------------------------------------------------------- /mogen/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from typing import Optional, Union 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | from .pipelines import Compose 9 | from .builder import DATASETS 10 | from mogen.core.evaluation import build_evaluator 11 | 12 | 13 | @DATASETS.register_module() 14 | class BaseMotionDataset(Dataset): 15 | """Base motion dataset. 16 | Args: 17 | data_prefix (str): the prefix of data path. 18 | pipeline (list): a list of dict, where each element represents 19 | a operation defined in `mogen.datasets.pipelines`. 20 | ann_file (str | None, optional): the annotation file. When ann_file is 21 | str, the subclass is expected to read from the ann_file. When 22 | ann_file is None, the subclass is expected to read according 23 | to data_prefix. 24 | test_mode (bool): in train mode or test mode. Default: None. 25 | dataset_name (str | None, optional): the name of dataset. It is used 26 | to identify the type of evaluation metric. Default: None. 27 | """ 28 | 29 | def __init__(self, 30 | data_prefix: str, 31 | pipeline: list, 32 | dataset_name: Optional[Union[str, None]] = None, 33 | fixed_length: Optional[Union[int, None]] = None, 34 | ann_file: Optional[Union[str, None]] = None, 35 | motion_dir: Optional[Union[str, None]] = None, 36 | eval_cfg: Optional[Union[dict, None]] = None, 37 | test_mode: Optional[bool] = False): 38 | super(BaseMotionDataset, self).__init__() 39 | 40 | self.data_prefix = data_prefix 41 | self.pipeline = Compose(pipeline) 42 | self.dataset_name = dataset_name 43 | self.fixed_length = fixed_length 44 | self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file) 45 | self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir) 46 | self.eval_cfg = copy.deepcopy(eval_cfg) 47 | self.test_mode = test_mode 48 | 49 | self.load_annotations() 50 | if self.test_mode: 51 | self.prepare_evaluation() 52 | 53 | def load_anno(self, name): 54 | motion_path = os.path.join(self.motion_dir, name + '.npy') 55 | motion_data = np.load(motion_path) 56 | return {'motion': motion_data} 57 | 58 | 59 | def load_annotations(self): 60 | """Load annotations from ``ann_file`` to ``data_infos``""" 61 | self.data_infos = [] 62 | for line in open(self.ann_file, 'r').readlines(): 63 | line = line.strip() 64 | self.data_infos.append(self.load_anno(line)) 65 | 66 | 67 | def prepare_data(self, idx: int): 68 | """"Prepare raw data for the f'{idx'}-th data.""" 69 | results = copy.deepcopy(self.data_infos[idx]) 70 | results['dataset_name'] = self.dataset_name 71 | results['sample_idx'] = idx 72 | return self.pipeline(results) 73 | 74 | def __len__(self): 75 | """Return the length of current dataset.""" 76 | if self.test_mode: 77 | return len(self.eval_indexes) 78 | elif self.fixed_length is not None: 79 | return self.fixed_length 80 | return len(self.data_infos) 81 | 82 | def __getitem__(self, idx: int): 83 | """Prepare data for the ``idx``-th data. 84 | As for video dataset, we can first parse raw data for each frame. Then 85 | we combine annotations from all frames. This interface is used to 86 | simplify the logic of video dataset and other special datasets. 87 | """ 88 | if self.test_mode: 89 | idx = self.eval_indexes[idx] 90 | elif self.fixed_length is not None: 91 | idx = idx % len(self.data_infos) 92 | return self.prepare_data(idx) 93 | 94 | def prepare_evaluation(self): 95 | self.evaluators = [] 96 | self.eval_indexes = [] 97 | for _ in range(self.eval_cfg['replication_times']): 98 | eval_indexes = np.arange(len(self.data_infos)) 99 | if self.eval_cfg.get('shuffle_indexes', False): 100 | np.random.shuffle(eval_indexes) 101 | self.eval_indexes.append(eval_indexes) 102 | for metric in self.eval_cfg['metrics']: 103 | evaluator, self.eval_indexes = build_evaluator( 104 | metric, self.eval_cfg, len(self.data_infos), self.eval_indexes) 105 | self.evaluators.append(evaluator) 106 | 107 | self.eval_indexes = np.concatenate(self.eval_indexes) 108 | 109 | def evaluate(self, results, work_dir, logger=None): 110 | metrics = {} 111 | device = results[0]['motion'].device 112 | for evaluator in self.evaluators: 113 | evaluator.to_device(device) 114 | metrics.update(evaluator.evaluate(results)) 115 | if logger is not None: 116 | logger.info(metrics) 117 | return metrics 118 | -------------------------------------------------------------------------------- /mogen/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | from functools import partial 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | from mmcv.parallel import collate 8 | from mmcv.runner import get_dist_info 9 | from mmcv.utils import Registry, build_from_cfg 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.dataset import Dataset 12 | 13 | from .samplers import DistributedSampler 14 | 15 | if platform.system() != 'Windows': 16 | # https://github.com/pytorch/pytorch/issues/973 17 | import resource 18 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 19 | base_soft_limit = rlimit[0] 20 | hard_limit = rlimit[1] 21 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 22 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 23 | 24 | DATASETS = Registry('dataset') 25 | PIPELINES = Registry('pipeline') 26 | 27 | 28 | def build_dataset(cfg: Union[dict, list, tuple], 29 | default_args: Optional[Union[dict, None]] = None): 30 | """"Build dataset by the given config.""" 31 | from .dataset_wrappers import ( 32 | ConcatDataset, 33 | RepeatDataset, 34 | ) 35 | if isinstance(cfg, (list, tuple)): 36 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 37 | elif cfg['type'] == 'RepeatDataset': 38 | dataset = RepeatDataset( 39 | build_dataset(cfg['dataset'], default_args), cfg['times']) 40 | else: 41 | dataset = build_from_cfg(cfg, DATASETS, default_args) 42 | 43 | return dataset 44 | 45 | 46 | def build_dataloader(dataset: Dataset, 47 | samples_per_gpu: int, 48 | workers_per_gpu: int, 49 | num_gpus: Optional[int] = 1, 50 | dist: Optional[bool] = True, 51 | shuffle: Optional[bool] = True, 52 | round_up: Optional[bool] = True, 53 | seed: Optional[Union[int, None]] = None, 54 | persistent_workers: Optional[bool] = True, 55 | **kwargs): 56 | """Build PyTorch DataLoader. 57 | In distributed training, each GPU/process has a dataloader. 58 | In non-distributed training, there is only one dataloader for all GPUs. 59 | Args: 60 | dataset (:obj:`Dataset`): A PyTorch dataset. 61 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 62 | batch size of each GPU. 63 | workers_per_gpu (int): How many subprocesses to use for data loading 64 | for each GPU. 65 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed 66 | training. 67 | dist (bool, optional): Distributed training/test or not. Default: True. 68 | shuffle (bool, optional): Whether to shuffle the data at every epoch. 69 | Default: True. 70 | round_up (bool, optional): Whether to round up the length of dataset by 71 | adding extra samples to make it evenly divisible. Default: True. 72 | kwargs: any keyword argument to be used to initialize DataLoader 73 | Returns: 74 | DataLoader: A PyTorch dataloader. 75 | """ 76 | rank, world_size = get_dist_info() 77 | if dist: 78 | sampler = DistributedSampler( 79 | dataset, world_size, rank, shuffle=shuffle, round_up=round_up) 80 | shuffle = False 81 | batch_size = samples_per_gpu 82 | num_workers = workers_per_gpu 83 | else: 84 | sampler = None 85 | batch_size = num_gpus * samples_per_gpu 86 | num_workers = num_gpus * workers_per_gpu 87 | 88 | init_fn = partial( 89 | worker_init_fn, num_workers=num_workers, rank=rank, 90 | seed=seed) if seed is not None else None 91 | 92 | data_loader = DataLoader( 93 | dataset, 94 | batch_size=batch_size, 95 | sampler=sampler, 96 | num_workers=num_workers, 97 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 98 | pin_memory=False, 99 | shuffle=shuffle, 100 | worker_init_fn=init_fn, 101 | persistent_workers=persistent_workers, 102 | **kwargs) 103 | 104 | return data_loader 105 | 106 | 107 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): 108 | """Init random seed for each worker.""" 109 | # The seed of each worker equals to 110 | # num_worker * rank + worker_id + user_seed 111 | worker_seed = num_workers * rank + worker_id + seed 112 | np.random.seed(worker_seed) 113 | random.seed(worker_seed) -------------------------------------------------------------------------------- /mogen/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 2 | from torch.utils.data.dataset import Dataset 3 | 4 | from .builder import DATASETS 5 | 6 | 7 | @DATASETS.register_module() 8 | class ConcatDataset(_ConcatDataset): 9 | """A wrapper of concatenated dataset. 10 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 11 | add `get_cat_ids` function. 12 | Args: 13 | datasets (list[:obj:`Dataset`]): A list of datasets. 14 | """ 15 | 16 | def __init__(self, datasets: list): 17 | super(ConcatDataset, self).__init__(datasets) 18 | 19 | 20 | @DATASETS.register_module() 21 | class RepeatDataset(object): 22 | """A wrapper of repeated dataset. 23 | The length of repeated dataset will be `times` larger than the original 24 | dataset. This is useful when the data loading time is long but the dataset 25 | is small. Using RepeatDataset can reduce the data loading time between 26 | epochs. 27 | Args: 28 | dataset (:obj:`Dataset`): The dataset to be repeated. 29 | times (int): Repeat times. 30 | """ 31 | 32 | def __init__(self, dataset: Dataset, times: int): 33 | self.dataset = dataset 34 | self.times = times 35 | 36 | self._ori_len = len(self.dataset) 37 | 38 | def __getitem__(self, idx: int): 39 | return self.dataset[idx % self._ori_len] 40 | 41 | def __len__(self): 42 | return self.times * self._ori_len -------------------------------------------------------------------------------- /mogen/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .compose import Compose 2 | from .formatting import ( 3 | to_tensor, 4 | ToTensor, 5 | Transpose, 6 | Collect, 7 | WrapFieldsToLists 8 | ) 9 | from .transforms import ( 10 | Crop, 11 | RandomCrop, 12 | Normalize 13 | ) 14 | 15 | __all__ = [ 16 | 'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', 'ToTensor', 17 | 'Crop', 'RandomCrop', 'Normalize' 18 | ] -------------------------------------------------------------------------------- /mogen/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Compose(object): 10 | """Compose a data pipeline with a sequence of transforms. 11 | 12 | Args: 13 | transforms (list[dict | callable]): 14 | Either config dicts of transforms or transform objects. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | assert isinstance(transforms, Sequence) 19 | self.transforms = [] 20 | for transform in transforms: 21 | if isinstance(transform, dict): 22 | transform = build_from_cfg(transform, PIPELINES) 23 | self.transforms.append(transform) 24 | elif callable(transform): 25 | self.transforms.append(transform) 26 | else: 27 | raise TypeError('transform must be callable or a dict, but got' 28 | f' {type(transform)}') 29 | 30 | def __call__(self, data): 31 | for t in self.transforms: 32 | data = t(data) 33 | if data is None: 34 | return None 35 | return data 36 | 37 | def __repr__(self): 38 | format_string = self.__class__.__name__ + '(' 39 | for t in self.transforms: 40 | format_string += f'\n {t}' 41 | format_string += '\n)' 42 | return format_string -------------------------------------------------------------------------------- /mogen/datasets/pipelines/formatting.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import DataContainer as DC 7 | from PIL import Image 8 | 9 | from ..builder import PIPELINES 10 | 11 | 12 | def to_tensor(data): 13 | """Convert objects of various python types to :obj:`torch.Tensor`. 14 | 15 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 16 | :class:`Sequence`, :class:`int` and :class:`float`. 17 | """ 18 | if isinstance(data, torch.Tensor): 19 | return data 20 | elif isinstance(data, np.ndarray): 21 | return torch.from_numpy(data) 22 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 23 | return torch.tensor(data) 24 | elif isinstance(data, int): 25 | return torch.LongTensor([data]) 26 | elif isinstance(data, float): 27 | return torch.FloatTensor([data]) 28 | else: 29 | raise TypeError( 30 | f'Type {type(data)} cannot be converted to tensor.' 31 | 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' 32 | '`Sequence`, `int` and `float`') 33 | 34 | 35 | @PIPELINES.register_module() 36 | class ToTensor(object): 37 | 38 | def __init__(self, keys): 39 | self.keys = keys 40 | 41 | def __call__(self, results): 42 | for key in self.keys: 43 | results[key] = to_tensor(results[key]) 44 | return results 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + f'(keys={self.keys})' 48 | 49 | 50 | @PIPELINES.register_module() 51 | class Transpose(object): 52 | 53 | def __init__(self, keys, order): 54 | self.keys = keys 55 | self.order = order 56 | 57 | def __call__(self, results): 58 | for key in self.keys: 59 | results[key] = results[key].transpose(self.order) 60 | return results 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + \ 64 | f'(keys={self.keys}, order={self.order})' 65 | 66 | 67 | @PIPELINES.register_module() 68 | class Collect(object): 69 | """Collect data from the loader relevant to the specific task. 70 | 71 | This is usually the last stage of the data loader pipeline. 72 | 73 | Args: 74 | keys (Sequence[str]): Keys of results to be collected in ``data``. 75 | meta_keys (Sequence[str], optional): Meta keys to be converted to 76 | ``mmcv.DataContainer`` and collected in ``data[motion_metas]``. 77 | Default: ``('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')`` 78 | 79 | Returns: 80 | dict: The result dict contains the following keys 81 | - keys in``self.keys`` 82 | - ``motion_metas`` if available 83 | """ 84 | 85 | def __init__(self, 86 | keys, 87 | meta_keys=('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')): 88 | self.keys = keys 89 | self.meta_keys = meta_keys 90 | 91 | def __call__(self, results): 92 | data = {} 93 | motion_meta = {} 94 | for key in self.meta_keys: 95 | if key in results: 96 | motion_meta[key] = results[key] 97 | data['motion_metas'] = DC(motion_meta, cpu_only=True) 98 | for key in self.keys: 99 | data[key] = results[key] 100 | return data 101 | 102 | def __repr__(self): 103 | return self.__class__.__name__ + \ 104 | f'(keys={self.keys}, meta_keys={self.meta_keys})' 105 | 106 | 107 | @PIPELINES.register_module() 108 | class WrapFieldsToLists(object): 109 | """Wrap fields of the data dictionary into lists for evaluation. 110 | 111 | This class can be used as a last step of a test or validation 112 | pipeline for single image evaluation or inference. 113 | 114 | Example: 115 | >>> test_pipeline = [ 116 | >>> dict(type='LoadImageFromFile'), 117 | >>> dict(type='Normalize', 118 | mean=[123.675, 116.28, 103.53], 119 | std=[58.395, 57.12, 57.375], 120 | to_rgb=True), 121 | >>> dict(type='ImageToTensor', keys=['img']), 122 | >>> dict(type='Collect', keys=['img']), 123 | >>> dict(type='WrapIntoLists') 124 | >>> ] 125 | """ 126 | 127 | def __call__(self, results): 128 | # Wrap dict fields into lists 129 | for key, val in results.items(): 130 | results[key] = [val] 131 | return results 132 | 133 | def __repr__(self): 134 | return f'{self.__class__.__name__}()' -------------------------------------------------------------------------------- /mogen/datasets/pipelines/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import mmcv 5 | import numpy as np 6 | 7 | from ..builder import PIPELINES 8 | import torch 9 | from typing import Optional, Tuple, Union 10 | 11 | 12 | @PIPELINES.register_module() 13 | class Crop(object): 14 | r"""Crop motion sequences. 15 | 16 | Args: 17 | crop_size (int): The size of the cropped motion sequence. 18 | """ 19 | def __init__(self, 20 | crop_size: Optional[Union[int, None]] = None): 21 | self.crop_size = crop_size 22 | assert self.crop_size is not None 23 | 24 | def __call__(self, results): 25 | motion = results['motion'] 26 | length = len(motion) 27 | if length >= self.crop_size: 28 | idx = random.randint(0, length - self.crop_size) 29 | motion = motion[idx: idx + self.crop_size] 30 | results['motion_length'] = self.crop_size 31 | else: 32 | padding_length = self.crop_size - length 33 | D = motion.shape[1:] 34 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) 35 | motion = np.concatenate([motion, padding_zeros], axis=0) 36 | results['motion_length'] = length 37 | assert len(motion) == self.crop_size 38 | results['motion'] = motion 39 | results['motion_shape'] = motion.shape 40 | if length >= self.crop_size: 41 | results['motion_mask'] = torch.ones(self.crop_size).numpy() 42 | else: 43 | results['motion_mask'] = torch.cat( 44 | (torch.ones(length), torch.zeros(self.crop_size - length))).numpy() 45 | return results 46 | 47 | 48 | def __repr__(self): 49 | repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' 50 | return repr_str 51 | 52 | @PIPELINES.register_module() 53 | class RandomCrop(object): 54 | r"""Random crop motion sequences. Each sequence will be padded with zeros to the maximum length. 55 | 56 | Args: 57 | min_size (int or None): The minimum size of the cropped motion sequence (inclusive). 58 | max_size (int or None): The maximum size of the cropped motion sequence (inclusive). 59 | """ 60 | def __init__(self, 61 | min_size: Optional[Union[int, None]] = None, 62 | max_size: Optional[Union[int, None]] = None): 63 | self.min_size = min_size 64 | self.max_size = max_size 65 | assert self.min_size is not None 66 | assert self.max_size is not None 67 | 68 | def __call__(self, results): 69 | motion = results['motion'] 70 | length = len(motion) 71 | crop_size = random.randint(self.min_size, self.max_size) 72 | if length > crop_size: 73 | idx = random.randint(0, length - crop_size) 74 | motion = motion[idx: idx + crop_size] 75 | results['motion_length'] = crop_size 76 | else: 77 | results['motion_length'] = length 78 | padding_length = self.max_size - min(crop_size, length) 79 | if padding_length > 0: 80 | D = motion.shape[1:] 81 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) 82 | motion = np.concatenate([motion, padding_zeros], axis=0) 83 | results['motion'] = motion 84 | results['motion_shape'] = motion.shape 85 | if length >= self.max_size and crop_size == self.max_size: 86 | results['motion_mask'] = torch.ones(self.max_size).numpy() 87 | else: 88 | results['motion_mask'] = torch.cat(( 89 | torch.ones(min(length, crop_size)), 90 | torch.zeros(self.max_size - min(length, crop_size))), dim=0).numpy() 91 | assert len(motion) == self.max_size 92 | return results 93 | 94 | 95 | def __repr__(self): 96 | repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' 97 | repr_str += f', max_size={self.max_size})' 98 | return repr_str 99 | 100 | @PIPELINES.register_module() 101 | class Normalize(object): 102 | """Normalize motion sequences. 103 | 104 | Args: 105 | mean_path (str): Path of mean file. 106 | std_path (str): Path of std file. 107 | """ 108 | 109 | def __init__(self, mean_path, std_path, eps=1e-9): 110 | self.mean = np.load(mean_path) 111 | self.std = np.load(std_path) 112 | self.eps = eps 113 | 114 | def __call__(self, results): 115 | motion = results['motion'] 116 | motion = (motion - self.mean) / (self.std + self.eps) 117 | results['motion'] = motion 118 | results['motion_norm_mean'] = self.mean 119 | results['motion_norm_std'] = self.std 120 | return results 121 | -------------------------------------------------------------------------------- /mogen/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_sampler import DistributedSampler 2 | 3 | __all__ = ['DistributedSampler'] -------------------------------------------------------------------------------- /mogen/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DistributedSampler as _DistributedSampler 3 | 4 | 5 | class DistributedSampler(_DistributedSampler): 6 | 7 | def __init__(self, 8 | dataset, 9 | num_replicas=None, 10 | rank=None, 11 | shuffle=True, 12 | round_up=True): 13 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 14 | self.shuffle = shuffle 15 | self.round_up = round_up 16 | if self.round_up: 17 | self.total_size = self.num_samples * self.num_replicas 18 | else: 19 | self.total_size = len(self.dataset) 20 | 21 | def __iter__(self): 22 | # deterministically shuffle based on epoch 23 | if self.shuffle: 24 | g = torch.Generator() 25 | g.manual_seed(self.epoch) 26 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 27 | else: 28 | indices = torch.arange(len(self.dataset)).tolist() 29 | 30 | # add extra samples to make it evenly divisible 31 | if self.round_up: 32 | indices = ( 33 | indices * 34 | int(self.total_size / len(indices) + 1))[:self.total_size] 35 | assert len(indices) == self.total_size 36 | 37 | # subsample 38 | indices = indices[self.rank:self.total_size:self.num_replicas] 39 | if self.round_up: 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | -------------------------------------------------------------------------------- /mogen/datasets/text_motion_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | from abc import ABCMeta 5 | from collections import OrderedDict 6 | from typing import Any, List, Optional, Union 7 | 8 | import mmcv 9 | import copy 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | from mmcv.runner import get_dist_info 14 | 15 | from .base_dataset import BaseMotionDataset 16 | from .builder import DATASETS 17 | 18 | 19 | @DATASETS.register_module() 20 | class TextMotionDataset(BaseMotionDataset): 21 | """TextMotion dataset. 22 | 23 | Args: 24 | text_dir (str): Path to the directory containing the text files. 25 | """ 26 | def __init__(self, 27 | data_prefix: str, 28 | pipeline: list, 29 | dataset_name: Optional[Union[str, None]] = None, 30 | fixed_length: Optional[Union[int, None]] = None, 31 | ann_file: Optional[Union[str, None]] = None, 32 | motion_dir: Optional[Union[str, None]] = None, 33 | text_dir: Optional[Union[str, None]] = None, 34 | token_dir: Optional[Union[str, None]] = None, 35 | clip_feat_dir: Optional[Union[str, None]] = None, 36 | eval_cfg: Optional[Union[dict, None]] = None, 37 | fine_mode: Optional[bool] = False, 38 | test_mode: Optional[bool] = False): 39 | self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir) 40 | if token_dir is not None: 41 | self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) 42 | else: 43 | self.token_dir = None 44 | if clip_feat_dir is not None: 45 | self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) 46 | else: 47 | self.clip_feat_dir = None 48 | self.fine_mode = fine_mode 49 | super(TextMotionDataset, self).__init__( 50 | data_prefix=data_prefix, 51 | pipeline=pipeline, 52 | dataset_name=dataset_name, 53 | fixed_length=fixed_length, 54 | ann_file=ann_file, 55 | motion_dir=motion_dir, 56 | eval_cfg=eval_cfg, 57 | test_mode=test_mode) 58 | 59 | def load_anno(self, name): 60 | results = super().load_anno(name) 61 | text_path = os.path.join(self.text_dir, name + '.txt') 62 | text_data = [] 63 | for line in open(text_path, 'r'): 64 | text_data.append(line.strip()) 65 | results['text'] = text_data 66 | if self.token_dir is not None: 67 | token_path = os.path.join(self.token_dir, name + '.txt') 68 | token_data = [] 69 | for line in open(token_path, 'r'): 70 | token_data.append(line.strip()) 71 | results['token'] = token_data 72 | if self.clip_feat_dir is not None: 73 | clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy') 74 | clip_feat = torch.from_numpy(np.load(clip_feat_path)) 75 | results['clip_feat'] = clip_feat 76 | return results 77 | 78 | def prepare_data(self, idx: int): 79 | """"Prepare raw data for the f'{idx'}-th data.""" 80 | results = copy.deepcopy(self.data_infos[idx]) 81 | text_list = results['text'] 82 | idx = np.random.randint(0, len(text_list)) 83 | if self.fine_mode: 84 | results['text'] = json.loads(text_list[idx]) 85 | else: 86 | results['text'] = text_list[idx] 87 | if 'clip_feat' in results.keys(): 88 | results['clip_feat'] = results['clip_feat'][idx] 89 | if 'token' in results.keys(): 90 | results['token'] = results['token'][idx] 91 | results['dataset_name'] = self.dataset_name 92 | results['sample_idx'] = idx 93 | return self.pipeline(results) 94 | -------------------------------------------------------------------------------- /mogen/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import * 2 | from .losses import * 3 | from .rnns import * 4 | from .transformers import * 5 | from .attentions import * 6 | from .builder import * 7 | from .utils import * -------------------------------------------------------------------------------- /mogen/models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae_architecture import MotionVAE 2 | from .diffusion_architecture import MotionDiffusion 3 | 4 | __all__ = [ 5 | 'MotionVAE', 'MotionDiffusion' 6 | ] -------------------------------------------------------------------------------- /mogen/models/architectures/base_architecture.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from mmcv.runner import BaseModule 7 | 8 | 9 | def to_cpu(x): 10 | if isinstance(x, torch.Tensor): 11 | return x.detach().cpu() 12 | return x 13 | 14 | 15 | class BaseArchitecture(BaseModule): 16 | """Base class for mogen architecture.""" 17 | 18 | def __init__(self, init_cfg=None): 19 | super(BaseArchitecture, self).__init__(init_cfg) 20 | 21 | def forward_train(self, **kwargs): 22 | pass 23 | 24 | def forward_test(self, **kwargs): 25 | pass 26 | 27 | def _parse_losses(self, losses): 28 | """Parse the raw outputs (losses) of the network. 29 | Args: 30 | losses (dict): Raw output of the network, which usually contain 31 | losses and other necessary information. 32 | Returns: 33 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ 34 | which may be a weighted sum of all losses, log_vars contains \ 35 | all the variables to be sent to the logger. 36 | """ 37 | log_vars = OrderedDict() 38 | for loss_name, loss_value in losses.items(): 39 | if isinstance(loss_value, torch.Tensor): 40 | log_vars[loss_name] = loss_value.mean() 41 | elif isinstance(loss_value, list): 42 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 43 | else: 44 | raise TypeError( 45 | f'{loss_name} is not a tensor or list of tensors') 46 | 47 | loss = sum(_value for _key, _value in log_vars.items() 48 | if 'loss' in _key) 49 | 50 | log_vars['loss'] = loss 51 | for loss_name, loss_value in log_vars.items(): 52 | # reduce loss when distributed training 53 | if dist.is_available() and dist.is_initialized(): 54 | loss_value = loss_value.data.clone() 55 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 56 | log_vars[loss_name] = loss_value.item() 57 | 58 | return loss, log_vars 59 | 60 | def train_step(self, data, optimizer): 61 | """The iteration step during training. 62 | This method defines an iteration step during training, except for the 63 | back propagation and optimizer updating, which are done in an optimizer 64 | hook. Note that in some complicated cases or models, the whole process 65 | including back propagation and optimizer updating is also defined in 66 | this method, such as GAN. 67 | Args: 68 | data (dict): The output of dataloader. 69 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 70 | runner is passed to ``train_step()``. This argument is unused 71 | and reserved. 72 | Returns: 73 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ 74 | ``num_samples``. 75 | - ``loss`` is a tensor for back propagation, which can be a 76 | weighted sum of multiple losses. 77 | - ``log_vars`` contains all the variables to be sent to the 78 | logger. 79 | - ``num_samples`` indicates the batch size (when the model is 80 | DDP, it means the batch size on each GPU), which is used for 81 | averaging the logs. 82 | """ 83 | losses = self(**data) 84 | loss, log_vars = self._parse_losses(losses) 85 | 86 | outputs = dict( 87 | loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) 88 | 89 | return outputs 90 | 91 | def val_step(self, data, optimizer=None): 92 | """The iteration step during validation. 93 | This method shares the same signature as :func:`train_step`, but used 94 | during val epochs. Note that the evaluation after training epochs is 95 | not implemented with this method, but an evaluation hook. 96 | """ 97 | losses = self(**data) 98 | loss, log_vars = self._parse_losses(losses) 99 | 100 | outputs = dict( 101 | loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) 102 | 103 | return outputs 104 | 105 | def forward(self, **kwargs): 106 | if self.training: 107 | return self.forward_train(**kwargs) 108 | else: 109 | return self.forward_test(**kwargs) 110 | 111 | def split_results(self, results): 112 | B = results['motion'].shape[0] 113 | output = [] 114 | for i in range(B): 115 | batch_output = dict() 116 | batch_output['motion'] = to_cpu(results['motion'][i]) 117 | batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) 118 | batch_output['motion_length'] = to_cpu(results['motion_length'][i]) 119 | batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) 120 | if 'pred_motion_length' in results.keys(): 121 | batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i]) 122 | else: 123 | batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i]) 124 | if 'pred_motion_mask' in results: 125 | batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i]) 126 | else: 127 | batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i]) 128 | if 'motion_metas' in results.keys(): 129 | motion_metas = results['motion_metas'][i] 130 | if 'text' in motion_metas.keys(): 131 | batch_output['text'] = motion_metas['text'] 132 | if 'token' in motion_metas.keys(): 133 | batch_output['token'] = motion_metas['token'] 134 | output.append(batch_output) 135 | return output 136 | -------------------------------------------------------------------------------- /mogen/models/architectures/diffusion_architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_architecture import BaseArchitecture 6 | from ..builder import ( 7 | ARCHITECTURES, 8 | build_architecture, 9 | build_submodule, 10 | build_loss 11 | ) 12 | from ..utils.gaussian_diffusion import ( 13 | GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler, 14 | ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion 15 | ) 16 | 17 | def build_diffusion(cfg): 18 | beta_scheduler = cfg['beta_scheduler'] 19 | diffusion_steps = cfg['diffusion_steps'] 20 | 21 | betas = get_named_beta_schedule(beta_scheduler, diffusion_steps) 22 | model_mean_type = { 23 | 'start_x': ModelMeanType.START_X, 24 | 'previous_x': ModelMeanType.PREVIOUS_X, 25 | 'epsilon': ModelMeanType.EPSILON 26 | }[cfg['model_mean_type']] 27 | model_var_type = { 28 | 'learned': ModelVarType.LEARNED, 29 | 'fixed_small': ModelVarType.FIXED_SMALL, 30 | 'fixed_large': ModelVarType.FIXED_LARGE, 31 | 'learned_range': ModelVarType.LEARNED_RANGE 32 | }[cfg['model_var_type']] 33 | if cfg.get('respace', None) is not None: 34 | diffusion = SpacedDiffusion( 35 | use_timesteps=space_timesteps(diffusion_steps, cfg['respace']), 36 | betas=betas, 37 | model_mean_type=model_mean_type, 38 | model_var_type=model_var_type, 39 | loss_type=LossType.MSE 40 | ) 41 | else: 42 | diffusion = GaussianDiffusion( 43 | betas=betas, 44 | model_mean_type=model_mean_type, 45 | model_var_type=model_var_type, 46 | loss_type=LossType.MSE) 47 | return diffusion 48 | 49 | 50 | @ARCHITECTURES.register_module() 51 | class MotionDiffusion(BaseArchitecture): 52 | 53 | def __init__(self, 54 | model=None, 55 | loss_recon=None, 56 | diffusion_train=None, 57 | diffusion_test=None, 58 | init_cfg=None, 59 | inference_type='ddpm', 60 | **kwargs): 61 | super().__init__(init_cfg=init_cfg, **kwargs) 62 | self.model = build_submodule(model) 63 | self.loss_recon = build_loss(loss_recon) 64 | self.diffusion_train = build_diffusion(diffusion_train) 65 | self.diffusion_test = build_diffusion(diffusion_test) 66 | self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train) 67 | self.inference_type = inference_type 68 | 69 | def forward(self, **kwargs): 70 | motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float() 71 | sample_idx = kwargs.get('sample_idx', None) 72 | clip_feat = kwargs.get('clip_feat', None) 73 | B, T = motion.shape[:2] 74 | text = [] 75 | for i in range(B): 76 | text.append(kwargs['motion_metas'][i]['text']) 77 | 78 | if self.training: 79 | t, _ = self.sampler.sample(B, motion.device) 80 | output = self.diffusion_train.training_losses( 81 | model=self.model, 82 | x_start=motion, 83 | t=t, 84 | model_kwargs={ 85 | 'motion_mask': motion_mask, 86 | 'motion_length': kwargs['motion_length'], 87 | 'text': text, 88 | 'clip_feat': clip_feat, 89 | 'sample_idx': sample_idx} 90 | ) 91 | pred, target = output['pred'], output['target'] 92 | recon_loss = self.loss_recon(pred, target, reduction_override='none') 93 | recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() 94 | loss = {'recon_loss': recon_loss} 95 | return loss 96 | else: 97 | dim_pose = kwargs['motion'].shape[-1] 98 | model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs) 99 | model_kwargs['motion_mask'] = motion_mask 100 | model_kwargs['sample_idx'] = sample_idx 101 | inference_kwargs = kwargs.get('inference_kwargs', {}) 102 | if self.inference_type == 'ddpm': 103 | output = self.diffusion_test.p_sample_loop( 104 | self.model, 105 | (B, T, dim_pose), 106 | clip_denoised=False, 107 | progress=False, 108 | model_kwargs=model_kwargs, 109 | **inference_kwargs 110 | ) 111 | else: 112 | output = self.diffusion_test.ddim_sample_loop( 113 | self.model, 114 | (B, T, dim_pose), 115 | clip_denoised=False, 116 | progress=False, 117 | model_kwargs=model_kwargs, 118 | eta=0, 119 | **inference_kwargs 120 | ) 121 | if getattr(self.model, "post_process") is not None: 122 | output = self.model.post_process(output) 123 | results = kwargs 124 | results['pred_motion'] = output 125 | results = self.split_results(results) 126 | return results 127 | 128 | -------------------------------------------------------------------------------- /mogen/models/architectures/vae_architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_architecture import BaseArchitecture 6 | from ..builder import ( 7 | ARCHITECTURES, 8 | build_architecture, 9 | build_submodule, 10 | build_loss 11 | ) 12 | 13 | 14 | @ARCHITECTURES.register_module() 15 | class PoseVAE(BaseArchitecture): 16 | 17 | def __init__(self, 18 | encoder=None, 19 | decoder=None, 20 | loss_recon=None, 21 | kl_div_loss_weight=None, 22 | init_cfg=None, 23 | **kwargs): 24 | super().__init__(init_cfg=init_cfg, **kwargs) 25 | self.encoder = build_submodule(encoder) 26 | self.decoder = build_submodule(decoder) 27 | self.loss_recon = build_loss(loss_recon) 28 | self.kl_div_loss_weight = kl_div_loss_weight 29 | 30 | def reparameterize(self, mu, logvar): 31 | std = torch.exp(logvar / 2) 32 | 33 | eps = std.data.new(std.size()).normal_() 34 | latent_code = eps.mul(std).add_(mu) 35 | return latent_code 36 | 37 | def encode(self, pose): 38 | mu, logvar = self.encoder(pose) 39 | return mu 40 | 41 | def forward(self, **kwargs): 42 | motion = kwargs['motion'].float() 43 | B, T = motion.shape[:2] 44 | pose = motion.reshape(B * T, -1) 45 | pose = pose[:, :-4] 46 | 47 | mu, logvar = self.encoder(pose) 48 | z = self.reparameterize(mu, logvar) 49 | pred = self.decoder(z) 50 | 51 | loss = dict() 52 | recon_loss = self.loss_recon(pred, pose, reduction_override='none') 53 | loss['recon_loss'] = recon_loss 54 | if self.kl_div_loss_weight is not None: 55 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 56 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) 57 | 58 | return loss 59 | 60 | 61 | @ARCHITECTURES.register_module() 62 | class MotionVAE(BaseArchitecture): 63 | 64 | def __init__(self, 65 | encoder=None, 66 | decoder=None, 67 | loss_recon=None, 68 | kl_div_loss_weight=None, 69 | init_cfg=None, 70 | **kwargs): 71 | super().__init__(init_cfg=init_cfg, **kwargs) 72 | self.encoder = build_submodule(encoder) 73 | self.decoder = build_submodule(decoder) 74 | self.loss_recon = build_loss(loss_recon) 75 | self.kl_div_loss_weight = kl_div_loss_weight 76 | 77 | def sample(self, std=1, latent_code=None): 78 | if latent_code is not None: 79 | z = latent_code 80 | else: 81 | z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std 82 | output = self.decoder(z) 83 | if self.use_normalization: 84 | output = output * self.motion_std 85 | output = output + self.motion_mean 86 | return output 87 | 88 | def reparameterize(self, mu, logvar): 89 | std = torch.exp(logvar / 2) 90 | 91 | eps = std.data.new(std.size()).normal_() 92 | latent_code = eps.mul(std).add_(mu) 93 | return latent_code 94 | 95 | def encode(self, motion, motion_mask): 96 | mu, logvar = self.encoder(motion, motion_mask) 97 | return self.reparameterize(mu, logvar) 98 | 99 | def decode(self, z, motion_mask): 100 | return self.decoder(z, motion_mask) 101 | 102 | def forward(self, **kwargs): 103 | motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] 104 | B, T = motion.shape[:2] 105 | 106 | mu, logvar = self.encoder(motion, motion_mask) 107 | z = self.reparameterize(mu, logvar) 108 | pred = self.decoder(z, motion_mask) 109 | 110 | loss = dict() 111 | recon_loss = self.loss_recon(pred, motion, reduction_override='none') 112 | recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() 113 | loss['recon_loss'] = recon_loss 114 | if self.kl_div_loss_weight is not None: 115 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 116 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) 117 | 118 | return loss -------------------------------------------------------------------------------- /mogen/models/attentions/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficient_attention import ( 2 | EfficientSelfAttention, 3 | EfficientCrossAttention 4 | ) 5 | from .semantics_modulated import SemanticsModulatedAttention 6 | from .base_attention import BaseMixedAttention -------------------------------------------------------------------------------- /mogen/models/attentions/base_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..utils.stylization_block import StylizationBlock 5 | from ..builder import ATTENTIONS 6 | 7 | 8 | @ATTENTIONS.register_module() 9 | class BaseMixedAttention(nn.Module): 10 | 11 | def __init__(self, latent_dim, 12 | text_latent_dim, 13 | num_heads, 14 | dropout, 15 | time_embed_dim): 16 | super().__init__() 17 | self.num_heads = num_heads 18 | 19 | self.norm = nn.LayerNorm(latent_dim) 20 | self.text_norm = nn.LayerNorm(text_latent_dim) 21 | 22 | self.query = nn.Linear(latent_dim, latent_dim) 23 | self.key_text = nn.Linear(text_latent_dim, latent_dim) 24 | self.value_text = nn.Linear(text_latent_dim, latent_dim) 25 | self.key_motion = nn.Linear(latent_dim, latent_dim) 26 | self.value_motion = nn.Linear(latent_dim, latent_dim) 27 | 28 | self.dropout = nn.Dropout(dropout) 29 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 30 | 31 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): 32 | """ 33 | x: B, T, D 34 | xf: B, N, L 35 | """ 36 | B, T, D = x.shape 37 | N = xf.shape[1] + x.shape[1] 38 | H = self.num_heads 39 | # B, T, D 40 | query = self.query(self.norm(x)).view(B, T, H, -1) 41 | # B, N, D 42 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) 43 | key = torch.cat(( 44 | self.key_text(self.text_norm(xf)), 45 | self.key_motion(self.norm(x)) 46 | ), dim=1).view(B, N, H, -1) 47 | 48 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 49 | motion_mask = src_mask.view(B, 1, T, 1) 50 | text_mask = text_cond_type.view(B, 1, -1, 1) 51 | mask = torch.cat((text_mask, motion_mask), dim=2) 52 | attention = attention + (1 - mask) * -1000000 53 | attention = F.softmax(attention, dim=2) 54 | 55 | value = torch.cat(( 56 | self.value_text(self.text_norm(xf)) * text_cond_type, 57 | self.value_motion(self.norm(x)) * src_mask, 58 | ), dim=1).view(B, N, H, -1) 59 | 60 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 61 | y = x + self.proj_out(y, emb) 62 | return y 63 | 64 | 65 | @ATTENTIONS.register_module() 66 | class BaseSelfAttention(nn.Module): 67 | 68 | def __init__(self, latent_dim, 69 | num_heads, 70 | dropout, 71 | time_embed_dim): 72 | super().__init__() 73 | self.num_heads = num_heads 74 | 75 | self.norm = nn.LayerNorm(latent_dim) 76 | self.query = nn.Linear(latent_dim, latent_dim) 77 | self.key = nn.Linear(latent_dim, latent_dim) 78 | self.value = nn.Linear(latent_dim, latent_dim) 79 | 80 | self.dropout = nn.Dropout(dropout) 81 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 82 | 83 | def forward(self, x, emb, src_mask, **kwargs): 84 | """ 85 | x: B, T, D 86 | """ 87 | B, T, D = x.shape 88 | H = self.num_heads 89 | # B, T, D 90 | query = self.query(self.norm(x)).view(B, T, H, -1) 91 | # B, N, D 92 | key = self.key(self.norm(x)).view(B, T, H, -1) 93 | 94 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 95 | mask = src_mask.view(B, 1, T, 1) 96 | attention = attention + (1 - mask) * -1000000 97 | attention = F.softmax(attention, dim=2) 98 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) 99 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 100 | y = x + self.proj_out(y, emb) 101 | return y 102 | 103 | 104 | @ATTENTIONS.register_module() 105 | class BaseCrossAttention(nn.Module): 106 | 107 | def __init__(self, latent_dim, 108 | text_latent_dim, 109 | num_heads, 110 | dropout, 111 | time_embed_dim): 112 | super().__init__() 113 | self.num_heads = num_heads 114 | 115 | self.norm = nn.LayerNorm(latent_dim) 116 | self.text_norm = nn.LayerNorm(text_latent_dim) 117 | 118 | self.query = nn.Linear(latent_dim, latent_dim) 119 | self.key = nn.Linear(text_latent_dim, latent_dim) 120 | self.value = nn.Linear(text_latent_dim, latent_dim) 121 | 122 | self.dropout = nn.Dropout(dropout) 123 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 124 | 125 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): 126 | """ 127 | x: B, T, D 128 | xf: B, N, L 129 | """ 130 | B, T, D = x.shape 131 | N = xf.shape[1] 132 | H = self.num_heads 133 | # B, T, D 134 | query = self.query(self.norm(x)).view(B, T, H, -1) 135 | # B, N, D 136 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) 137 | key = self.key(self.text_norm(xf)).view(B, N, H, -1) 138 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 139 | mask = text_cond_type.view(B, 1, -1, 1) 140 | attention = attention + (1 - mask) * -1000000 141 | attention = F.softmax(attention, dim=2) 142 | 143 | value = (self.value(self.text_norm(xf)) * text_cond_type).view(B, N, H, -1) 144 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 145 | y = x + self.proj_out(y, emb) 146 | return y 147 | -------------------------------------------------------------------------------- /mogen/models/attentions/efficient_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..utils.stylization_block import StylizationBlock 5 | from ..builder import ATTENTIONS 6 | 7 | 8 | @ATTENTIONS.register_module() 9 | class EfficientSelfAttention(nn.Module): 10 | 11 | def __init__(self, latent_dim, num_heads, dropout, time_embed_dim=None): 12 | super().__init__() 13 | self.num_heads = num_heads 14 | self.norm = nn.LayerNorm(latent_dim) 15 | self.query = nn.Linear(latent_dim, latent_dim) 16 | self.key = nn.Linear(latent_dim, latent_dim) 17 | self.value = nn.Linear(latent_dim, latent_dim) 18 | self.dropout = nn.Dropout(dropout) 19 | self.time_embed_dim = time_embed_dim 20 | if time_embed_dim is not None: 21 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 22 | 23 | def forward(self, x, src_mask, emb=None, **kwargs): 24 | """ 25 | x: B, T, D 26 | """ 27 | B, T, D = x.shape 28 | H = self.num_heads 29 | # B, T, D 30 | query = self.query(self.norm(x)) 31 | # B, T, D 32 | key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000) 33 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 34 | key = F.softmax(key.view(B, T, H, -1), dim=1) 35 | # B, T, H, HD 36 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) 37 | # B, H, HD, HD 38 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 39 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 40 | if self.time_embed_dim is None: 41 | y = x + y 42 | else: 43 | y = x + self.proj_out(y, emb) 44 | return y 45 | 46 | 47 | @ATTENTIONS.register_module() 48 | class EfficientCrossAttention(nn.Module): 49 | 50 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, time_embed_dim): 51 | super().__init__() 52 | self.num_heads = num_heads 53 | self.norm = nn.LayerNorm(latent_dim) 54 | self.text_norm = nn.LayerNorm(text_latent_dim) 55 | self.query = nn.Linear(latent_dim, latent_dim) 56 | self.key = nn.Linear(text_latent_dim, latent_dim) 57 | self.value = nn.Linear(text_latent_dim, latent_dim) 58 | self.dropout = nn.Dropout(dropout) 59 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 60 | 61 | def forward(self, x, xf, emb, cond_type=None, **kwargs): 62 | """ 63 | x: B, T, D 64 | xf: B, N, L 65 | """ 66 | B, T, D = x.shape 67 | N = xf.shape[1] 68 | H = self.num_heads 69 | # B, T, D 70 | query = self.query(self.norm(x)) 71 | # B, N, D 72 | key = self.key(self.text_norm(xf)) 73 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 74 | if cond_type is None: 75 | key = F.softmax(key.view(B, N, H, -1), dim=1) 76 | # B, N, H, HD 77 | value = self.value(self.text_norm(xf)).view(B, N, H, -1) 78 | else: 79 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) 80 | key = key + (1 - text_cond_type) * -1000000 81 | key = F.softmax(key.view(B, N, H, -1), dim=1) 82 | value = self.value(self.text_norm(xf) * text_cond_type).view(B, N, H, -1) 83 | # B, H, HD, HD 84 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 85 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 86 | y = x + self.proj_out(y, emb) 87 | return y 88 | -------------------------------------------------------------------------------- /mogen/models/attentions/semantics_modulated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..utils.stylization_block import StylizationBlock 5 | from ..builder import ATTENTIONS 6 | 7 | 8 | def zero_module(module): 9 | """ 10 | Zero out the parameters of a module and return it. 11 | """ 12 | for p in module.parameters(): 13 | p.detach().zero_() 14 | return module 15 | 16 | 17 | @ATTENTIONS.register_module() 18 | class SemanticsModulatedAttention(nn.Module): 19 | 20 | def __init__(self, latent_dim, 21 | text_latent_dim, 22 | num_heads, 23 | dropout, 24 | time_embed_dim): 25 | super().__init__() 26 | self.num_heads = num_heads 27 | 28 | self.norm = nn.LayerNorm(latent_dim) 29 | self.text_norm = nn.LayerNorm(text_latent_dim) 30 | 31 | self.query = nn.Linear(latent_dim, latent_dim) 32 | self.key_text = nn.Linear(text_latent_dim, latent_dim) 33 | self.value_text = nn.Linear(text_latent_dim, latent_dim) 34 | self.key_motion = nn.Linear(latent_dim, latent_dim) 35 | self.value_motion = nn.Linear(latent_dim, latent_dim) 36 | 37 | self.retr_norm1 = nn.LayerNorm(2 * latent_dim) 38 | self.retr_norm2 = nn.LayerNorm(latent_dim) 39 | self.key_retr = nn.Linear(2 * latent_dim, latent_dim) 40 | self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) 41 | 42 | self.dropout = nn.Dropout(dropout) 43 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 44 | 45 | def forward(self, x, xf, emb, src_mask, cond_type, re_dict=None): 46 | """ 47 | x: B, T, D 48 | xf: B, N, L 49 | """ 50 | B, T, D = x.shape 51 | re_motion = re_dict['re_motion'] 52 | re_text = re_dict['re_text'] 53 | re_mask = re_dict['re_mask'] 54 | re_mask = re_mask.reshape(B, -1, 1) 55 | N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] 56 | H = self.num_heads 57 | # B, T, D 58 | query = self.query(self.norm(x)) 59 | # B, N, D 60 | text_cond_type = (cond_type % 10 > 0).float() 61 | retr_cond_type = (cond_type // 10 > 0).float() 62 | re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) 63 | re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D) 64 | key = torch.cat(( 65 | self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000, 66 | self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000, 67 | self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 68 | ), dim=1) 69 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 70 | key = F.softmax(key.view(B, N, H, -1), dim=1) 71 | # B, N, H, HD 72 | re_feat_value = re_motion.reshape(B, -1, D) 73 | value = torch.cat(( 74 | self.value_text(self.text_norm(xf)) * text_cond_type, 75 | self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask, 76 | self.value_motion(self.norm(x)) * src_mask, 77 | ), dim=1).view(B, N, H, -1) 78 | # B, H, HD, HD 79 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 80 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 81 | y = x + self.proj_out(y, emb) 82 | return y -------------------------------------------------------------------------------- /mogen/models/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import MODELS as MMCV_MODELS 2 | from mmcv.utils import Registry 3 | 4 | 5 | def build_from_cfg(cfg, registry, default_args=None): 6 | if cfg is None: 7 | return None 8 | return MMCV_MODELS.build_func(cfg, registry, default_args) 9 | 10 | 11 | MODELS = Registry('models', parent=MMCV_MODELS, build_func=build_from_cfg) 12 | 13 | LOSSES = MODELS 14 | ARCHITECTURES = MODELS 15 | SUBMODULES = MODELS 16 | ATTENTIONS = MODELS 17 | 18 | def build_loss(cfg): 19 | """Build loss.""" 20 | return LOSSES.build(cfg) 21 | 22 | def build_architecture(cfg): 23 | """Build framework.""" 24 | return ARCHITECTURES.build(cfg) 25 | 26 | def build_submodule(cfg): 27 | """Build submodule.""" 28 | return SUBMODULES.build(cfg) 29 | 30 | def build_attention(cfg): 31 | """Build attention.""" 32 | return ATTENTIONS.build(cfg) 33 | -------------------------------------------------------------------------------- /mogen/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .mse_loss import MSELoss 2 | from .utils import ( 3 | convert_to_one_hot, 4 | reduce_loss, 5 | weight_reduce_loss, 6 | weighted_loss, 7 | ) 8 | 9 | 10 | __all__ = [ 11 | 'convert_to_one_hot', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 12 | 'MSELoss' 13 | ] -------------------------------------------------------------------------------- /mogen/models/losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..builder import LOSSES 5 | from .utils import weighted_loss 6 | 7 | 8 | def gmof(x, sigma): 9 | """Geman-McClure error function.""" 10 | x_squared = x**2 11 | sigma_squared = sigma**2 12 | return (sigma_squared * x_squared) / (sigma_squared + x_squared) 13 | 14 | 15 | @weighted_loss 16 | def mse_loss(pred, target): 17 | """Warpper of mse loss.""" 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | @weighted_loss 22 | def mse_loss_with_gmof(pred, target, sigma): 23 | """Extended MSE Loss with GMOF.""" 24 | loss = F.mse_loss(pred, target, reduction='none') 25 | loss = gmof(loss, sigma) 26 | return loss 27 | 28 | 29 | @LOSSES.register_module() 30 | class MSELoss(nn.Module): 31 | """MSELoss. 32 | Args: 33 | reduction (str, optional): The method that reduces the loss to a 34 | scalar. Options are "none", "mean" and "sum". 35 | loss_weight (float, optional): The weight of the loss. Defaults to 1.0 36 | """ 37 | 38 | def __init__(self, reduction='mean', loss_weight=1.0): 39 | super().__init__() 40 | assert reduction in (None, 'none', 'mean', 'sum') 41 | reduction = 'none' if reduction is None else reduction 42 | self.reduction = reduction 43 | self.loss_weight = loss_weight 44 | 45 | def forward(self, 46 | pred, 47 | target, 48 | weight=None, 49 | avg_factor=None, 50 | reduction_override=None): 51 | """Forward function of loss. 52 | Args: 53 | pred (torch.Tensor): The prediction. 54 | target (torch.Tensor): The learning target of the prediction. 55 | weight (torch.Tensor, optional): Weight of the loss for each 56 | prediction. Defaults to None. 57 | avg_factor (int, optional): Average factor that is used to average 58 | the loss. Defaults to None. 59 | reduction_override (str, optional): The reduction method used to 60 | override the original reduction method of the loss. 61 | Defaults to None. 62 | Returns: 63 | torch.Tensor: The calculated loss 64 | """ 65 | assert reduction_override in (None, 'none', 'mean', 'sum') 66 | reduction = ( 67 | reduction_override if reduction_override else self.reduction) 68 | loss = self.loss_weight * mse_loss( 69 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 70 | return loss -------------------------------------------------------------------------------- /mogen/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def reduce_loss(loss, reduction): 8 | """Reduce loss as specified. 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | Return: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | elif reduction_enum == 2: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 26 | """Apply element-wise weight and reduce loss. 27 | Args: 28 | loss (Tensor): Element-wise loss. 29 | weight (Tensor): Element-wise weights. 30 | reduction (str): Same as built-in losses of PyTorch. 31 | avg_factor (float): Average factor when computing the mean of losses. 32 | Returns: 33 | Tensor: Processed loss values. 34 | """ 35 | # if weight is specified, apply element-wise weight 36 | if weight is not None: 37 | loss = loss * weight 38 | 39 | # if avg_factor is not specified, just reduce the loss 40 | if avg_factor is None: 41 | loss = reduce_loss(loss, reduction) 42 | else: 43 | # if reduction is mean, then average the loss by avg_factor 44 | if reduction == 'mean': 45 | loss = loss.sum() / avg_factor 46 | # if reduction is 'none', then do nothing, otherwise raise an error 47 | elif reduction != 'none': 48 | raise ValueError('avg_factor can not be used with reduction="sum"') 49 | return loss 50 | 51 | 52 | def weighted_loss(loss_func): 53 | """Create a weighted version of a given loss function. 54 | To use this decorator, the loss function must have the signature like 55 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 56 | element-wise loss without any reduction. This decorator will add weight 57 | and reduction arguments to the function. The decorated function will have 58 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 59 | avg_factor=None, **kwargs)`. 60 | :Example: 61 | >>> import torch 62 | >>> @weighted_loss 63 | >>> def l1_loss(pred, target): 64 | >>> return (pred - target).abs() 65 | >>> pred = torch.Tensor([0, 2, 3]) 66 | >>> target = torch.Tensor([1, 1, 1]) 67 | >>> weight = torch.Tensor([1, 0, 1]) 68 | >>> l1_loss(pred, target) 69 | tensor(1.3333) 70 | >>> l1_loss(pred, target, weight) 71 | tensor(1.) 72 | >>> l1_loss(pred, target, reduction='none') 73 | tensor([1., 1., 2.]) 74 | >>> l1_loss(pred, target, weight, avg_factor=2) 75 | tensor(1.5000) 76 | """ 77 | 78 | @functools.wraps(loss_func) 79 | def wrapper(pred, 80 | target, 81 | weight=None, 82 | reduction='mean', 83 | avg_factor=None, 84 | **kwargs): 85 | # get element-wise loss 86 | loss = loss_func(pred, target, **kwargs) 87 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 88 | return loss 89 | 90 | return wrapper 91 | 92 | 93 | def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: 94 | """This function converts target class indices to one-hot vectors, given 95 | the number of classes. 96 | Args: 97 | targets (Tensor): The ground truth label of the prediction 98 | with shape (N, 1) 99 | classes (int): the number of classes. 100 | Returns: 101 | Tensor: Processed loss values. 102 | """ 103 | assert (torch.max(targets).item() < 104 | classes), 'Class Index must be less than number of classes' 105 | one_hot_targets = torch.zeros((targets.shape[0], classes), 106 | dtype=torch.long, 107 | device=targets.device) 108 | one_hot_targets.scatter_(1, targets.long(), 1) 109 | return one_hot_targets 110 | -------------------------------------------------------------------------------- /mogen/models/rnns/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2m_bigru import T2MMotionEncoder, T2MTextEncoder -------------------------------------------------------------------------------- /mogen/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import ACTOREncoder, ACTORDecoder 2 | from .motiondiffuse import MotionDiffuseTransformer 3 | from .remodiffuse import ReMoDiffuseTransformer 4 | from .mdm import MDMTransformer 5 | from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding -------------------------------------------------------------------------------- /mogen/models/transformers/actor.py: -------------------------------------------------------------------------------- 1 | from cv2 import norm 2 | import torch 3 | from torch import layer_norm, nn 4 | from mmcv.runner import BaseModule 5 | import numpy as np 6 | 7 | from ..builder import SUBMODULES 8 | from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding 9 | import math 10 | 11 | 12 | @SUBMODULES.register_module() 13 | class ACTOREncoder(BaseModule): 14 | def __init__(self, 15 | max_seq_len=16, 16 | njoints=None, 17 | nfeats=None, 18 | input_feats=None, 19 | latent_dim=256, 20 | output_dim=256, 21 | condition_dim=None, 22 | num_heads=4, 23 | ff_size=1024, 24 | num_layers=8, 25 | activation='gelu', 26 | dropout=0.1, 27 | use_condition=False, 28 | num_class=None, 29 | use_final_proj=False, 30 | output_var=False, 31 | pos_embedding='sinusoidal', 32 | init_cfg=None): 33 | super().__init__(init_cfg=init_cfg) 34 | self.njoints = njoints 35 | self.nfeats = nfeats 36 | if input_feats is None: 37 | assert self.njoints is not None and self.nfeats is not None 38 | self.input_feats = njoints * nfeats 39 | else: 40 | self.input_feats = input_feats 41 | self.max_seq_len = max_seq_len 42 | self.latent_dim = latent_dim 43 | self.condition_dim = condition_dim 44 | self.use_condition = use_condition 45 | self.num_class = num_class 46 | self.use_final_proj = use_final_proj 47 | self.output_var = output_var 48 | self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) 49 | if self.use_condition: 50 | if num_class is None: 51 | self.mu_layer = build_MLP(self.condition_dim, self.latent_dim) 52 | if self.output_var: 53 | self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim) 54 | else: 55 | self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) 56 | if self.output_var: 57 | self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) 58 | else: 59 | if self.output_var: 60 | self.query = nn.Parameter(torch.randn(2, self.latent_dim)) 61 | else: 62 | self.query = nn.Parameter(torch.randn(1, self.latent_dim)) 63 | if pos_embedding == 'sinusoidal': 64 | self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) 65 | else: 66 | self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2) 67 | seqTransEncoderLayer = nn.TransformerEncoderLayer( 68 | d_model=self.latent_dim, 69 | nhead=num_heads, 70 | dim_feedforward=ff_size, 71 | dropout=dropout, 72 | activation=activation) 73 | self.seqTransEncoder = nn.TransformerEncoder( 74 | seqTransEncoderLayer, 75 | num_layers=num_layers) 76 | 77 | def forward(self, motion, motion_mask=None, condition=None): 78 | B, T = motion.shape[:2] 79 | motion = motion.view(B, T, -1) 80 | feature = self.skelEmbedding(motion) 81 | if self.use_condition: 82 | if self.output_var: 83 | if self.num_class is None: 84 | sigma_query = self.sigma_layer(condition).view(B, 1, -1) 85 | else: 86 | sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1) 87 | feature = torch.cat((sigma_query, feature), dim=1) 88 | if self.num_class is None: 89 | mu_query = self.mu_layer(condition).view(B, 1, -1) 90 | else: 91 | mu_query = self.mu_layer[condition.long()].view(B, 1, -1) 92 | feature = torch.cat((mu_query, feature), dim=1) 93 | else: 94 | query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1) 95 | feature = torch.cat((query, feature), dim=1) 96 | if self.output_var: 97 | motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool() 98 | else: 99 | motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool() 100 | feature = feature.permute(1, 0, 2).contiguous() 101 | feature = self.pos_encoder(feature) 102 | feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask) 103 | if self.use_final_proj: 104 | mu = self.final_mu(feature[0]) 105 | if self.output_var: 106 | sigma = self.final_sigma(feature[1]) 107 | return mu, sigma 108 | return mu 109 | else: 110 | if self.output_var: 111 | return feature[0], feature[1] 112 | else: 113 | return feature[0] 114 | 115 | 116 | @SUBMODULES.register_module() 117 | class ACTORDecoder(BaseModule): 118 | 119 | def __init__(self, 120 | max_seq_len=16, 121 | njoints=None, 122 | nfeats=None, 123 | input_feats=None, 124 | input_dim=256, 125 | latent_dim=256, 126 | condition_dim=None, 127 | num_heads=4, 128 | ff_size=1024, 129 | num_layers=8, 130 | activation='gelu', 131 | dropout=0.1, 132 | use_condition=False, 133 | num_class=None, 134 | pos_embedding='sinusoidal', 135 | init_cfg=None): 136 | super().__init__(init_cfg=init_cfg) 137 | if input_dim != latent_dim: 138 | self.linear = nn.Linear(input_dim, latent_dim) 139 | else: 140 | self.linear = nn.Identity() 141 | self.njoints = njoints 142 | self.nfeats = nfeats 143 | if input_feats is None: 144 | assert self.njoints is not None and self.nfeats is not None 145 | self.input_feats = njoints * nfeats 146 | else: 147 | self.input_feats = input_feats 148 | self.max_seq_len = max_seq_len 149 | self.input_dim = input_dim 150 | self.latent_dim = latent_dim 151 | self.condition_dim = condition_dim 152 | self.use_condition = use_condition 153 | self.num_class = num_class 154 | if self.use_condition: 155 | if num_class is None: 156 | self.condition_bias = build_MLP(condition_dim, latent_dim) 157 | else: 158 | self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim)) 159 | if pos_embedding == 'sinusoidal': 160 | self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) 161 | else: 162 | self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len) 163 | seqTransDecoderLayer = nn.TransformerDecoderLayer( 164 | d_model=self.latent_dim, 165 | nhead=num_heads, 166 | dim_feedforward=ff_size, 167 | dropout=dropout, 168 | activation=activation) 169 | self.seqTransDecoder = nn.TransformerDecoder( 170 | seqTransDecoderLayer, 171 | num_layers=num_layers) 172 | 173 | self.final = nn.Linear(self.latent_dim, self.input_feats) 174 | 175 | def forward(self, input, motion_mask=None, condition=None): 176 | B = input.shape[0] 177 | T = self.max_seq_len 178 | input = self.linear(input) 179 | if self.use_condition: 180 | if self.num_class is None: 181 | condition = self.condition_bias(condition) 182 | else: 183 | condition = self.condition_bias[condition.long()].squeeze(1) 184 | input = input + condition 185 | query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1) 186 | input = input.view(1, B, -1) 187 | feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool()) 188 | pose = self.final(feature).permute(1, 0, 2).contiguous() 189 | return pose 190 | -------------------------------------------------------------------------------- /mogen/models/transformers/mdm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import clip 6 | 7 | from ..builder import SUBMODULES 8 | 9 | 10 | def convert_weights(model: nn.Module): 11 | """Convert applicable model parameters to fp32""" 12 | 13 | def _convert_weights_to_fp32(l): 14 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 15 | l.weight.data = l.weight.data.float() 16 | if l.bias is not None: 17 | l.bias.data = l.bias.data.float() 18 | 19 | if isinstance(l, nn.MultiheadAttention): 20 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 21 | tensor = getattr(l, attr) 22 | if tensor is not None: 23 | tensor.data = tensor.data.float() 24 | 25 | for name in ["text_projection", "proj"]: 26 | if hasattr(l, name): 27 | attr = getattr(l, name) 28 | if attr is not None: 29 | attr.data = attr.data.float() 30 | 31 | model.apply(_convert_weights_to_fp32) 32 | 33 | 34 | @SUBMODULES.register_module() 35 | class MDMTransformer(nn.Module): 36 | def __init__(self, 37 | input_feats=263, 38 | latent_dim=256, 39 | ff_size=1024, 40 | num_layers=8, 41 | num_heads=4, 42 | dropout=0.1, 43 | activation="gelu", 44 | clip_dim=512, 45 | clip_version=None, 46 | guide_scale=1.0, 47 | cond_mask_prob=0.1, 48 | use_official_ckpt=False, 49 | **kwargs): 50 | super().__init__() 51 | 52 | self.latent_dim = latent_dim 53 | self.ff_size = ff_size 54 | self.num_layers = num_layers 55 | self.num_heads = num_heads 56 | self.dropout = dropout 57 | self.activation = activation 58 | self.clip_dim = clip_dim 59 | self.input_feats = input_feats 60 | self.guide_scale = guide_scale 61 | self.use_official_ckpt = use_official_ckpt 62 | 63 | self.cond_mask_prob = cond_mask_prob 64 | self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) 65 | self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) 66 | 67 | seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, 68 | nhead=self.num_heads, 69 | dim_feedforward=self.ff_size, 70 | dropout=self.dropout, 71 | activation=self.activation) 72 | 73 | self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, 74 | num_layers=self.num_layers) 75 | 76 | self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) 77 | 78 | 79 | self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) 80 | self.clip_version = clip_version 81 | self.clip_model = self.load_and_freeze_clip(clip_version) 82 | 83 | self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) 84 | 85 | 86 | def load_and_freeze_clip(self, clip_version): 87 | clip_model, clip_preprocess = clip.load(clip_version, device='cpu', 88 | jit=False) # Must set jit=False for training 89 | clip.model.convert_weights( 90 | clip_model) # Actually this line is unnecessary since clip by default already on float16 91 | 92 | clip_model.eval() 93 | for p in clip_model.parameters(): 94 | p.requires_grad = False 95 | 96 | return clip_model 97 | 98 | 99 | def mask_cond(self, cond, force_mask=False): 100 | bs, d = cond.shape 101 | if force_mask: 102 | return torch.zeros_like(cond) 103 | elif self.training and self.cond_mask_prob > 0.: 104 | mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond 105 | return cond * (1. - mask) 106 | else: 107 | return cond 108 | 109 | def encode_text(self, raw_text): 110 | # raw_text - list (batch_size length) of strings with input text prompts 111 | device = next(self.parameters()).device 112 | max_text_len = 20 113 | if max_text_len is not None: 114 | default_context_length = 77 115 | context_length = max_text_len + 2 # start_token + 20 + end_token 116 | assert context_length < default_context_length 117 | texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) 118 | zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) 119 | texts = torch.cat([texts, zero_pad], dim=1) 120 | return self.clip_model.encode_text(texts).float() 121 | 122 | def get_precompute_condition(self, text, device=None, **kwargs): 123 | if not self.training and device == torch.device('cpu'): 124 | convert_weights(self.clip_model) 125 | text_feat = self.encode_text(text) 126 | return {'text_feat': text_feat} 127 | 128 | def post_process(self, motion): 129 | assert len(motion.shape) == 3 130 | if self.use_official_ckpt: 131 | motion[:, :, :4] = motion[:, :, :4] * 25 132 | return motion 133 | 134 | def forward(self, motion, timesteps, text_feat=None, **kwargs): 135 | """ 136 | motion: B, T, D 137 | timesteps: [batch_size] (int) 138 | """ 139 | B, T, D = motion.shape 140 | device = motion.device 141 | if text_feat is None: 142 | enc_text = get_precompute_condition(**kwargs)['text_feat'] 143 | else: 144 | enc_text = text_feat 145 | if self.training: 146 | # T, B, D 147 | motion = self.poseEmbedding(motion).permute(1, 0, 2) 148 | 149 | emb = self.embed_timestep(timesteps) # [1, bs, d] 150 | emb += self.embed_text(self.mask_cond(enc_text, force_mask=False)) 151 | 152 | xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0)) 153 | output = self.seqTransEncoder(xseq)[1:] 154 | 155 | # B, T, D 156 | output = self.poseFinal(output).permute(1, 0, 2) 157 | return output 158 | else: 159 | # T, B, D 160 | motion = self.poseEmbedding(motion).permute(1, 0, 2) 161 | 162 | emb = self.embed_timestep(timesteps) # [1, bs, d] 163 | emb_uncond = emb + self.embed_text(self.mask_cond(enc_text, force_mask=True)) 164 | emb_text = emb + self.embed_text(self.mask_cond(enc_text, force_mask=False)) 165 | 166 | xseq = self.sequence_pos_encoder(torch.cat((emb_uncond, motion), axis=0)) 167 | xseq_text = self.sequence_pos_encoder(torch.cat((emb_text, motion), axis=0)) 168 | output = self.seqTransEncoder(xseq)[1:] 169 | output_text = self.seqTransEncoder(xseq_text)[1:] 170 | # B, T, D 171 | output = self.poseFinal(output).permute(1, 0, 2) 172 | output_text = self.poseFinal(output_text).permute(1, 0, 2) 173 | scale = self.guide_scale 174 | output = output + scale * (output_text - output) 175 | return output 176 | 177 | 178 | class PositionalEncoding(nn.Module): 179 | def __init__(self, d_model, dropout=0.1, max_len=5000): 180 | super(PositionalEncoding, self).__init__() 181 | self.dropout = nn.Dropout(p=dropout) 182 | 183 | pe = torch.zeros(max_len, d_model) 184 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 185 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 186 | pe[:, 0::2] = torch.sin(position * div_term) 187 | pe[:, 1::2] = torch.cos(position * div_term) 188 | pe = pe.unsqueeze(0).transpose(0, 1) 189 | 190 | self.register_buffer('pe', pe) 191 | 192 | def forward(self, x): 193 | # not used in the final model 194 | x = x + self.pe[:x.shape[0], :] 195 | return self.dropout(x) 196 | 197 | 198 | class TimestepEmbedder(nn.Module): 199 | def __init__(self, latent_dim, sequence_pos_encoder): 200 | super().__init__() 201 | self.latent_dim = latent_dim 202 | self.sequence_pos_encoder = sequence_pos_encoder 203 | 204 | time_embed_dim = self.latent_dim 205 | self.time_embed = nn.Sequential( 206 | nn.Linear(self.latent_dim, time_embed_dim), 207 | nn.SiLU(), 208 | nn.Linear(time_embed_dim, time_embed_dim), 209 | ) 210 | 211 | def forward(self, timesteps): 212 | return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) 213 | -------------------------------------------------------------------------------- /mogen/models/transformers/motiondiffuse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from ..builder import SUBMODULES 7 | from .diffusion_transformer import DiffusionTransformer 8 | 9 | 10 | @SUBMODULES.register_module() 11 | class MotionDiffuseTransformer(DiffusionTransformer): 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | def get_precompute_condition(self, 16 | text=None, 17 | xf_proj=None, 18 | xf_out=None, 19 | device=None, 20 | clip_feat=None, 21 | **kwargs): 22 | if xf_proj is None or xf_out is None: 23 | xf_proj, xf_out = self.encode_text(text, clip_feat, device) 24 | return {'xf_proj': xf_proj, 'xf_out': xf_out} 25 | 26 | def post_process(self, motion): 27 | return motion 28 | 29 | def forward_train(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): 30 | B, T = h.shape[0], h.shape[1] 31 | for module in self.temporal_decoder_blocks: 32 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) 33 | output = self.out(h).view(B, T, -1).contiguous() 34 | return output 35 | 36 | def forward_test(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): 37 | B, T = h.shape[0], h.shape[1] 38 | for module in self.temporal_decoder_blocks: 39 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) 40 | output = self.out(h).view(B, T, -1).contiguous() 41 | return output 42 | -------------------------------------------------------------------------------- /mogen/models/transformers/position_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class SinusoidalPositionalEncoding(nn.Module): 7 | def __init__(self, d_model, dropout=0.1, max_len=5000): 8 | super(SinusoidalPositionalEncoding, self).__init__() 9 | self.dropout = nn.Dropout(p=dropout) 10 | 11 | pe = torch.zeros(max_len, d_model) 12 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 13 | div_term = torch.arange(0, d_model, 2).float() 14 | div_term = div_term * (-np.log(10000.0) / d_model) 15 | div_term = torch.exp(div_term) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0).transpose(0, 1) 19 | # T, 1, D 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, x): 23 | x = x + self.pe[:x.shape[0]] 24 | return self.dropout(x) 25 | 26 | 27 | class LearnedPositionalEncoding(nn.Module): 28 | def __init__(self, d_model, dropout=0.1, max_len=5000): 29 | super(LearnedPositionalEncoding, self).__init__() 30 | self.dropout = nn.Dropout(p=dropout) 31 | self.pe = nn.Parameter(torch.randn(max_len, 1, d_model)) 32 | 33 | def forward(self, x): 34 | x = x + self.pe[:x.shape[0]] 35 | return self.dropout(x) 36 | -------------------------------------------------------------------------------- /mogen/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/ReMoDiffuse/d81c83ddbf72a989dc334cd48cb7d46fb6feba63/mogen/models/utils/__init__.py -------------------------------------------------------------------------------- /mogen/models/utils/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_MLP(dim_list, latent_dim): 5 | model_list = [] 6 | prev = dim_list[0] 7 | for cur in dim_list[1:]: 8 | model_list.append(nn.Linear(prev, cur)) 9 | model_list.append(nn.GELU()) 10 | prev = cur 11 | model_list.append(nn.Linear(prev, latent_dim)) 12 | model = nn.Sequential(*model_list) 13 | return model 14 | -------------------------------------------------------------------------------- /mogen/models/utils/stylization_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def zero_module(module): 6 | """ 7 | Zero out the parameters of a module and return it. 8 | """ 9 | for p in module.parameters(): 10 | p.detach().zero_() 11 | return module 12 | 13 | 14 | class StylizationBlock(nn.Module): 15 | 16 | def __init__(self, latent_dim, time_embed_dim, dropout): 17 | super().__init__() 18 | self.emb_layers = nn.Sequential( 19 | nn.SiLU(), 20 | nn.Linear(time_embed_dim, 2 * latent_dim), 21 | ) 22 | self.norm = nn.LayerNorm(latent_dim) 23 | self.out_layers = nn.Sequential( 24 | nn.SiLU(), 25 | nn.Dropout(p=dropout), 26 | zero_module(nn.Linear(latent_dim, latent_dim)), 27 | ) 28 | 29 | def forward(self, h, emb): 30 | """ 31 | h: B, T, D 32 | emb: B, D 33 | """ 34 | # B, 1, 2D 35 | emb_out = self.emb_layers(emb).unsqueeze(1) 36 | # scale: B, 1, D / shift: B, 1, D 37 | scale, shift = torch.chunk(emb_out, 2, dim=2) 38 | h = self.norm(h) * (1 + scale) + shift 39 | h = self.out_layers(h) 40 | return h 41 | -------------------------------------------------------------------------------- /mogen/models/utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec -------------------------------------------------------------------------------- /mogen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.utils.collect_env import collect_env 2 | from mogen.utils.dist_utils import DistOptimizerHook, allreduce_grads 3 | from mogen.utils.logger import get_root_logger 4 | from mogen.utils.misc import multi_apply, torch_to_numpy 5 | from mogen.utils.path_utils import ( 6 | Existence, 7 | check_input_path, 8 | check_path_existence, 9 | check_path_suffix, 10 | prepare_output_path, 11 | ) 12 | 13 | 14 | __all__ = [ 15 | 'collect_env', 'DistOptimizerHook', 'allreduce_grads', 'get_root_logger', 16 | 'multi_apply', 'torch_to_numpy', 'Existence', 'check_input_path', 17 | 'check_path_existence', 'check_path_suffix', 'prepare_output_path' 18 | ] -------------------------------------------------------------------------------- /mogen/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import collect_env as collect_base_env 2 | from mmcv.utils import get_git_hash 3 | 4 | import mogen 5 | 6 | 7 | def collect_env(): 8 | """Collect the information of the running environments.""" 9 | env_info = collect_base_env() 10 | env_info['mogen'] = mogen.__version__ + '+' + get_git_hash()[:7] 11 | return env_info 12 | 13 | 14 | if __name__ == '__main__': 15 | for name, val in collect_env().items(): 16 | print(f'{name}: {val}') -------------------------------------------------------------------------------- /mogen/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.distributed as dist 4 | from mmcv.runner import OptimizerHook 5 | from torch._utils import ( 6 | _flatten_dense_tensors, 7 | _take_tensors, 8 | _unflatten_dense_tensors, 9 | ) 10 | 11 | 12 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 13 | if bucket_size_mb > 0: 14 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 15 | buckets = _take_tensors(tensors, bucket_size_bytes) 16 | else: 17 | buckets = OrderedDict() 18 | for tensor in tensors: 19 | tp = tensor.type() 20 | if tp not in buckets: 21 | buckets[tp] = [] 22 | buckets[tp].append(tensor) 23 | buckets = buckets.values() 24 | 25 | for bucket in buckets: 26 | flat_tensors = _flatten_dense_tensors(bucket) 27 | dist.all_reduce(flat_tensors) 28 | flat_tensors.div_(world_size) 29 | for tensor, synced in zip( 30 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 31 | tensor.copy_(synced) 32 | 33 | 34 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 35 | grads = [ 36 | param.grad.data for param in params 37 | if param.requires_grad and param.grad is not None 38 | ] 39 | world_size = dist.get_world_size() 40 | if coalesce: 41 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 42 | else: 43 | for tensor in grads: 44 | dist.all_reduce(tensor.div_(world_size)) 45 | 46 | 47 | class DistOptimizerHook(OptimizerHook): 48 | 49 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): 50 | self.grad_clip = grad_clip 51 | self.coalesce = coalesce 52 | self.bucket_size_mb = bucket_size_mb 53 | 54 | def after_train_iter(self, runner): 55 | runner.optimizer.zero_grad() 56 | runner.outputs['loss'].backward() 57 | if self.grad_clip is not None: 58 | self.clip_grads(runner.model.parameters()) 59 | runner.optimizer.step() -------------------------------------------------------------------------------- /mogen/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | return get_logger('mogen', log_file, log_level) -------------------------------------------------------------------------------- /mogen/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | 6 | def multi_apply(func, *args, **kwargs): 7 | pfunc = partial(func, **kwargs) if kwargs else func 8 | map_results = map(pfunc, *args) 9 | return tuple(map(list, zip(*map_results))) 10 | 11 | 12 | def torch_to_numpy(x): 13 | assert isinstance(x, torch.Tensor) 14 | return x.detach().cpu().numpy() -------------------------------------------------------------------------------- /mogen/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from enum import Enum 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | try: 8 | from typing import Literal 9 | except ImportError: 10 | from typing_extensions import Literal 11 | 12 | 13 | def check_path_suffix(path_str: str, 14 | allowed_suffix: Union[str, List[str]] = '') -> bool: 15 | """Check whether the suffix of the path is allowed. 16 | 17 | Args: 18 | path_str (str): 19 | Path to check. 20 | allowed_suffix (List[str], optional): 21 | What extension names are allowed. 22 | Offer a list like ['.jpg', ',jpeg']. 23 | When it's [], all will be received. 24 | Use [''] then directory is allowed. 25 | Defaults to []. 26 | 27 | Returns: 28 | bool: 29 | True: suffix test passed 30 | False: suffix test failed 31 | """ 32 | if isinstance(allowed_suffix, str): 33 | allowed_suffix = [allowed_suffix] 34 | pathinfo = Path(path_str) 35 | suffix = pathinfo.suffix.lower() 36 | if len(allowed_suffix) == 0: 37 | return True 38 | if pathinfo.is_dir(): 39 | if '' in allowed_suffix: 40 | return True 41 | else: 42 | return False 43 | else: 44 | for index, tmp_suffix in enumerate(allowed_suffix): 45 | if not tmp_suffix.startswith('.'): 46 | tmp_suffix = '.' + tmp_suffix 47 | allowed_suffix[index] = tmp_suffix.lower() 48 | if suffix in allowed_suffix: 49 | return True 50 | else: 51 | return False 52 | 53 | 54 | class Existence(Enum): 55 | """State of file existence.""" 56 | FileExist = 0 57 | DirectoryExistEmpty = 1 58 | DirectoryExistNotEmpty = 2 59 | MissingParent = 3 60 | DirectoryNotExist = 4 61 | FileNotExist = 5 62 | 63 | 64 | def check_path_existence( 65 | path_str: str, 66 | path_type: Literal['file', 'dir', 'auto'] = 'auto', 67 | ) -> Existence: 68 | """Check whether a file or a directory exists at the expected path. 69 | 70 | Args: 71 | path_str (str): 72 | Path to check. 73 | path_type (Literal[, optional): 74 | What kind of file do we expect at the path. 75 | Choose among `file`, `dir`, `auto`. 76 | Defaults to 'auto'. path_type = path_type.lower() 77 | 78 | Raises: 79 | KeyError: if `path_type` conflicts with `path_str` 80 | 81 | Returns: 82 | Existence: 83 | 0. FileExist: file at path_str exists. 84 | 1. DirectoryExistEmpty: folder at path exists and. 85 | 2. DirectoryExistNotEmpty: folder at path_str exists and not empty. 86 | 3. MissingParent: its parent doesn't exist. 87 | 4. DirectoryNotExist: expect a folder at path_str, but not found. 88 | 5. FileNotExist: expect a file at path_str, but not found. 89 | """ 90 | path_type = path_type.lower() 91 | assert path_type in {'file', 'dir', 'auto'} 92 | pathinfo = Path(path_str) 93 | if not pathinfo.parent.is_dir(): 94 | return Existence.MissingParent 95 | suffix = pathinfo.suffix.lower() 96 | if path_type == 'dir' or\ 97 | path_type == 'auto' and suffix == '': 98 | if pathinfo.is_dir(): 99 | if len(os.listdir(path_str)) == 0: 100 | return Existence.DirectoryExistEmpty 101 | else: 102 | return Existence.DirectoryExistNotEmpty 103 | else: 104 | return Existence.DirectoryNotExist 105 | elif path_type == 'file' or\ 106 | path_type == 'auto' and suffix != '': 107 | if pathinfo.is_file(): 108 | return Existence.FileExist 109 | elif pathinfo.is_dir(): 110 | if len(os.listdir(path_str)) == 0: 111 | return Existence.DirectoryExistEmpty 112 | else: 113 | return Existence.DirectoryExistNotEmpty 114 | if path_str.endswith('/'): 115 | return Existence.DirectoryNotExist 116 | else: 117 | return Existence.FileNotExist 118 | 119 | 120 | def prepare_output_path(output_path: str, 121 | allowed_suffix: List[str] = [], 122 | tag: str = 'output file', 123 | path_type: Literal['file', 'dir', 'auto'] = 'auto', 124 | overwrite: bool = True) -> None: 125 | """Check output folder or file. 126 | 127 | Args: 128 | output_path (str): could be folder or file. 129 | allowed_suffix (List[str], optional): 130 | Check the suffix of `output_path`. If folder, should be [] or ['']. 131 | If could both be folder or file, should be [suffixs..., '']. 132 | Defaults to []. 133 | tag (str, optional): The `string` tag to specify the output type. 134 | Defaults to 'output file'. 135 | path_type (Literal[, optional): 136 | Choose `file` for file and `dir` for folder. 137 | Choose `auto` if allowed to be both. 138 | Defaults to 'auto'. 139 | overwrite (bool, optional): 140 | Whether overwrite the existing file or folder. 141 | Defaults to True. 142 | 143 | Raises: 144 | FileNotFoundError: suffix does not match. 145 | FileExistsError: file or folder already exists and `overwrite` is 146 | False. 147 | 148 | Returns: 149 | None 150 | """ 151 | if path_type.lower() == 'dir': 152 | allowed_suffix = [] 153 | exist_result = check_path_existence(output_path, path_type=path_type) 154 | if exist_result == Existence.MissingParent: 155 | warnings.warn( 156 | f'The parent folder of {tag} does not exist: {output_path},' + 157 | f' will make dir {Path(output_path).parent.absolute().__str__()}') 158 | os.makedirs( 159 | Path(output_path).parent.absolute().__str__(), exist_ok=True) 160 | 161 | elif exist_result == Existence.DirectoryNotExist: 162 | os.mkdir(output_path) 163 | print(f'Making directory {output_path} for saving results.') 164 | elif exist_result == Existence.FileNotExist: 165 | suffix_matched = \ 166 | check_path_suffix(output_path, allowed_suffix=allowed_suffix) 167 | if not suffix_matched: 168 | raise FileNotFoundError( 169 | f'The {tag} should be {", ".join(allowed_suffix)}: ' 170 | f'{output_path}.') 171 | elif exist_result == Existence.FileExist: 172 | if not overwrite: 173 | raise FileExistsError( 174 | f'{output_path} exists (set overwrite = True to overwrite).') 175 | else: 176 | print(f'Overwriting {output_path}.') 177 | elif exist_result == Existence.DirectoryExistEmpty: 178 | pass 179 | elif exist_result == Existence.DirectoryExistNotEmpty: 180 | if not overwrite: 181 | raise FileExistsError( 182 | f'{output_path} is not empty (set overwrite = ' 183 | 'True to overwrite the files).') 184 | else: 185 | print(f'Overwriting {output_path} and its files.') 186 | else: 187 | raise FileNotFoundError(f'No Existence type for {output_path}.') 188 | 189 | 190 | def check_input_path( 191 | input_path: str, 192 | allowed_suffix: List[str] = [], 193 | tag: str = 'input file', 194 | path_type: Literal['file', 'dir', 'auto'] = 'auto', 195 | ): 196 | """Check input folder or file. 197 | 198 | Args: 199 | input_path (str): input folder or file path. 200 | allowed_suffix (List[str], optional): 201 | Check the suffix of `input_path`. If folder, should be [] or ['']. 202 | If could both be folder or file, should be [suffixs..., '']. 203 | Defaults to []. 204 | tag (str, optional): The `string` tag to specify the output type. 205 | Defaults to 'output file'. 206 | path_type (Literal[, optional): 207 | Choose `file` for file and `directory` for folder. 208 | Choose `auto` if allowed to be both. 209 | Defaults to 'auto'. 210 | 211 | Raises: 212 | FileNotFoundError: file does not exists or suffix does not match. 213 | 214 | Returns: 215 | None 216 | """ 217 | if path_type.lower() == 'dir': 218 | allowed_suffix = [] 219 | exist_result = check_path_existence(input_path, path_type=path_type) 220 | 221 | if exist_result in [ 222 | Existence.FileExist, Existence.DirectoryExistEmpty, 223 | Existence.DirectoryExistNotEmpty 224 | ]: 225 | suffix_matched = \ 226 | check_path_suffix(input_path, allowed_suffix=allowed_suffix) 227 | if not suffix_matched: 228 | raise FileNotFoundError( 229 | f'The {tag} should be {", ".join(allowed_suffix)}:' + 230 | f'{input_path}.') 231 | else: 232 | raise FileNotFoundError(f'The {tag} does not exist: {input_path}.') -------------------------------------------------------------------------------- /mogen/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is borrowed from https://github.com/EricGuo5513/text-to-motion 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | import math 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | from mpl_toolkits.mplot3d import Axes3D 12 | from matplotlib.animation import FuncAnimation, FFMpegFileWriter 13 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 14 | import mpl_toolkits.mplot3d.axes3d as p3 15 | 16 | # Define a kinematic tree for the skeletal struture 17 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 18 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 19 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 20 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 21 | 22 | 23 | def qinv(q): 24 | assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' 25 | mask = torch.ones_like(q) 26 | mask[..., 1:] = -mask[..., 1:] 27 | return q * mask 28 | 29 | 30 | def qrot(q, v): 31 | """ 32 | Rotate vector(s) v about the rotation described by quaternion(s) q. 33 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 34 | where * denotes any number of dimensions. 35 | Returns a tensor of shape (*, 3). 36 | """ 37 | assert q.shape[-1] == 4 38 | assert v.shape[-1] == 3 39 | assert q.shape[:-1] == v.shape[:-1] 40 | 41 | original_shape = list(v.shape) 42 | # print(q.shape) 43 | q = q.contiguous().view(-1, 4) 44 | v = v.contiguous().view(-1, 3) 45 | 46 | qvec = q[:, 1:] 47 | uv = torch.cross(qvec, v, dim=1) 48 | uuv = torch.cross(qvec, uv, dim=1) 49 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 50 | 51 | 52 | def recover_root_rot_pos(data): 53 | rot_vel = data[..., 0] 54 | r_rot_ang = torch.zeros_like(rot_vel).to(data.device) 55 | '''Get Y-axis rotation from rotation velocity''' 56 | r_rot_ang[..., 1:] = rot_vel[..., :-1] 57 | r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) 58 | 59 | r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) 60 | r_rot_quat[..., 0] = torch.cos(r_rot_ang) 61 | r_rot_quat[..., 2] = torch.sin(r_rot_ang) 62 | 63 | r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) 64 | r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] 65 | '''Add Y-axis rotation to root position''' 66 | r_pos = qrot(qinv(r_rot_quat), r_pos) 67 | 68 | r_pos = torch.cumsum(r_pos, dim=-2) 69 | 70 | r_pos[..., 1] = data[..., 3] 71 | return r_rot_quat, r_pos 72 | 73 | 74 | def recover_from_ric(data, joints_num): 75 | r_rot_quat, r_pos = recover_root_rot_pos(data) 76 | positions = data[..., 4:(joints_num - 1) * 3 + 4] 77 | positions = positions.view(positions.shape[:-1] + (-1, 3)) 78 | 79 | '''Add Y-axis rotation to local joints''' 80 | positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) 81 | 82 | '''Add root XZ to joints''' 83 | positions[..., 0] += r_pos[..., 0:1] 84 | positions[..., 2] += r_pos[..., 2:3] 85 | 86 | '''Concate root and joints''' 87 | positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) 88 | 89 | return positions 90 | 91 | 92 | def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4): 93 | matplotlib.use('Agg') 94 | 95 | title_sp = title.split(' ') 96 | if len(title_sp) > 20: 97 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])]) 98 | elif len(title_sp) > 10: 99 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) 100 | 101 | def init(): 102 | ax.set_xlim3d([-radius / 4, radius / 4]) 103 | ax.set_ylim3d([0, radius / 2]) 104 | ax.set_zlim3d([0, radius / 2]) 105 | fig.suptitle(title, fontsize=20) 106 | ax.grid(b=False) 107 | 108 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 109 | verts = [ 110 | [minx, miny, minz], 111 | [minx, miny, maxz], 112 | [maxx, miny, maxz], 113 | [maxx, miny, minz] 114 | ] 115 | xz_plane = Poly3DCollection([verts]) 116 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 117 | ax.add_collection3d(xz_plane) 118 | 119 | # (seq_len, joints_num, 3) 120 | data = joints.copy().reshape(len(joints), -1, 3) 121 | fig = plt.figure(figsize=figsize) 122 | ax = p3.Axes3D(fig) 123 | init() 124 | MINS = data.min(axis=0).min(axis=0) 125 | MAXS = data.max(axis=0).max(axis=0) 126 | colors = ['red', 'blue', 'black', 'red', 'blue', 127 | 'yellow', 'yellow', 'darkblue', 'darkblue', 'darkblue', 128 | 'darkred', 'darkred', 'darkred', 'darkred', 'darkred'] 129 | 130 | frame_number = data.shape[0] 131 | 132 | height_offset = MINS[1] 133 | data[:, :, 1] -= height_offset 134 | trajec = data[:, 0, [0, 2]] 135 | 136 | data[..., 0] -= data[:, 0:1, 0] 137 | data[..., 2] -= data[:, 0:1, 2] 138 | 139 | def update(index): 140 | ax.lines = [] 141 | ax.collections = [] 142 | ax.view_init(elev=120, azim=-90) 143 | ax.dist = 7.5 144 | plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], 145 | MAXS[2] - trajec[index, 1]) 146 | 147 | if index > 1: 148 | ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), 149 | trajec[:index, 1] - trajec[index, 1], linewidth=1.0, 150 | color='blue') 151 | 152 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): 153 | if i < 5: 154 | linewidth = 4.0 155 | else: 156 | linewidth = 2.0 157 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, 158 | color=color) 159 | 160 | plt.axis('off') 161 | ax.set_xticklabels([]) 162 | ax.set_yticklabels([]) 163 | ax.set_zticklabels([]) 164 | 165 | ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) 166 | ani.save(save_path, fps=fps) 167 | plt.close() 168 | 169 | 170 | def plot_3d_motion_kit(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4): 171 | matplotlib.use('Agg') 172 | 173 | title_sp = title.split(' ') 174 | if len(title_sp) > 20: 175 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])]) 176 | elif len(title_sp) > 10: 177 | title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) 178 | 179 | def init(): 180 | ax.set_xlim3d([-radius / 0.125, radius / 0.125]) 181 | ax.set_ylim3d([0, radius / 0.0625]) 182 | ax.set_zlim3d([0, radius / 0.0625]) 183 | fig.suptitle(title, fontsize=20) 184 | ax.grid(b=False) 185 | 186 | def plot_xzPlane(minx, maxx, miny, minz, maxz): 187 | verts = [ 188 | [minx, miny, minz], 189 | [minx, miny, maxz], 190 | [maxx, miny, maxz], 191 | [maxx, miny, minz] 192 | ] 193 | xz_plane = Poly3DCollection([verts]) 194 | xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) 195 | ax.add_collection3d(xz_plane) 196 | 197 | # (seq_len, joints_num, 3) 198 | data = joints.copy().reshape(len(joints), -1, 3) 199 | fig = plt.figure(figsize=figsize) 200 | ax = p3.Axes3D(fig) 201 | init() 202 | MINS = data.min(axis=0).min(axis=0) 203 | MAXS = data.max(axis=0).max(axis=0) 204 | colors = ['red', 'blue', 'black', 'red', 'blue',] 205 | 206 | frame_number = data.shape[0] 207 | 208 | height_offset = MINS[1] 209 | data[:, :, 1] -= height_offset 210 | trajec = data[:, 0, [0, 2]] 211 | 212 | data[..., 0] -= data[:, 0:1, 0] 213 | data[..., 2] -= data[:, 0:1, 2] 214 | 215 | def update(index): 216 | ax.lines = [] 217 | ax.collections = [] 218 | ax.view_init(elev=100, azim=-90) 219 | ax.dist = 225 220 | plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], 221 | MAXS[2] - trajec[index, 1]) 222 | 223 | if index > 1: 224 | ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), 225 | trajec[:index, 1] - trajec[index, 1], linewidth=1.0, 226 | color='blue') 227 | 228 | for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): 229 | if i < 5: 230 | linewidth = 4.0 231 | else: 232 | linewidth = 2.0 233 | ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, 234 | color=color) 235 | 236 | plt.axis('off') 237 | ax.set_xticklabels([]) 238 | ax.set_yticklabels([]) 239 | ax.set_zticklabels([]) 240 | 241 | ani = FuncAnimation(fig, update, frames=frame_number, interval=500 / fps, repeat=False) 242 | ani.save(save_path, fps=fps) 243 | plt.close() 244 | -------------------------------------------------------------------------------- /mogen/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | 3 | 4 | def parse_version_info(version_str): 5 | """Parse a version string into a tuple. 6 | Args: 7 | version_str (str): The version string. 8 | Returns: 9 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 10 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 11 | """ 12 | version_info = [] 13 | for x in version_str.split('.'): 14 | if x.isdigit(): 15 | version_info.append(int(x)) 16 | elif x.find('rc') != -1: 17 | patch_version = x.split('rc') 18 | version_info.append(int(patch_version[0])) 19 | version_info.append(f'rc{patch_version[1]}') 20 | return tuple(version_info) 21 | 22 | 23 | version_info = parse_version_info(__version__) 24 | 25 | __all__ = ['__version__', 'version_info', 'parse_version_info'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | scipy 5 | matplotlib==3.3.1 6 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | WORK_DIR=$2 3 | GPUS=$3 4 | PORT=${PORT:-29500} 5 | 6 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 7 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 8 | $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:4} -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | CHECKPOINT=$5 11 | GPUS=1 12 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) 13 | CPUS_PER_TASK=${CPUS_PER_TASK:-2} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | PY_ARGS=${@:6} 16 | 17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 18 | srun -p ${PARTITION} \ 19 | --job-name=${JOB_NAME} \ 20 | --gres=gpu:${GPUS_PER_NODE} \ 21 | --ntasks=${GPUS} \ 22 | --ntasks-per-node=${GPUS_PER_NODE} \ 23 | --cpus-per-task=${CPUS_PER_TASK} \ 24 | --kill-on-bad-exit=1 \ 25 | ${SRUN_ARGS} \ 26 | python -u tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | GPUS=$5 11 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) 12 | CPUS_PER_TASK=${CPUS_PER_TASK:-2} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | PY_ARGS=${@:6} 15 | 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 17 | srun -p ${PARTITION} \ 18 | --job-name=${JOB_NAME} \ 19 | --gres=gpu:${GPUS_PER_NODE} \ 20 | --ntasks=${GPUS} \ 21 | --ntasks-per-node=${GPUS_PER_NODE} \ 22 | --cpus-per-task=${CPUS_PER_TASK} \ 23 | --kill-on-bad-exit=1 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import DictAction 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import ( 10 | get_dist_info, 11 | init_dist, 12 | load_checkpoint, 13 | wrap_fp16_model, 14 | ) 15 | 16 | from mogen.apis import multi_gpu_test, single_gpu_test 17 | from mogen.datasets import build_dataloader, build_dataset 18 | from mogen.models import build_architecture 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='mogen evaluation') 23 | parser.add_argument('config', help='test config file path') 24 | parser.add_argument( 25 | '--work-dir', help='the dir to save evaluation results') 26 | parser.add_argument('checkpoint', help='checkpoint file') 27 | parser.add_argument('--out', help='output result file') 28 | parser.add_argument( 29 | '--gpu_collect', 30 | action='store_true', 31 | help='whether to use gpu to collect results') 32 | parser.add_argument('--tmpdir', help='tmp dir for writing some results') 33 | parser.add_argument( 34 | '--cfg-options', 35 | nargs='+', 36 | action=DictAction, 37 | help='override some settings in the used config, the key-value pair ' 38 | 'in xxx=yyy format will be merged into config file.') 39 | parser.add_argument( 40 | '--launcher', 41 | choices=['none', 'pytorch', 'slurm', 'mpi'], 42 | default='none', 43 | help='job launcher') 44 | parser.add_argument('--local_rank', type=int, default=0) 45 | parser.add_argument( 46 | '--device', 47 | choices=['cpu', 'cuda'], 48 | default='cuda', 49 | help='device used for testing') 50 | args = parser.parse_args() 51 | if 'LOCAL_RANK' not in os.environ: 52 | os.environ['LOCAL_RANK'] = str(args.local_rank) 53 | return args 54 | 55 | 56 | def main(): 57 | args = parse_args() 58 | 59 | cfg = mmcv.Config.fromfile(args.config) 60 | if args.cfg_options is not None: 61 | cfg.merge_from_dict(args.cfg_options) 62 | # set cudnn_benchmark 63 | if cfg.get('cudnn_benchmark', False): 64 | torch.backends.cudnn.benchmark = True 65 | cfg.data.test.test_mode = True 66 | 67 | # init distributed env first, since logger depends on the dist info. 68 | if args.launcher == 'none': 69 | distributed = False 70 | else: 71 | distributed = True 72 | init_dist(args.launcher, **cfg.dist_params) 73 | 74 | # build the dataloader 75 | dataset = build_dataset(cfg.data.test) 76 | # the extra round_up data will be removed during gpu/cpu collect 77 | data_loader = build_dataloader( 78 | dataset, 79 | samples_per_gpu=cfg.data.samples_per_gpu, 80 | workers_per_gpu=cfg.data.workers_per_gpu, 81 | dist=distributed, 82 | shuffle=False, 83 | round_up=False) 84 | 85 | # build the model and load checkpoint 86 | model = build_architecture(cfg.model) 87 | fp16_cfg = cfg.get('fp16', None) 88 | if fp16_cfg is not None: 89 | wrap_fp16_model(model) 90 | load_checkpoint(model, args.checkpoint, map_location='cpu') 91 | 92 | if not distributed: 93 | if args.device == 'cpu': 94 | model = model.cpu() 95 | else: 96 | model = MMDataParallel(model, device_ids=[0]) 97 | outputs = single_gpu_test(model, data_loader) 98 | else: 99 | model = MMDistributedDataParallel( 100 | model.cuda(), 101 | device_ids=[torch.cuda.current_device()], 102 | broadcast_buffers=False) 103 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 104 | args.gpu_collect) 105 | 106 | rank, _ = get_dist_info() 107 | if rank == 0: 108 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 109 | results = dataset.evaluate(outputs, args.work_dir) 110 | for k, v in results.items(): 111 | print(f'\n{k} : {v:.4f}') 112 | 113 | if args.out and rank == 0: 114 | print(f'\nwriting results to {args.out}') 115 | mmcv.dump(results, args.out) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | 7 | import mmcv 8 | import torch 9 | from mmcv import Config, DictAction 10 | from mmcv.runner import get_dist_info, init_dist 11 | 12 | from mogen import __version__ 13 | from mogen.apis import set_random_seed, train_model 14 | from mogen.datasets import build_dataset 15 | from mogen.models import build_architecture 16 | from mogen.utils import collect_env, get_root_logger 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Train a model') 21 | parser.add_argument('config', help='train config file path') 22 | parser.add_argument('--work-dir', help='the dir to save logs and models') 23 | parser.add_argument( 24 | '--resume-from', help='the checkpoint file to resume from') 25 | parser.add_argument( 26 | '--no-validate', 27 | action='store_true', 28 | help='whether not to evaluate the checkpoint during training') 29 | group_gpus = parser.add_mutually_exclusive_group() 30 | group_gpus.add_argument('--device', help='device used for training') 31 | group_gpus.add_argument( 32 | '--gpus', 33 | type=int, 34 | help='number of gpus to use ' 35 | '(only applicable to non-distributed training)') 36 | group_gpus.add_argument( 37 | '--gpu-ids', 38 | type=int, 39 | nargs='+', 40 | help='ids of gpus to use ' 41 | '(only applicable to non-distributed training)') 42 | parser.add_argument('--seed', type=int, default=None, help='random seed') 43 | parser.add_argument( 44 | '--deterministic', 45 | action='store_true', 46 | help='whether to set deterministic options for CUDNN backend.') 47 | parser.add_argument( 48 | '--options', nargs='+', action=DictAction, help='arguments in dict') 49 | parser.add_argument( 50 | '--launcher', 51 | choices=['none', 'pytorch', 'slurm', 'mpi'], 52 | default='none', 53 | help='job launcher') 54 | parser.add_argument('--local_rank', type=int, default=0) 55 | args = parser.parse_args() 56 | if 'LOCAL_RANK' not in os.environ: 57 | os.environ['LOCAL_RANK'] = str(args.local_rank) 58 | 59 | return args 60 | 61 | 62 | def main(): 63 | args = parse_args() 64 | 65 | cfg = Config.fromfile(args.config) 66 | if args.options is not None: 67 | cfg.merge_from_dict(args.options) 68 | # set cudnn_benchmark 69 | if cfg.get('cudnn_benchmark', False): 70 | torch.backends.cudnn.benchmark = True 71 | 72 | # work_dir is determined in this priority: CLI > segment in file > filename 73 | if args.work_dir is not None: 74 | # update configs according to CLI args if args.work_dir is not None 75 | cfg.work_dir = args.work_dir 76 | elif cfg.get('work_dir', None) is None: 77 | # use config filename as default work_dir if cfg.work_dir is None 78 | cfg.work_dir = osp.join('./work_dirs', 79 | osp.splitext(osp.basename(args.config))[0]) 80 | if args.resume_from is not None: 81 | cfg.resume_from = args.resume_from 82 | if args.gpu_ids is not None: 83 | cfg.gpu_ids = args.gpu_ids 84 | else: 85 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 86 | 87 | # init distributed env first, since logger depends on the dist info. 88 | if args.launcher == 'none': 89 | distributed = False 90 | else: 91 | distributed = True 92 | init_dist(args.launcher, **cfg.dist_params) 93 | _, world_size = get_dist_info() 94 | cfg.gpu_ids = range(world_size) 95 | 96 | # create work_dir 97 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 98 | # dump config 99 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 100 | # init the logger before other steps 101 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 102 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 103 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 104 | 105 | # init the meta dict to record some important information such as 106 | # environment info and seed, which will be logged 107 | meta = dict() 108 | # log env info 109 | env_info_dict = collect_env() 110 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 111 | dash_line = '-' * 60 + '\n' 112 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 113 | dash_line) 114 | meta['env_info'] = env_info 115 | 116 | # log some basic info 117 | logger.info(f'Distributed training: {distributed}') 118 | logger.info(f'Config:\n{cfg.pretty_text}') 119 | 120 | # set random seeds 121 | if args.seed is not None: 122 | logger.info(f'Set random seed to {args.seed}, ' 123 | f'deterministic: {args.deterministic}') 124 | set_random_seed(args.seed, deterministic=args.deterministic) 125 | cfg.seed = args.seed 126 | meta['seed'] = args.seed 127 | 128 | model = build_architecture(cfg.model) 129 | model.init_weights() 130 | 131 | datasets = [build_dataset(cfg.data.train)] 132 | if len(cfg.workflow) == 2: 133 | val_dataset = copy.deepcopy(cfg.data.val) 134 | val_dataset.pipeline = cfg.data.train.pipeline 135 | datasets.append(build_dataset(val_dataset)) 136 | # add an attribute for visualization convenience 137 | train_model( 138 | model, 139 | datasets, 140 | cfg, 141 | distributed=distributed, 142 | validate=(not args.no_validate), 143 | timestamp=timestamp, 144 | device='cpu' if args.device == 'cpu' else 'cuda', 145 | meta=meta) 146 | 147 | 148 | if __name__ == '__main__': 149 | main() -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | from mogen.models import build_architecture 8 | from mmcv.runner import load_checkpoint 9 | from mmcv.parallel import MMDataParallel 10 | from mogen.utils.plot_utils import ( 11 | recover_from_ric, 12 | plot_3d_motion, 13 | t2m_kinematic_chain, 14 | plot_3d_motion_kit, 15 | kit_kinematic_chain 16 | ) 17 | from scipy.ndimage import gaussian_filter 18 | 19 | 20 | def motion_temporal_filter(motion, sigma=1): 21 | motion = motion.reshape(motion.shape[0], -1) 22 | for i in range(motion.shape[1]): 23 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 24 | return motion.reshape(motion.shape[0], -1, 3) 25 | 26 | 27 | def plot_t2m(data, result_path, npy_path, caption): 28 | joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy() 29 | joint = motion_temporal_filter(joint, sigma=2.5) 30 | plot_3d_motion(result_path, t2m_kinematic_chain, joint, title=caption, fps=20) 31 | if npy_path is not None: 32 | np.save(npy_path, joint) 33 | 34 | def plot_kit(data, result_path, npy_path, caption): 35 | joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy() 36 | joint = motion_temporal_filter(joint, sigma=2.5) 37 | plot_3d_motion_kit(result_path, kit_kinematic_chain, joint, title=caption, fps=20) 38 | if npy_path is not None: 39 | np.save(npy_path, joint) 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser(description='mogen evaluation') 43 | parser.add_argument('config', help='test config file path') # kit(configs/remodiffuse/remodiffuse_kit.py) 44 | parser.add_argument('checkpoint', help='checkpoint file') 45 | parser.add_argument('--text', help='motion description') 46 | parser.add_argument('--motion_length', type=int, help='expected motion length') 47 | parser.add_argument('--out', help='output animation file') 48 | parser.add_argument('--pose_npy', help='output pose sequence file', default=None) 49 | parser.add_argument( 50 | '--launcher', 51 | choices=['none', 'pytorch', 'slurm', 'mpi'], 52 | default='none', 53 | help='job launcher') 54 | parser.add_argument('--local_rank', type=int, default=0) 55 | parser.add_argument( 56 | '--device', 57 | choices=['cpu', 'cuda'], 58 | default='cuda', 59 | help='device used for testing') 60 | args = parser.parse_args() 61 | if 'LOCAL_RANK' not in os.environ: 62 | os.environ['LOCAL_RANK'] = str(args.local_rank) 63 | return args 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | cfg = mmcv.Config.fromfile(args.config) 70 | # set cudnn_benchmark 71 | if cfg.get('cudnn_benchmark', False): 72 | torch.backends.cudnn.benchmark = True 73 | cfg.data.test.test_mode = True 74 | 75 | # init distributed env first, since logger depends on the dist info. 76 | if args.launcher == 'none': 77 | distributed = False 78 | else: 79 | distributed = True 80 | init_dist(args.launcher, **cfg.dist_params) 81 | 82 | assert args.motion_length >= 16 and args.motion_length <= 196 83 | 84 | # build the model and load checkpoint 85 | model = build_architecture(cfg.model) 86 | fp16_cfg = cfg.get('fp16', None) 87 | if fp16_cfg is not None: 88 | wrap_fp16_model(model) 89 | load_checkpoint(model, args.checkpoint, map_location='cpu') 90 | 91 | if args.device == 'cpu': 92 | model = model.cpu() 93 | else: 94 | model = MMDataParallel(model, device_ids=[0]) 95 | model.eval() 96 | 97 | dataset_name = cfg.data.test.dataset_name 98 | print("dataset_name",dataset_name) 99 | if dataset_name == 'human_ml3d': 100 | #assert dataset_name == "human_ml3d" 101 | mean_path = "data/datasets/human_ml3d/mean.npy" 102 | std_path = "data/datasets/human_ml3d/std.npy" 103 | mean = np.load(mean_path) 104 | std = np.load(std_path) 105 | else: 106 | #assert dataset_name == "kit_ml" 107 | mean_path = "data/datasets/kit_ml/mean.npy" 108 | std_path = "data/datasets/kit_ml/std.npy" 109 | mean = np.load(mean_path) 110 | std = np.load(std_path) 111 | 112 | device = args.device 113 | text = args.text 114 | motion_length = args.motion_length 115 | if dataset_name == 'human_ml3d': 116 | motion = torch.zeros(1, motion_length, 263).to(device) 117 | else: 118 | motion = torch.zeros(1, motion_length, 251).to(device) 119 | motion_mask = torch.ones(1, motion_length).to(device) 120 | motion_length = torch.Tensor([motion_length]).long().to(device) 121 | model = model.to(device) 122 | 123 | input = { 124 | 'motion': motion, 125 | 'motion_mask': motion_mask, 126 | 'motion_length': motion_length, 127 | 'motion_metas': [{'text': text}], 128 | } 129 | 130 | all_pred_motion = [] 131 | with torch.no_grad(): 132 | input['inference_kwargs'] = {} 133 | output_list = [] 134 | output = model(**input)[0]['pred_motion'] 135 | pred_motion = output.cpu().detach().numpy() 136 | pred_motion = pred_motion * std + mean 137 | 138 | if dataset_name == 'human_ml3d': 139 | plot_t2m(pred_motion, args.out, args.pose_npy, text) 140 | else: 141 | plot_kit(pred_motion, args.out, args.pose_npy, text) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() --------------------------------------------------------------------------------