├── .gitignore ├── LICENSE ├── README.md ├── configs ├── _base_ │ └── datasets │ │ ├── human_ml3d_bs128.py │ │ ├── inter_human_bs128.py │ │ └── kit_ml_bs128.py ├── finemogen │ ├── finemogen_kit.py │ └── finemogen_t2m.py ├── interhuman │ ├── intergen_interhuman.py │ ├── momatmogen_interhuman.py │ ├── motiondiffuse_interhuman.py │ └── remodiffuse_interhuman.py ├── mdm │ ├── mdm_kit.py │ ├── mdm_t2m.py │ └── mdm_t2m_official.py ├── motiondiffuse │ ├── motiondiffuse_kit.py │ └── motiondiffuse_t2m.py └── remodiffuse │ ├── remodiffuse_kit.py │ └── remodiffuse_t2m.py ├── imgs ├── pipeline.png └── teaser.png ├── mogen ├── __init__.py ├── apis │ ├── __init__.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── distributed_wrapper.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── eval_hooks.py │ │ ├── evaluators │ │ │ ├── __init__.py │ │ │ ├── base_evaluator.py │ │ │ ├── diversity_evaluator.py │ │ │ ├── fid_evaluator.py │ │ │ ├── matching_score_evaluator.py │ │ │ ├── multimodality_evaluator.py │ │ │ └── precision_evaluator.py │ │ └── utils.py │ └── optimizer │ │ ├── __init__.py │ │ └── builder.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── builder.py │ ├── dataset_wrappers.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formatting.py │ │ ├── quaternion.py │ │ ├── rotation_conversions.py │ │ ├── siamese_motion.py │ │ └── transforms.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ └── text_motion_dataset.py ├── models │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ ├── base_architecture.py │ │ ├── diffusion_architecture.py │ │ └── vae_architecture.py │ ├── attentions │ │ ├── __init__.py │ │ ├── base_attention.py │ │ ├── efficient_attention.py │ │ ├── fine_attention.py │ │ └── semantics_modulated.py │ ├── builder.py │ ├── losses │ │ ├── __init__.py │ │ ├── gan_loss.py │ │ ├── mse_loss.py │ │ └── utils.py │ ├── rnns │ │ ├── __init__.py │ │ └── t2m_bigru.py │ ├── transformers │ │ ├── __init__.py │ │ ├── actor.py │ │ ├── diffusion_transformer.py │ │ ├── finemogen.py │ │ ├── intergen.py │ │ ├── mdm.py │ │ ├── momatmogen.py │ │ ├── motiondiffuse.py │ │ └── remodiffuse.py │ └── utils │ │ ├── __init__.py │ │ ├── gaussian_diffusion.py │ │ ├── misc.py │ │ ├── mlp.py │ │ ├── position_encoding.py │ │ ├── stylization_block.py │ │ └── word_vectorizer.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── dist_utils.py │ ├── logger.py │ ├── misc.py │ ├── path_utils.py │ └── plot_utils.py └── version.py ├── requirements.txt └── tools ├── dist_train.sh ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── train.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/*.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # custom 108 | data 109 | # data for pytest moved to http server 110 | # !tests/data 111 | .vscode 112 | .idea 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | logs/ 118 | 119 | # Pytorch 120 | *.pth 121 | *.pt 122 | 123 | # Resources as exception 124 | !resources/* 125 | 126 | # Loaded/Saved data files 127 | *.npz 128 | *.npy 129 | *.pickle 130 | 131 | # MacOS 132 | *DS_Store* 133 | # git 134 | *.orig 135 | 136 | env.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | 15 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /configs/_base_/datasets/human_ml3d_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 3 | meta_keys = ['text', 'token'] 4 | train_pipeline = [ 5 | dict(type='Normalize', 6 | mean_path='data/datasets/human_ml3d/mean.npy', 7 | std_path='data/datasets/human_ml3d/std.npy'), 8 | dict(type='Crop', crop_size=196), 9 | dict(type='ToTensor', keys=data_keys), 10 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 11 | ] 12 | 13 | data = dict( 14 | samples_per_gpu=128, 15 | workers_per_gpu=1, 16 | train=dict(type='RepeatDataset', 17 | dataset=dict( 18 | type='TextMotionDataset', 19 | dataset_name='human_ml3d', 20 | data_prefix='data', 21 | pipeline=train_pipeline, 22 | ann_file='train.txt', 23 | motion_dir='motions', 24 | text_dir='texts', 25 | token_dir='tokens', 26 | clip_feat_dir='clip_feats', 27 | ), 28 | times=100), 29 | test=dict(type='TextMotionDataset', 30 | dataset_name='human_ml3d', 31 | data_prefix='data', 32 | pipeline=train_pipeline, 33 | ann_file='test.txt', 34 | motion_dir='motions', 35 | text_dir='texts', 36 | token_dir='tokens', 37 | clip_feat_dir='clip_feats', 38 | eval_cfg=dict( 39 | shuffle_indexes=True, 40 | replication_times=20, 41 | replication_reduction='statistics', 42 | evaluator_model=dict( 43 | type='T2MContrastiveModel', 44 | motion_encoder=dict( 45 | input_size=263, 46 | movement_hidden_size=512, 47 | movement_latent_size=512, 48 | motion_hidden_size=1024, 49 | motion_latent_size=512, 50 | ), 51 | text_encoder=dict(word_size=300, 52 | pos_size=15, 53 | hidden_size=512, 54 | output_size=512, 55 | max_text_len=20), 56 | init_cfg=dict( 57 | type='Pretrained', 58 | checkpoint='data/evaluators/human_ml3d/finest.tar')), 59 | metrics=[ 60 | dict(type='R Precision', batch_size=32, top_k=3), 61 | dict(type='Matching Score', batch_size=32), 62 | dict(type='FID'), 63 | dict(type='Diversity', num_samples=300), 64 | dict(type='MultiModality', 65 | num_samples=100, 66 | num_repeats=30, 67 | num_picks=10) 68 | ]), 69 | test_mode=True)) 70 | -------------------------------------------------------------------------------- /configs/_base_/datasets/inter_human_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 3 | meta_keys = ['text'] 4 | train_pipeline = [ 5 | dict(type='SwapSiameseMotion', prob=0.5), 6 | dict(type='ProcessSiameseMotion', 7 | feet_threshold=0.001, 8 | prev_frames=0, 9 | n_joints=22, 10 | prob=0.5), 11 | dict(type='Crop', crop_size=300), 12 | dict(type='Normalize', 13 | mean_path='data/datasets/inter_human/mean.npy', 14 | std_path='data/datasets/inter_human/std.npy'), 15 | dict(type='ToTensor', keys=data_keys), 16 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 17 | ] 18 | test_pipeline = [ 19 | dict(type='SwapSiameseMotion', prob=0.5), 20 | dict(type='ProcessSiameseMotion', 21 | feet_threshold=0.001, 22 | prev_frames=0, 23 | n_joints=22, 24 | prob=0.5), 25 | dict(type='Crop', crop_size=300), 26 | dict(type='ToTensor', keys=data_keys), 27 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 28 | ] 29 | 30 | data = dict( 31 | samples_per_gpu=128, 32 | workers_per_gpu=4, 33 | train=dict(type='RepeatDataset', 34 | dataset=dict(type='TextMotionDataset', 35 | dataset_name='inter_human', 36 | data_prefix='data', 37 | pipeline=train_pipeline, 38 | ann_file='train.txt', 39 | motion_dir='motions', 40 | text_dir='texts', 41 | clip_feat_dir='clip_feats', 42 | siamese_mode=True), 43 | times=100), 44 | test=dict( 45 | type='TextMotionDataset', 46 | dataset_name='inter_human', 47 | data_prefix='data', 48 | pipeline=test_pipeline, 49 | ann_file='test.txt', 50 | motion_dir='motions', 51 | text_dir='texts', 52 | clip_feat_dir='clip_feats', 53 | siamese_mode=True, 54 | eval_cfg=dict( 55 | shuffle_indexes=True, 56 | replication_times=1, 57 | replication_reduction='statistics', 58 | evaluator_model=dict( 59 | type='InterCLIP', 60 | input_dim=258, 61 | latent_dim=1024, 62 | ff_size=2048, 63 | num_layers=8, 64 | num_heads=8, 65 | dropout=0.1, 66 | activation="gelu", 67 | init_cfg=dict( 68 | type='Pretrained', 69 | checkpoint='data/evaluators/inter_human/interclip.ckpt')), 70 | metrics=[ 71 | dict(type='R Precision', batch_size=96, top_k=3), 72 | dict(type='Matching Score', batch_size=96), 73 | dict(type='FID', emb_scale=6), 74 | dict(type='Diversity', 75 | num_samples=300, 76 | emb_scale=6, 77 | norm_scale=0.5), 78 | dict(type='MultiModality', 79 | num_samples=100, 80 | num_repeats=30, 81 | num_picks=10) 82 | ]), 83 | test_mode=True)) 84 | -------------------------------------------------------------------------------- /configs/_base_/datasets/kit_ml_bs128.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 3 | meta_keys = ['text', 'token'] 4 | train_pipeline = [ 5 | dict(type='Crop', crop_size=196), 6 | dict(type='Normalize', 7 | mean_path='data/datasets/kit_ml/mean.npy', 8 | std_path='data/datasets/kit_ml/std.npy'), 9 | dict(type='ToTensor', keys=data_keys), 10 | dict(type='Collect', keys=data_keys, meta_keys=meta_keys) 11 | ] 12 | 13 | data = dict(samples_per_gpu=128, 14 | workers_per_gpu=1, 15 | train=dict(type='RepeatDataset', 16 | dataset=dict( 17 | type='TextMotionDataset', 18 | dataset_name='kit_ml', 19 | data_prefix='data', 20 | pipeline=train_pipeline, 21 | ann_file='train.txt', 22 | motion_dir='motions', 23 | text_dir='texts', 24 | clip_feat_dir='clip_feats', 25 | ), 26 | times=200), 27 | test=dict( 28 | type='TextMotionDataset', 29 | dataset_name='kit_ml', 30 | data_prefix='data', 31 | pipeline=train_pipeline, 32 | ann_file='test.txt', 33 | motion_dir='motions', 34 | text_dir='texts', 35 | token_dir='tokens', 36 | clip_feat_dir='clip_feats', 37 | eval_cfg=dict( 38 | shuffle_indexes=True, 39 | replication_times=20, 40 | replication_reduction='statistics', 41 | evaluator_model=dict( 42 | type='T2MContrastiveModel', 43 | motion_encoder=dict( 44 | input_size=251, 45 | movement_hidden_size=512, 46 | movement_latent_size=512, 47 | motion_hidden_size=1024, 48 | motion_latent_size=512, 49 | ), 50 | text_encoder=dict(word_size=300, 51 | pos_size=15, 52 | hidden_size=512, 53 | output_size=512, 54 | max_text_len=20), 55 | init_cfg=dict( 56 | type='Pretrained', 57 | checkpoint='data/evaluators/kit_ml/finest.tar')), 58 | metrics=[ 59 | dict(type='R Precision', batch_size=32, top_k=3), 60 | dict(type='Matching Score', batch_size=32), 61 | dict(type='FID'), 62 | dict(type='Diversity', num_samples=300), 63 | dict(type='MultiModality', 64 | num_samples=50, 65 | num_repeats=30, 66 | num_picks=10) 67 | ]), 68 | test_mode=True)) 69 | -------------------------------------------------------------------------------- /configs/finemogen/finemogen_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[10]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=12) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 64 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 256 32 | num_heads = 8 33 | dropout = 0 34 | dataset_name = "kit_ml" 35 | 36 | # model settings 37 | model = dict(type='MotionDiffusion', 38 | model=dict(type='FineMoGenTransformer', 39 | input_feats=input_feats, 40 | max_seq_len=max_seq_len, 41 | latent_dim=latent_dim * 8, 42 | time_embed_dim=time_embed_dim, 43 | num_layers=4, 44 | ca_block_cfg=dict(type='SAMI', 45 | latent_dim=latent_dim, 46 | text_latent_dim=text_latent_dim, 47 | num_heads=8, 48 | num_text_heads=1, 49 | num_experts=16, 50 | topk=2, 51 | gate_type='cosine_top', 52 | gate_noise=1.0, 53 | ffn_dim=ff_size, 54 | time_embed_dim=time_embed_dim, 55 | max_seq_len=max_seq_len, 56 | max_text_seq_len=77, 57 | temporal_comb=False, 58 | dropout=dropout), 59 | ffn_cfg=dict(latent_dim=latent_dim, 60 | ffn_dim=ff_size, 61 | dropout=dropout, 62 | time_embed_dim=time_embed_dim), 63 | text_encoder=dict(pretrained_model='clip', 64 | latent_dim=text_latent_dim, 65 | num_layers=2, 66 | ff_size=2048, 67 | dropout=dropout, 68 | use_text_proj=False), 69 | pose_encoder_cfg=dict(dataset_name=dataset_name, 70 | latent_dim=latent_dim, 71 | input_dim=input_feats), 72 | pose_decoder_cfg=dict(dataset_name=dataset_name, 73 | latent_dim=latent_dim, 74 | output_dim=input_feats), 75 | scale_func_cfg=dict(scale=4.5), 76 | moe_route_loss_weight=10.0, 77 | template_kl_loss_weight=0.0001, 78 | use_pos_embedding=False), 79 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 80 | diffusion_train=dict( 81 | beta_scheduler='linear', 82 | diffusion_steps=1000, 83 | model_mean_type='start_x', 84 | model_var_type='fixed_large', 85 | ), 86 | diffusion_test=dict( 87 | beta_scheduler='linear', 88 | diffusion_steps=1000, 89 | model_mean_type='start_x', 90 | model_var_type='fixed_large', 91 | respace='15,15,8,6,6', 92 | ), 93 | inference_type='ddim', 94 | loss_reduction='frame') 95 | data = dict(samples_per_gpu=64, 96 | train=dict(dataset=dict(ann_file='trainval_wo_mirror.txt'))) 97 | -------------------------------------------------------------------------------- /configs/finemogen/finemogen_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[10]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=12) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 64 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 256 32 | num_heads = 8 33 | dropout = 0 34 | dataset_name = "human_ml3d" 35 | 36 | # model settings 37 | model = dict(type='MotionDiffusion', 38 | model=dict(type='FineMoGenTransformer', 39 | input_feats=input_feats, 40 | max_seq_len=max_seq_len, 41 | latent_dim=latent_dim * 8, 42 | time_embed_dim=time_embed_dim, 43 | num_layers=4, 44 | ca_block_cfg=dict(type='SAMI', 45 | latent_dim=latent_dim, 46 | text_latent_dim=text_latent_dim, 47 | num_heads=8, 48 | num_text_heads=1, 49 | num_experts=16, 50 | topk=2, 51 | gate_type='cosine_top', 52 | gate_noise=1.0, 53 | ffn_dim=ff_size, 54 | time_embed_dim=time_embed_dim, 55 | max_seq_len=max_seq_len, 56 | max_text_seq_len=77, 57 | temporal_comb=False, 58 | dropout=dropout), 59 | ffn_cfg=dict(latent_dim=latent_dim, 60 | ffn_dim=ff_size, 61 | dropout=dropout, 62 | time_embed_dim=time_embed_dim), 63 | text_encoder=dict(pretrained_model='clip', 64 | latent_dim=text_latent_dim, 65 | num_layers=2, 66 | ff_size=2048, 67 | dropout=dropout, 68 | use_text_proj=False), 69 | pose_encoder_cfg=dict(dataset_name=dataset_name, 70 | latent_dim=latent_dim), 71 | pose_decoder_cfg=dict(dataset_name=dataset_name, 72 | latent_dim=latent_dim, 73 | output_dim=input_feats), 74 | scale_func_cfg=dict(scale=6.5), 75 | moe_route_loss_weight=10.0, 76 | template_kl_loss_weight=0.0001, 77 | use_pos_embedding=False), 78 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 79 | diffusion_train=dict( 80 | beta_scheduler='linear', 81 | diffusion_steps=1000, 82 | model_mean_type='start_x', 83 | model_var_type='fixed_large', 84 | ), 85 | diffusion_test=dict( 86 | beta_scheduler='linear', 87 | diffusion_steps=1000, 88 | model_mean_type='start_x', 89 | model_var_type='fixed_large', 90 | respace='15,15,8,6,6', 91 | ), 92 | inference_type='ddim', 93 | loss_reduction='batch') 94 | data = dict(samples_per_gpu=64, 95 | train=dict(dataset=dict(ann_file='trainval_wo_mirror.txt'))) 96 | -------------------------------------------------------------------------------- /configs/interhuman/intergen_interhuman.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/inter_human_bs128.py'] 2 | # use_adversarial_train = True 3 | 4 | # checkpoint saving 5 | checkpoint_config = dict(interval=1) 6 | 7 | dist_params = dict(backend='nccl') 8 | log_level = 'INFO' 9 | load_from = None 10 | resume_from = None 11 | workflow = [('train', 1)] 12 | 13 | # optimizer 14 | optimizer = dict(type='Adam', lr=2e-4) 15 | optimizer_config = dict(grad_clip=None) 16 | # learning policy 17 | lr_config = dict(policy='step', step=[]) 18 | runner = dict(type='EpochBasedRunner', max_epochs=10) 19 | 20 | log_config = dict( 21 | interval=50, 22 | hooks=[ 23 | dict(type='TextLoggerHook'), 24 | # dict(type='TensorboardLoggerHook') 25 | ]) 26 | 27 | input_feats = 524 28 | max_seq_len = 300 29 | latent_dim = 512 30 | time_embed_dim = 2048 31 | text_latent_dim = 512 32 | ff_size = 1024 33 | num_heads = 8 34 | dropout = 0 35 | 36 | # model settings 37 | model = dict( 38 | type='MotionDiffusion', 39 | model=dict(type='InterGen', 40 | input_dim=262, 41 | latent_dim=1024, 42 | ff_size=2048, 43 | num_layers=8, 44 | num_heads=8, 45 | dropout=0.1, 46 | activation="gelu", 47 | cfg_weight=3.5), 48 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 49 | loss_reduction="batch", 50 | diffusion_train=dict( 51 | beta_scheduler='linear', 52 | diffusion_steps=1000, 53 | model_mean_type='start_x', 54 | model_var_type='fixed_large', 55 | ), 56 | diffusion_test=dict( 57 | beta_scheduler='linear', 58 | diffusion_steps=1000, 59 | model_mean_type='start_x', 60 | model_var_type='fixed_large', 61 | respace='15,15,8,6,6', 62 | # respace='30,30,16,12,12', 63 | ), 64 | inference_type='ddim') 65 | data = dict(samples_per_gpu=64) 66 | -------------------------------------------------------------------------------- /configs/interhuman/momatmogen_interhuman.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/inter_human_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=20) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 262 27 | max_seq_len = 300 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict( 37 | type='MotionDiffusion', 38 | model=dict( 39 | type='MoMatMoGenTransformer', 40 | input_feats=input_feats, 41 | max_seq_len=max_seq_len, 42 | latent_dim=latent_dim, 43 | time_embed_dim=time_embed_dim, 44 | num_layers=4, 45 | ca_block_cfg=dict(type='DualSemanticsModulatedAttention', 46 | latent_dim=latent_dim, 47 | text_latent_dim=text_latent_dim, 48 | num_heads=num_heads, 49 | dropout=dropout, 50 | time_embed_dim=time_embed_dim), 51 | ffn_cfg=dict(latent_dim=latent_dim, 52 | ffn_dim=ff_size, 53 | dropout=dropout, 54 | time_embed_dim=time_embed_dim), 55 | text_encoder=dict(pretrained_model='clip', 56 | latent_dim=text_latent_dim, 57 | num_layers=2, 58 | ff_size=2048, 59 | dropout=dropout, 60 | use_text_proj=False), 61 | retrieval_cfg=dict( 62 | num_retrieval=1, 63 | stride=4, 64 | num_layers=2, 65 | num_motion_layers=2, 66 | kinematic_coef=0.1, 67 | topk=1, 68 | retrieval_file='data/database/interhuman_text_train.npz', 69 | latent_dim=latent_dim, 70 | output_dim=latent_dim, 71 | max_seq_len=max_seq_len, 72 | num_heads=num_heads, 73 | ff_size=ff_size, 74 | dropout=dropout, 75 | ffn_cfg=dict( 76 | latent_dim=latent_dim, 77 | ffn_dim=ff_size, 78 | dropout=dropout, 79 | ), 80 | sa_block_cfg=dict(type='EfficientSelfAttention', 81 | latent_dim=latent_dim, 82 | num_heads=num_heads, 83 | dropout=dropout), 84 | ), 85 | scale_func_cfg=dict(coarse_scale=5.5, 86 | both_coef=0.52351, 87 | text_coef=-0.28419, 88 | retr_coef=2.39872), 89 | # post_process_cfg=dict( 90 | # unnormalized_infer=True, 91 | # mean_path='data/datasets/inter_human/mean.npy', 92 | # std_path='data/datasets/inter_human/std.npy' 93 | # ) 94 | ), 95 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 96 | loss_reduction="batch", 97 | diffusion_train=dict( 98 | beta_scheduler='linear', 99 | diffusion_steps=1000, 100 | model_mean_type='start_x', 101 | model_var_type='fixed_large', 102 | ), 103 | diffusion_test=dict( 104 | beta_scheduler='linear', 105 | diffusion_steps=1000, 106 | model_mean_type='start_x', 107 | model_var_type='fixed_large', 108 | respace='15,15,8,6,6', 109 | ), 110 | inference_type='ddim') 111 | data = dict(samples_per_gpu=64) 112 | -------------------------------------------------------------------------------- /configs/interhuman/motiondiffuse_interhuman.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/inter_human_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=20) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 524 27 | max_seq_len = 300 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | # model settings 35 | model = dict(type='MotionDiffusion', 36 | model=dict(type='MotionDiffuseTransformer', 37 | input_feats=input_feats, 38 | max_seq_len=max_seq_len, 39 | latent_dim=latent_dim, 40 | time_embed_dim=time_embed_dim, 41 | num_layers=8, 42 | sa_block_cfg=dict(type='EfficientSelfAttention', 43 | latent_dim=latent_dim, 44 | num_heads=num_heads, 45 | dropout=dropout, 46 | time_embed_dim=time_embed_dim), 47 | ca_block_cfg=dict(type='EfficientCrossAttention', 48 | latent_dim=latent_dim, 49 | text_latent_dim=text_latent_dim, 50 | num_heads=num_heads, 51 | dropout=dropout, 52 | time_embed_dim=time_embed_dim), 53 | ffn_cfg=dict(latent_dim=latent_dim, 54 | ffn_dim=ff_size, 55 | dropout=dropout, 56 | time_embed_dim=time_embed_dim), 57 | text_encoder=dict(pretrained_model='clip', 58 | latent_dim=text_latent_dim, 59 | num_layers=4, 60 | num_heads=4, 61 | ff_size=2048, 62 | dropout=dropout, 63 | use_text_proj=True), 64 | post_process_cfg=dict( 65 | unnormalized_infer=True, 66 | mean_path='data/datasets/inter_human/mean.npy', 67 | std_path='data/datasets/inter_human/std.npy')), 68 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 69 | loss_reduction="batch", 70 | diffusion_train=dict( 71 | beta_scheduler='linear', 72 | diffusion_steps=1000, 73 | model_mean_type='start_x', 74 | model_var_type='fixed_small', 75 | ), 76 | diffusion_test=dict( 77 | beta_scheduler='linear', 78 | diffusion_steps=1000, 79 | model_mean_type='start_x', 80 | model_var_type='fixed_small', 81 | ), 82 | inference_type='ddpm') 83 | data = dict(samples_per_gpu=64) 84 | -------------------------------------------------------------------------------- /configs/interhuman/remodiffuse_interhuman.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/inter_human_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=40) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 524 27 | max_seq_len = 300 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict(type='MotionDiffusion', 37 | model=dict( 38 | type='ReMoDiffuseTransformer', 39 | input_feats=input_feats, 40 | max_seq_len=max_seq_len, 41 | latent_dim=latent_dim, 42 | time_embed_dim=time_embed_dim, 43 | num_layers=4, 44 | ca_block_cfg=dict(type='SemanticsModulatedAttention', 45 | latent_dim=latent_dim, 46 | text_latent_dim=text_latent_dim, 47 | num_heads=num_heads, 48 | dropout=dropout, 49 | time_embed_dim=time_embed_dim), 50 | ffn_cfg=dict(latent_dim=latent_dim, 51 | ffn_dim=ff_size, 52 | dropout=dropout, 53 | time_embed_dim=time_embed_dim), 54 | text_encoder=dict(pretrained_model='clip', 55 | latent_dim=text_latent_dim, 56 | num_layers=2, 57 | ff_size=2048, 58 | dropout=dropout, 59 | use_text_proj=False), 60 | retrieval_cfg=dict( 61 | num_retrieval=1, 62 | stride=4, 63 | num_layers=2, 64 | num_motion_layers=2, 65 | kinematic_coef=0.1, 66 | topk=1, 67 | retrieval_file='data/database/interhuman_text_train.npz', 68 | latent_dim=latent_dim, 69 | output_dim=latent_dim, 70 | max_seq_len=max_seq_len, 71 | num_heads=num_heads, 72 | ff_size=ff_size, 73 | dropout=dropout, 74 | ffn_cfg=dict( 75 | latent_dim=latent_dim, 76 | ffn_dim=ff_size, 77 | dropout=dropout, 78 | ), 79 | sa_block_cfg=dict(type='EfficientSelfAttention', 80 | latent_dim=latent_dim, 81 | num_heads=num_heads, 82 | dropout=dropout), 83 | ), 84 | scale_func_cfg=dict(coarse_scale=6.5, 85 | both_coef=0.52351, 86 | text_coef=-0.28419, 87 | retr_coef=2.39872), 88 | post_process_cfg=dict( 89 | unnormalized_infer=True, 90 | mean_path='data/datasets/inter_human/mean.npy', 91 | std_path='data/datasets/inter_human/std.npy')), 92 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 93 | loss_reduction="batch", 94 | diffusion_train=dict( 95 | beta_scheduler='linear', 96 | diffusion_steps=1000, 97 | model_mean_type='start_x', 98 | model_var_type='fixed_large', 99 | ), 100 | diffusion_test=dict( 101 | beta_scheduler='linear', 102 | diffusion_steps=1000, 103 | model_mean_type='start_x', 104 | model_var_type='fixed_large', 105 | respace='15,15,8,6,6', 106 | ), 107 | inference_type='ddim') 108 | data = dict(samples_per_gpu=64) 109 | -------------------------------------------------------------------------------- /configs/mdm/mdm_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=50000) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=1e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='IterBasedRunner', max_iters=500000) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_layers = 8 33 | num_heads = 4 34 | dropout = 0.1 35 | cond_mask_prob = 0.1 36 | # model settings 37 | model = dict(type='MotionDiffusion', 38 | model=dict(type='MDMTransformer', 39 | input_feats=input_feats, 40 | latent_dim=latent_dim, 41 | ff_size=ff_size, 42 | num_layers=num_layers, 43 | num_heads=num_heads, 44 | dropout=dropout, 45 | time_embed_dim=time_embed_dim, 46 | cond_mask_prob=cond_mask_prob, 47 | guide_scale=2.5, 48 | clip_version='ViT-B/32'), 49 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 50 | diffusion_train=dict( 51 | beta_scheduler='cosine', 52 | diffusion_steps=1000, 53 | model_mean_type='start_x', 54 | model_var_type='fixed_small', 55 | ), 56 | diffusion_test=dict( 57 | beta_scheduler='cosine', 58 | diffusion_steps=1000, 59 | model_mean_type='start_x', 60 | model_var_type='fixed_small', 61 | ), 62 | inference_type='ddpm') 63 | data = dict(samples_per_gpu=64) 64 | -------------------------------------------------------------------------------- /configs/mdm/mdm_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=50000) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=1e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='IterBasedRunner', max_iters=500000) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_layers = 8 33 | num_heads = 4 34 | dropout = 0.1 35 | cond_mask_prob = 0.1 36 | # model settings 37 | model = dict(type='MotionDiffusion', 38 | model=dict(type='MDMTransformer', 39 | input_feats=input_feats, 40 | latent_dim=latent_dim, 41 | ff_size=ff_size, 42 | num_layers=num_layers, 43 | num_heads=num_heads, 44 | dropout=dropout, 45 | time_embed_dim=time_embed_dim, 46 | cond_mask_prob=cond_mask_prob, 47 | guide_scale=2.5, 48 | clip_version='ViT-B/32'), 49 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 50 | diffusion_train=dict( 51 | beta_scheduler='cosine', 52 | diffusion_steps=1000, 53 | model_mean_type='start_x', 54 | model_var_type='fixed_small', 55 | ), 56 | diffusion_test=dict( 57 | beta_scheduler='cosine', 58 | diffusion_steps=1000, 59 | model_mean_type='start_x', 60 | model_var_type='fixed_small', 61 | ), 62 | inference_type='ddpm') 63 | data = dict(samples_per_gpu=64) 64 | -------------------------------------------------------------------------------- /configs/mdm/mdm_t2m_official.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=50000) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=1e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='IterBasedRunner', max_iters=600000) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_layers = 8 33 | num_heads = 4 34 | dropout = 0.1 35 | cond_mask_prob = 0.1 36 | # model settings 37 | model = dict(type='MotionDiffusion', 38 | model=dict(type='MDMTransformer', 39 | input_feats=input_feats, 40 | latent_dim=latent_dim, 41 | ff_size=ff_size, 42 | num_layers=num_layers, 43 | num_heads=num_heads, 44 | dropout=dropout, 45 | time_embed_dim=time_embed_dim, 46 | cond_mask_prob=cond_mask_prob, 47 | guide_scale=2.5, 48 | clip_version='ViT-B/32', 49 | use_official_ckpt=True), 50 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 51 | diffusion_train=dict( 52 | beta_scheduler='cosine', 53 | diffusion_steps=1000, 54 | model_mean_type='start_x', 55 | model_var_type='fixed_small', 56 | ), 57 | diffusion_test=dict( 58 | beta_scheduler='cosine', 59 | diffusion_steps=1000, 60 | model_mean_type='start_x', 61 | model_var_type='fixed_small', 62 | ), 63 | inference_type='ddpm') 64 | -------------------------------------------------------------------------------- /configs/motiondiffuse/motiondiffuse_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=50) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | # model settings 35 | model = dict(type='MotionDiffusion', 36 | model=dict(type='MotionDiffuseTransformer', 37 | input_feats=input_feats, 38 | max_seq_len=max_seq_len, 39 | latent_dim=latent_dim, 40 | time_embed_dim=time_embed_dim, 41 | num_layers=8, 42 | sa_block_cfg=dict(type='EfficientSelfAttention', 43 | latent_dim=latent_dim, 44 | num_heads=num_heads, 45 | dropout=dropout, 46 | time_embed_dim=time_embed_dim), 47 | ca_block_cfg=dict(type='EfficientCrossAttention', 48 | latent_dim=latent_dim, 49 | text_latent_dim=text_latent_dim, 50 | num_heads=num_heads, 51 | dropout=dropout, 52 | time_embed_dim=time_embed_dim), 53 | ffn_cfg=dict(latent_dim=latent_dim, 54 | ffn_dim=ff_size, 55 | dropout=dropout, 56 | time_embed_dim=time_embed_dim), 57 | text_encoder=dict(pretrained_model='clip', 58 | latent_dim=text_latent_dim, 59 | num_layers=4, 60 | num_heads=4, 61 | ff_size=2048, 62 | dropout=dropout, 63 | use_text_proj=True)), 64 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 65 | diffusion_train=dict( 66 | beta_scheduler='linear', 67 | diffusion_steps=1000, 68 | model_mean_type='epsilon', 69 | model_var_type='fixed_small', 70 | ), 71 | diffusion_test=dict( 72 | beta_scheduler='linear', 73 | diffusion_steps=1000, 74 | model_mean_type='epsilon', 75 | model_var_type='fixed_small', 76 | ), 77 | inference_type='ddpm') 78 | -------------------------------------------------------------------------------- /configs/motiondiffuse/motiondiffuse_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='step', step=[]) 17 | runner = dict(type='EpochBasedRunner', max_epochs=50) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | # model settings 35 | model = dict(type='MotionDiffusion', 36 | model=dict(type='MotionDiffuseTransformer', 37 | input_feats=input_feats, 38 | max_seq_len=max_seq_len, 39 | latent_dim=latent_dim, 40 | time_embed_dim=time_embed_dim, 41 | num_layers=8, 42 | sa_block_cfg=dict(type='EfficientSelfAttention', 43 | latent_dim=latent_dim, 44 | num_heads=num_heads, 45 | dropout=dropout, 46 | time_embed_dim=time_embed_dim), 47 | ca_block_cfg=dict(type='EfficientCrossAttention', 48 | latent_dim=latent_dim, 49 | text_latent_dim=text_latent_dim, 50 | num_heads=num_heads, 51 | dropout=dropout, 52 | time_embed_dim=time_embed_dim), 53 | ffn_cfg=dict(latent_dim=latent_dim, 54 | ffn_dim=ff_size, 55 | dropout=dropout, 56 | time_embed_dim=time_embed_dim), 57 | text_encoder=dict(pretrained_model='clip', 58 | latent_dim=text_latent_dim, 59 | num_layers=4, 60 | num_heads=4, 61 | ff_size=2048, 62 | dropout=dropout, 63 | use_text_proj=True)), 64 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 65 | diffusion_train=dict( 66 | beta_scheduler='linear', 67 | diffusion_steps=1000, 68 | model_mean_type='epsilon', 69 | model_var_type='fixed_small', 70 | ), 71 | diffusion_test=dict( 72 | beta_scheduler='linear', 73 | diffusion_steps=1000, 74 | model_mean_type='epsilon', 75 | model_var_type='fixed_small', 76 | ), 77 | inference_type='ddpm') 78 | data = dict(samples_per_gpu=128) 79 | -------------------------------------------------------------------------------- /configs/remodiffuse/remodiffuse_kit.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/kit_ml_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=20) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 251 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict(type='MotionDiffusion', 37 | model=dict(type='ReMoDiffuseTransformer', 38 | input_feats=input_feats, 39 | max_seq_len=max_seq_len, 40 | latent_dim=latent_dim, 41 | time_embed_dim=time_embed_dim, 42 | num_layers=4, 43 | ca_block_cfg=dict(type='SemanticsModulatedAttention', 44 | latent_dim=latent_dim, 45 | text_latent_dim=text_latent_dim, 46 | num_heads=num_heads, 47 | dropout=dropout, 48 | time_embed_dim=time_embed_dim), 49 | ffn_cfg=dict(latent_dim=latent_dim, 50 | ffn_dim=ff_size, 51 | dropout=dropout, 52 | time_embed_dim=time_embed_dim), 53 | text_encoder=dict(pretrained_model='clip', 54 | latent_dim=text_latent_dim, 55 | num_layers=2, 56 | ff_size=2048, 57 | dropout=dropout, 58 | use_text_proj=False), 59 | retrieval_cfg=dict( 60 | num_retrieval=2, 61 | stride=4, 62 | num_layers=2, 63 | num_motion_layers=2, 64 | kinematic_coef=0.1, 65 | topk=2, 66 | retrieval_file='data/database/kit_text_train.npz', 67 | latent_dim=latent_dim, 68 | output_dim=latent_dim, 69 | max_seq_len=max_seq_len, 70 | num_heads=num_heads, 71 | ff_size=ff_size, 72 | dropout=dropout, 73 | ffn_cfg=dict( 74 | latent_dim=latent_dim, 75 | ffn_dim=ff_size, 76 | dropout=dropout, 77 | ), 78 | sa_block_cfg=dict(type='EfficientSelfAttention', 79 | latent_dim=latent_dim, 80 | num_heads=num_heads, 81 | dropout=dropout), 82 | ), 83 | scale_func_cfg=dict(coarse_scale=4.0, 84 | both_coef=0.78123, 85 | text_coef=0.39284, 86 | retr_coef=-0.12475)), 87 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 88 | diffusion_train=dict( 89 | beta_scheduler='linear', 90 | diffusion_steps=1000, 91 | model_mean_type='start_x', 92 | model_var_type='fixed_large', 93 | ), 94 | diffusion_test=dict( 95 | beta_scheduler='linear', 96 | diffusion_steps=1000, 97 | model_mean_type='start_x', 98 | model_var_type='fixed_large', 99 | respace='15,15,8,6,6', 100 | ), 101 | inference_type='ddim') 102 | -------------------------------------------------------------------------------- /configs/remodiffuse/remodiffuse_t2m.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/datasets/human_ml3d_bs128.py'] 2 | 3 | # checkpoint saving 4 | checkpoint_config = dict(interval=1) 5 | 6 | dist_params = dict(backend='nccl') 7 | log_level = 'INFO' 8 | load_from = None 9 | resume_from = None 10 | workflow = [('train', 1)] 11 | 12 | # optimizer 13 | optimizer = dict(type='Adam', lr=2e-4) 14 | optimizer_config = dict(grad_clip=None) 15 | # learning policy 16 | lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) 17 | runner = dict(type='EpochBasedRunner', max_epochs=40) 18 | 19 | log_config = dict( 20 | interval=50, 21 | hooks=[ 22 | dict(type='TextLoggerHook'), 23 | # dict(type='TensorboardLoggerHook') 24 | ]) 25 | 26 | input_feats = 263 27 | max_seq_len = 196 28 | latent_dim = 512 29 | time_embed_dim = 2048 30 | text_latent_dim = 256 31 | ff_size = 1024 32 | num_heads = 8 33 | dropout = 0 34 | 35 | # model settings 36 | model = dict(type='MotionDiffusion', 37 | model=dict(type='ReMoDiffuseTransformer', 38 | input_feats=input_feats, 39 | max_seq_len=max_seq_len, 40 | latent_dim=latent_dim, 41 | time_embed_dim=time_embed_dim, 42 | num_layers=4, 43 | ca_block_cfg=dict(type='SemanticsModulatedAttention', 44 | latent_dim=latent_dim, 45 | text_latent_dim=text_latent_dim, 46 | num_heads=num_heads, 47 | dropout=dropout, 48 | time_embed_dim=time_embed_dim), 49 | ffn_cfg=dict(latent_dim=latent_dim, 50 | ffn_dim=ff_size, 51 | dropout=dropout, 52 | time_embed_dim=time_embed_dim), 53 | text_encoder=dict(pretrained_model='clip', 54 | latent_dim=text_latent_dim, 55 | num_layers=2, 56 | ff_size=2048, 57 | dropout=dropout, 58 | use_text_proj=False), 59 | retrieval_cfg=dict( 60 | num_retrieval=2, 61 | stride=4, 62 | num_layers=2, 63 | num_motion_layers=2, 64 | kinematic_coef=0.1, 65 | topk=2, 66 | retrieval_file='data/database/t2m_text_train.npz', 67 | latent_dim=latent_dim, 68 | output_dim=latent_dim, 69 | max_seq_len=max_seq_len, 70 | num_heads=num_heads, 71 | ff_size=ff_size, 72 | dropout=dropout, 73 | ffn_cfg=dict( 74 | latent_dim=latent_dim, 75 | ffn_dim=ff_size, 76 | dropout=dropout, 77 | ), 78 | sa_block_cfg=dict(type='EfficientSelfAttention', 79 | latent_dim=latent_dim, 80 | num_heads=num_heads, 81 | dropout=dropout), 82 | ), 83 | scale_func_cfg=dict(coarse_scale=6.5, 84 | both_coef=0.52351, 85 | text_coef=-0.28419, 86 | retr_coef=2.39872)), 87 | loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), 88 | diffusion_train=dict( 89 | beta_scheduler='linear', 90 | diffusion_steps=1000, 91 | model_mean_type='start_x', 92 | model_var_type='fixed_large', 93 | ), 94 | diffusion_test=dict( 95 | beta_scheduler='linear', 96 | diffusion_steps=1000, 97 | model_mean_type='start_x', 98 | model_var_type='fixed_large', 99 | respace='15,15,8,6,6', 100 | ), 101 | inference_type='ddim') 102 | -------------------------------------------------------------------------------- /imgs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/FineMoGen/d5697b5aa6ad2de0301e77a251c7d28fb177ee23/imgs/pipeline.png -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/FineMoGen/d5697b5aa6ad2de0301e77a251c7d28fb177ee23/imgs/teaser.png -------------------------------------------------------------------------------- /mogen/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | from packaging.version import parse 5 | 6 | from .version import __version__ 7 | 8 | 9 | def digit_version(version_str: str, length: int = 4): 10 | """Convert a version string into a tuple of integers. 11 | This method is usually used for comparing two versions. For pre-release 12 | versions: alpha < beta < rc. 13 | Args: 14 | version_str (str): The version string. 15 | length (int): The maximum number of version levels. Default: 4. 16 | Returns: 17 | tuple[int]: The version info in digits (integers). 18 | """ 19 | version = parse(version_str) 20 | assert version.release, f'failed to parse version {version_str}' 21 | release = list(version.release) 22 | release = release[:length] 23 | if len(release) < length: 24 | release = release + [0] * (length - len(release)) 25 | if version.is_prerelease: 26 | mapping = {'a': -3, 'b': -2, 'rc': -1} 27 | val = -4 28 | # version.pre can be None 29 | if version.pre: 30 | if version.pre[0] not in mapping: 31 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 32 | 'version checking may go wrong') 33 | else: 34 | val = mapping[version.pre[0]] 35 | release.extend([val, version.pre[-1]]) 36 | else: 37 | release.extend([val, 0]) 38 | 39 | elif version.is_postrelease: 40 | release.extend([1, version.post]) 41 | else: 42 | release.extend([0, 0]) 43 | return tuple(release) 44 | 45 | 46 | mmcv_minimum_version = '1.4.2' 47 | mmcv_maximum_version = '1.9.0' 48 | mmcv_version = digit_version(mmcv.__version__) 49 | 50 | 51 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 52 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 53 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 54 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 55 | 56 | __all__ = ['__version__', 'digit_version'] 57 | -------------------------------------------------------------------------------- /mogen/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.apis.test import (collect_results_cpu, collect_results_gpu, 2 | multi_gpu_test, single_gpu_test) 3 | from mogen.apis.train import set_random_seed, train_model 4 | 5 | __all__ = [ 6 | 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', 7 | 'single_gpu_test', 'set_random_seed', 'train_model' 8 | ] 9 | -------------------------------------------------------------------------------- /mogen/apis/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import shutil 4 | import tempfile 5 | import time 6 | 7 | import mmcv 8 | import torch 9 | import torch.distributed as dist 10 | from mmcv.runner import get_dist_info 11 | 12 | 13 | def single_gpu_test(model, data_loader): 14 | """Test with single gpu.""" 15 | model.eval() 16 | results = [] 17 | dataset = data_loader.dataset 18 | prog_bar = mmcv.ProgressBar(len(dataset)) 19 | for i, data in enumerate(data_loader): 20 | with torch.no_grad(): 21 | result = model(return_loss=False, **data) 22 | 23 | batch_size = len(result) 24 | if isinstance(result, list): 25 | results.extend(result) 26 | else: 27 | results.append(result) 28 | 29 | batch_size = data['motion'].size(0) 30 | for _ in range(batch_size): 31 | prog_bar.update() 32 | return results 33 | 34 | 35 | def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): 36 | """Test model with multiple gpus. 37 | This method tests model with multiple gpus and collects the results 38 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 39 | it encodes results to gpu tensors and use gpu communication for results 40 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 41 | and collects them by the rank 0 worker. 42 | Args: 43 | model (nn.Module): Model to be tested. 44 | data_loader (nn.Dataloader): Pytorch data loader. 45 | tmpdir (str): Path of directory to save the temporary results from 46 | different gpus under cpu mode. 47 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 48 | Returns: 49 | list: The prediction results. 50 | """ 51 | model.eval() 52 | results = [] 53 | dataset = data_loader.dataset 54 | rank, world_size = get_dist_info() 55 | if rank == 0: 56 | # Check if tmpdir is valid for cpu_collect 57 | if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)): 58 | raise OSError((f'The tmpdir {tmpdir} already exists.', 59 | ' Since tmpdir will be deleted after testing,', 60 | ' please make sure you specify an empty one.')) 61 | prog_bar = mmcv.ProgressBar(len(dataset)) 62 | time.sleep(2) # This line can prevent deadlock problem in some cases. 63 | for i, data in enumerate(data_loader): 64 | with torch.no_grad(): 65 | result = model(return_loss=False, **data) 66 | if isinstance(result, list): 67 | results.extend(result) 68 | else: 69 | results.append(result) 70 | 71 | if rank == 0: 72 | batch_size = data['motion'].size(0) 73 | for _ in range(batch_size * world_size): 74 | prog_bar.update() 75 | 76 | # collect results from all ranks 77 | if gpu_collect: 78 | results = collect_results_gpu(results, len(dataset)) 79 | else: 80 | results = collect_results_cpu(results, len(dataset), tmpdir) 81 | return results 82 | 83 | 84 | def collect_results_cpu(result_part, size, tmpdir=None): 85 | """Collect results in cpu.""" 86 | rank, world_size = get_dist_info() 87 | # create a tmp dir if it is not specified 88 | if tmpdir is None: 89 | MAX_LEN = 512 90 | # 32 is whitespace 91 | dir_tensor = torch.full((MAX_LEN, ), 92 | 32, 93 | dtype=torch.uint8, 94 | device='cuda') 95 | if rank == 0: 96 | mmcv.mkdir_or_exist('.dist_test') 97 | tmpdir = tempfile.mkdtemp(dir='.dist_test') 98 | tmpdir = torch.tensor(bytearray(tmpdir.encode()), 99 | dtype=torch.uint8, 100 | device='cuda') 101 | dir_tensor[:len(tmpdir)] = tmpdir 102 | dist.broadcast(dir_tensor, 0) 103 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 104 | else: 105 | mmcv.mkdir_or_exist(tmpdir) 106 | # dump the part result to the dir 107 | mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) 108 | dist.barrier() 109 | # collect all parts 110 | if rank != 0: 111 | return None 112 | else: 113 | # load results of all parts from tmp dir 114 | part_list = [] 115 | for i in range(world_size): 116 | part_file = osp.join(tmpdir, f'part_{i}.pkl') 117 | part_result = mmcv.load(part_file) 118 | part_list.append(part_result) 119 | # sort the results 120 | ordered_results = [] 121 | for res in zip(*part_list): 122 | ordered_results.extend(list(res)) 123 | # the dataloader may pad some samples 124 | ordered_results = ordered_results[:size] 125 | # remove tmp dir 126 | shutil.rmtree(tmpdir) 127 | return ordered_results 128 | 129 | 130 | def collect_results_gpu(result_part, size): 131 | """Collect results in gpu.""" 132 | rank, world_size = get_dist_info() 133 | # dump result part to tensor with pickle 134 | part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)), 135 | dtype=torch.uint8, 136 | device='cuda') 137 | # gather all result part tensor shape 138 | shape_tensor = torch.tensor(part_tensor.shape, device='cuda') 139 | shape_list = [shape_tensor.clone() for _ in range(world_size)] 140 | dist.all_gather(shape_list, shape_tensor) 141 | # padding result part tensor to max length 142 | shape_max = torch.tensor(shape_list).max() 143 | part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') 144 | part_send[:shape_tensor[0]] = part_tensor 145 | part_recv_list = [ 146 | part_tensor.new_zeros(shape_max) for _ in range(world_size) 147 | ] 148 | # gather all result part 149 | dist.all_gather(part_recv_list, part_send) 150 | 151 | if rank == 0: 152 | part_list = [] 153 | for recv, shape in zip(part_recv_list, shape_list): 154 | part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) 155 | part_list.append(part_result) 156 | # sort the results 157 | ordered_results = [] 158 | for res in zip(*part_list): 159 | ordered_results.extend(list(res)) 160 | # the dataloader may pad some samples 161 | ordered_results = ordered_results[:size] 162 | return ordered_results 163 | -------------------------------------------------------------------------------- /mogen/apis/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook, OptimizerHook, 8 | build_runner) 9 | 10 | from mogen.core.distributed_wrapper import DistributedDataParallelWrapper 11 | from mogen.core.evaluation import DistEvalHook, EvalHook 12 | from mogen.core.optimizer import build_optimizers 13 | from mogen.datasets import build_dataloader, build_dataset 14 | from mogen.utils import get_root_logger 15 | 16 | 17 | def set_random_seed(seed, deterministic=False): 18 | """Set random seed. 19 | Args: 20 | seed (int): Seed to be used. 21 | deterministic (bool): Whether to set the deterministic option for 22 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 23 | to True and `torch.backends.cudnn.benchmark` to False. 24 | Default: False. 25 | """ 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | if deterministic: 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = False 33 | 34 | 35 | def train_model(model, 36 | dataset, 37 | cfg, 38 | distributed=False, 39 | validate=False, 40 | timestamp=None, 41 | device='cuda', 42 | meta=None): 43 | """Main api for training model.""" 44 | logger = get_root_logger(cfg.log_level) 45 | 46 | # prepare data loaders 47 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 48 | 49 | data_loaders = [ 50 | build_dataloader( 51 | ds, 52 | cfg.data.samples_per_gpu, 53 | cfg.data.workers_per_gpu, 54 | # cfg.gpus will be ignored if distributed 55 | num_gpus=len(cfg.gpu_ids), 56 | dist=distributed, 57 | round_up=True, 58 | seed=cfg.seed) for ds in dataset 59 | ] 60 | 61 | # determine whether use adversarial training precess or not 62 | use_adverserial_train = cfg.get('use_adversarial_train', False) 63 | 64 | # put model on gpus 65 | if distributed: 66 | find_unused_parameters = cfg.get('find_unused_parameters', True) 67 | # Sets the `find_unused_parameters` parameter in 68 | # torch.nn.parallel.DistributedDataParallel 69 | if use_adverserial_train: 70 | # Use DistributedDataParallelWrapper for adversarial training 71 | model = DistributedDataParallelWrapper( 72 | model, 73 | device_ids=[torch.cuda.current_device()], 74 | broadcast_buffers=False, 75 | find_unused_parameters=find_unused_parameters) 76 | else: 77 | model = MMDistributedDataParallel( 78 | model.cuda(), 79 | device_ids=[torch.cuda.current_device()], 80 | broadcast_buffers=False, 81 | find_unused_parameters=find_unused_parameters) 82 | else: 83 | if device == 'cuda': 84 | model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), 85 | device_ids=cfg.gpu_ids) 86 | elif device == 'cpu': 87 | model = model.cpu() 88 | else: 89 | raise ValueError(F'unsupported device name {device}.') 90 | 91 | # build runner 92 | optimizer = build_optimizers(model, cfg.optimizer) 93 | 94 | if cfg.get('runner') is None: 95 | cfg.runner = { 96 | 'type': 'EpochBasedRunner', 97 | 'max_epochs': cfg.total_epochs 98 | } 99 | warnings.warn( 100 | 'config is now expected to have a `runner` section, ' 101 | 'please set `runner` in your config.', UserWarning) 102 | 103 | runner = build_runner(cfg.runner, 104 | default_args=dict(model=model, 105 | batch_processor=None, 106 | optimizer=optimizer, 107 | work_dir=cfg.work_dir, 108 | logger=logger, 109 | meta=meta)) 110 | 111 | # an ugly walkaround to make the .log and .log.json filenames the same 112 | runner.timestamp = timestamp 113 | 114 | if use_adverserial_train: 115 | # The optimizer step process is included in the train_step function 116 | # of the model, so the runner should NOT include optimizer hook. 117 | optimizer_config = None 118 | else: 119 | # fp16 setting 120 | fp16_cfg = cfg.get('fp16', None) 121 | if fp16_cfg is not None: 122 | optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config, 123 | **fp16_cfg, 124 | distributed=distributed) 125 | elif distributed and 'type' not in cfg.optimizer_config: 126 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 127 | else: 128 | optimizer_config = cfg.optimizer_config 129 | 130 | # register hooks 131 | runner.register_training_hooks(cfg.lr_config, 132 | optimizer_config, 133 | cfg.checkpoint_config, 134 | cfg.log_config, 135 | cfg.get('momentum_config', None), 136 | custom_hooks_config=cfg.get( 137 | 'custom_hooks', None)) 138 | if distributed: 139 | runner.register_hook(DistSamplerSeedHook()) 140 | 141 | # register eval hooks 142 | if validate: 143 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 144 | val_dataloader = build_dataloader( 145 | val_dataset, 146 | samples_per_gpu=cfg.data.samples_per_gpu, 147 | workers_per_gpu=cfg.data.workers_per_gpu, 148 | dist=distributed, 149 | shuffle=False, 150 | round_up=True) 151 | eval_cfg = cfg.get('evaluation', {}) 152 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 153 | eval_hook = DistEvalHook if distributed else EvalHook 154 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 155 | 156 | if cfg.resume_from: 157 | runner.resume(cfg.resume_from) 158 | elif cfg.load_from: 159 | runner.load_checkpoint(cfg.load_from) 160 | runner.run(data_loaders, cfg.workflow) 161 | -------------------------------------------------------------------------------- /mogen/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/FineMoGen/d5697b5aa6ad2de0301e77a251c7d28fb177ee23/mogen/core/__init__.py -------------------------------------------------------------------------------- /mogen/core/distributed_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel 5 | from mmcv.parallel.scatter_gather import scatter_kwargs 6 | from torch.cuda._utils import _get_device_index 7 | 8 | 9 | @MODULE_WRAPPERS.register_module() 10 | class DistributedDataParallelWrapper(nn.Module): 11 | """A DistributedDataParallel wrapper for models in 3D mesh estimation task. 12 | 13 | In 3D mesh estimation task, there is a need to wrap different modules in 14 | the models with separate DistributedDataParallel. Otherwise, it will cause 15 | errors for GAN training. 16 | More specific, the GAN model, usually has two sub-modules: 17 | generator and discriminator. If we wrap both of them in one 18 | standard DistributedDataParallel, it will cause errors during training, 19 | because when we update the parameters of the generator (or discriminator), 20 | the parameters of the discriminator (or generator) is not updated, which is 21 | not allowed for DistributedDataParallel. 22 | So we design this wrapper to separately wrap DistributedDataParallel 23 | for generator and discriminator. 24 | In this wrapper, we perform two operations: 25 | 1. Wrap the modules in the models with separate MMDistributedDataParallel. 26 | Note that only modules with parameters will be wrapped. 27 | 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. 28 | Note that the arguments of this wrapper is the same as those in 29 | `torch.nn.parallel.distributed.DistributedDataParallel`. 30 | Args: 31 | module (nn.Module): Module that needs to be wrapped. 32 | device_ids (list[int | `torch.device`]): Same as that in 33 | `torch.nn.parallel.distributed.DistributedDataParallel`. 34 | dim (int, optional): Same as that in the official scatter function in 35 | pytorch. Defaults to 0. 36 | broadcast_buffers (bool): Same as that in 37 | `torch.nn.parallel.distributed.DistributedDataParallel`. 38 | Defaults to False. 39 | find_unused_parameters (bool, optional): Same as that in 40 | `torch.nn.parallel.distributed.DistributedDataParallel`. 41 | Traverse the autograd graph of all tensors contained in returned 42 | value of the wrapped module’s forward function. Defaults to False. 43 | kwargs (dict): Other arguments used in 44 | `torch.nn.parallel.distributed.DistributedDataParallel`. 45 | """ 46 | 47 | def __init__(self, 48 | module, 49 | device_ids, 50 | dim=0, 51 | broadcast_buffers=False, 52 | find_unused_parameters=False, 53 | **kwargs): 54 | super().__init__() 55 | assert len(device_ids) == 1, ( 56 | 'Currently, DistributedDataParallelWrapper only supports one' 57 | 'single CUDA device for each process.' 58 | f'The length of device_ids must be 1, but got {len(device_ids)}.') 59 | self.module = module 60 | self.dim = dim 61 | self.to_ddp(device_ids=device_ids, 62 | dim=dim, 63 | broadcast_buffers=broadcast_buffers, 64 | find_unused_parameters=find_unused_parameters, 65 | **kwargs) 66 | self.output_device = _get_device_index(device_ids[0], True) 67 | 68 | def to_ddp(self, device_ids, dim, broadcast_buffers, 69 | find_unused_parameters, **kwargs): 70 | """Wrap models with separate MMDistributedDataParallel. 71 | 72 | It only wraps the modules with parameters. 73 | """ 74 | for name, module in self.module._modules.items(): 75 | if next(module.parameters(), None) is None: 76 | module = module.cuda() 77 | elif all(not p.requires_grad for p in module.parameters()): 78 | module = module.cuda() 79 | else: 80 | module = MMDistributedDataParallel( 81 | module.cuda(), 82 | device_ids=device_ids, 83 | dim=dim, 84 | broadcast_buffers=broadcast_buffers, 85 | find_unused_parameters=find_unused_parameters, 86 | **kwargs) 87 | self.module._modules[name] = module 88 | 89 | def scatter(self, inputs, kwargs, device_ids): 90 | """Scatter function. 91 | 92 | Args: 93 | inputs (Tensor): Input Tensor. 94 | kwargs (dict): Args for 95 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 96 | device_ids (int): Device id. 97 | """ 98 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 99 | 100 | def forward(self, *inputs, **kwargs): 101 | """Forward function. 102 | 103 | Args: 104 | inputs (tuple): Input data. 105 | kwargs (dict): Args for 106 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 107 | """ 108 | inputs, kwargs = self.scatter(inputs, kwargs, 109 | [torch.cuda.current_device()]) 110 | return self.module(*inputs[0], **kwargs[0]) 111 | 112 | def train_step(self, *inputs, **kwargs): 113 | """Train step function. 114 | 115 | Args: 116 | inputs (Tensor): Input Tensor. 117 | kwargs (dict): Args for 118 | ``mmcv.parallel.scatter_gather.scatter_kwargs``. 119 | """ 120 | inputs, kwargs = self.scatter(inputs, kwargs, 121 | [torch.cuda.current_device()]) 122 | output = self.module.train_step(*inputs[0], **kwargs[0]) 123 | return output 124 | 125 | def val_step(self, *inputs, **kwargs): 126 | """Validation step function. 127 | 128 | Args: 129 | inputs (tuple): Input data. 130 | kwargs (dict): Args for ``scatter_kwargs``. 131 | """ 132 | inputs, kwargs = self.scatter(inputs, kwargs, 133 | [torch.cuda.current_device()]) 134 | output = self.module.val_step(*inputs[0], **kwargs[0]) 135 | return output 136 | -------------------------------------------------------------------------------- /mogen/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.core.evaluation.builder import build_evaluator 2 | from mogen.core.evaluation.eval_hooks import DistEvalHook, EvalHook 3 | 4 | __all__ = ["DistEvalHook", "EvalHook", "build_evaluator"] 5 | -------------------------------------------------------------------------------- /mogen/core/evaluation/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from mmcv.utils import Registry 5 | 6 | from .evaluators.diversity_evaluator import DiversityEvaluator 7 | from .evaluators.fid_evaluator import FIDEvaluator 8 | from .evaluators.matching_score_evaluator import MatchingScoreEvaluator 9 | from .evaluators.multimodality_evaluator import MultiModalityEvaluator 10 | from .evaluators.precision_evaluator import PrecisionEvaluator 11 | 12 | EVALUATORS = Registry('evaluators') 13 | 14 | EVALUATORS.register_module(name='R Precision', module=PrecisionEvaluator) 15 | EVALUATORS.register_module(name='Matching Score', 16 | module=MatchingScoreEvaluator) 17 | EVALUATORS.register_module(name='FID', module=FIDEvaluator) 18 | EVALUATORS.register_module(name='Diversity', module=DiversityEvaluator) 19 | EVALUATORS.register_module(name='MultiModality', module=MultiModalityEvaluator) 20 | 21 | 22 | def build_evaluator(metric, eval_cfg, data_len, eval_indexes): 23 | cfg = copy.deepcopy(eval_cfg) 24 | cfg.update(metric) 25 | cfg.pop('metrics') 26 | cfg['data_len'] = data_len 27 | cfg['eval_indexes'] = eval_indexes 28 | evaluator = EVALUATORS.build(cfg) 29 | if evaluator.append_indexes is not None: 30 | for i in range(eval_cfg['replication_times']): 31 | eval_indexes[i] = np.concatenate( 32 | (eval_indexes[i], evaluator.append_indexes[i]), axis=0) 33 | return evaluator, eval_indexes 34 | -------------------------------------------------------------------------------- /mogen/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import tempfile 3 | import warnings 4 | 5 | from mmcv.runner import DistEvalHook as BaseDistEvalHook 6 | from mmcv.runner import EvalHook as BaseEvalHook 7 | 8 | mogen_GREATER_KEYS = [] 9 | mogen_LESS_KEYS = [] 10 | 11 | 12 | class EvalHook(BaseEvalHook): 13 | 14 | def __init__(self, 15 | dataloader, 16 | start=None, 17 | interval=1, 18 | by_epoch=True, 19 | save_best=None, 20 | rule=None, 21 | test_fn=None, 22 | greater_keys=mogen_GREATER_KEYS, 23 | less_keys=mogen_LESS_KEYS, 24 | **eval_kwargs): 25 | if test_fn is None: 26 | from mogen.apis import single_gpu_test 27 | test_fn = single_gpu_test 28 | 29 | # remove "gpu_collect" from eval_kwargs 30 | if 'gpu_collect' in eval_kwargs: 31 | warnings.warn( 32 | '"gpu_collect" will be deprecated in EvalHook.' 33 | 'Please remove it from the config.', DeprecationWarning) 34 | _ = eval_kwargs.pop('gpu_collect') 35 | 36 | # update "save_best" according to "key_indicator" and remove the 37 | # latter from eval_kwargs 38 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): 39 | warnings.warn( 40 | '"key_indicator" will be deprecated in EvalHook.' 41 | 'Please use "save_best" to specify the metric key,' 42 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning) 43 | 44 | key_indicator = eval_kwargs.pop('key_indicator', None) 45 | if save_best is True and key_indicator is None: 46 | raise ValueError('key_indicator should not be None, when ' 47 | 'save_best is set to True.') 48 | save_best = key_indicator 49 | 50 | super().__init__(dataloader, start, interval, by_epoch, save_best, 51 | rule, test_fn, greater_keys, less_keys, **eval_kwargs) 52 | 53 | def evaluate(self, runner, results): 54 | 55 | with tempfile.TemporaryDirectory() as tmp_dir: 56 | eval_res = self.dataloader.dataset.evaluate(results, 57 | work_dir=tmp_dir, 58 | logger=runner.logger, 59 | **self.eval_kwargs) 60 | 61 | for name, val in eval_res.items(): 62 | runner.log_buffer.output[name] = val 63 | runner.log_buffer.ready = True 64 | 65 | if self.save_best is not None: 66 | if self.key_indicator == 'auto': 67 | self._init_rule(self.rule, list(eval_res.keys())[0]) 68 | 69 | return eval_res[self.key_indicator] 70 | 71 | return None 72 | 73 | 74 | class DistEvalHook(BaseDistEvalHook): 75 | 76 | def __init__(self, 77 | dataloader, 78 | start=None, 79 | interval=1, 80 | by_epoch=True, 81 | save_best=None, 82 | rule=None, 83 | test_fn=None, 84 | greater_keys=mogen_GREATER_KEYS, 85 | less_keys=mogen_LESS_KEYS, 86 | broadcast_bn_buffer=True, 87 | tmpdir=None, 88 | gpu_collect=False, 89 | **eval_kwargs): 90 | 91 | if test_fn is None: 92 | from mogen.apis import multi_gpu_test 93 | test_fn = multi_gpu_test 94 | 95 | # update "save_best" according to "key_indicator" and remove the 96 | # latter from eval_kwargs 97 | if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): 98 | warnings.warn( 99 | '"key_indicator" will be deprecated in EvalHook.' 100 | 'Please use "save_best" to specify the metric key,' 101 | 'e.g., save_best="pa-mpjpe".', DeprecationWarning) 102 | 103 | key_indicator = eval_kwargs.pop('key_indicator', None) 104 | if save_best is True and key_indicator is None: 105 | raise ValueError('key_indicator should not be None, when ' 106 | 'save_best is set to True.') 107 | save_best = key_indicator 108 | 109 | super().__init__(dataloader, start, interval, by_epoch, save_best, 110 | rule, test_fn, greater_keys, less_keys, 111 | broadcast_bn_buffer, tmpdir, gpu_collect, 112 | **eval_kwargs) 113 | 114 | def evaluate(self, runner, results): 115 | """Evaluate the results. 116 | Args: 117 | runner (:obj:`mmcv.Runner`): The underlined training runner. 118 | results (list): Output results. 119 | """ 120 | with tempfile.TemporaryDirectory() as tmp_dir: 121 | eval_res = self.dataloader.dataset.evaluate(results, 122 | work_dir=tmp_dir, 123 | logger=runner.logger, 124 | **self.eval_kwargs) 125 | 126 | for name, val in eval_res.items(): 127 | runner.log_buffer.output[name] = val 128 | runner.log_buffer.ready = True 129 | 130 | if self.save_best is not None: 131 | if self.key_indicator == 'auto': 132 | # infer from eval_results 133 | self._init_rule(self.rule, list(eval_res.keys())[0]) 134 | return eval_res[self.key_indicator] 135 | 136 | return None 137 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/diversity_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import calculate_diversity 4 | from .base_evaluator import BaseEvaluator 5 | 6 | 7 | class DiversityEvaluator(BaseEvaluator): 8 | 9 | def __init__(self, 10 | data_len=0, 11 | evaluator_model=None, 12 | num_samples=300, 13 | batch_size=None, 14 | drop_last=False, 15 | replication_times=1, 16 | replication_reduction='statistics', 17 | emb_scale=1, 18 | norm_scale=1, 19 | **kwargs): 20 | super().__init__(replication_times=replication_times, 21 | replication_reduction=replication_reduction, 22 | batch_size=batch_size, 23 | drop_last=drop_last, 24 | eval_begin_idx=0, 25 | eval_end_idx=data_len, 26 | evaluator_model=evaluator_model) 27 | self.num_samples = num_samples 28 | self.append_indexes = None 29 | self.emb_scale = emb_scale 30 | self.norm_scale = norm_scale 31 | 32 | def single_evaluate(self, results): 33 | results = self.prepare_results(results) 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | pred_motion = results['pred_motion'] 36 | pred_motion_length = results['pred_motion_length'] 37 | pred_motion_mask = results['pred_motion_mask'] 38 | with torch.no_grad(): 39 | pred_motion_emb = self.encode_motion( 40 | motion=pred_motion, 41 | motion_length=pred_motion_length, 42 | motion_mask=pred_motion_mask, 43 | device=device).cpu().detach().numpy() 44 | diversity = calculate_diversity(pred_motion_emb, self.num_samples, 45 | self.emb_scale, self.norm_scale) 46 | return diversity 47 | 48 | def parse_values(self, values): 49 | metrics = {} 50 | metrics['Diversity (mean)'] = values[0] 51 | metrics['Diversity (conf)'] = values[1] 52 | return metrics 53 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/fid_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import calculate_activation_statistics, calculate_frechet_distance 4 | from .base_evaluator import BaseEvaluator 5 | 6 | 7 | class FIDEvaluator(BaseEvaluator): 8 | 9 | def __init__(self, 10 | data_len=0, 11 | evaluator_model=None, 12 | batch_size=None, 13 | drop_last=False, 14 | replication_times=1, 15 | emb_scale=1, 16 | replication_reduction='statistics', 17 | **kwargs): 18 | super().__init__(replication_times=replication_times, 19 | replication_reduction=replication_reduction, 20 | batch_size=batch_size, 21 | drop_last=drop_last, 22 | eval_begin_idx=0, 23 | eval_end_idx=data_len, 24 | evaluator_model=evaluator_model) 25 | self.emb_scale = emb_scale 26 | self.append_indexes = None 27 | 28 | def single_evaluate(self, results): 29 | results = self.prepare_results(results) 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | pred_motion = results['pred_motion'] 32 | pred_motion_length = results['pred_motion_length'] 33 | pred_motion_mask = results['pred_motion_mask'] 34 | motion = results['motion'] 35 | motion_length = results['motion_length'] 36 | motion_mask = results['motion_mask'] 37 | with torch.no_grad(): 38 | pred_motion_emb = self.encode_motion( 39 | motion=pred_motion, 40 | motion_length=pred_motion_length, 41 | motion_mask=pred_motion_mask, 42 | device=device).cpu().detach().numpy() 43 | gt_motion_emb = self.encode_motion( 44 | motion=motion, 45 | motion_length=motion_length, 46 | motion_mask=motion_mask, 47 | device=device).cpu().detach().numpy() 48 | gt_mu, gt_cov = calculate_activation_statistics( 49 | gt_motion_emb, self.emb_scale) 50 | pred_mu, pred_cov = calculate_activation_statistics( 51 | pred_motion_emb, self.emb_scale) 52 | fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov) 53 | return fid 54 | 55 | def parse_values(self, values): 56 | metrics = {} 57 | metrics['FID (mean)'] = values[0] 58 | metrics['FID (conf)'] = values[1] 59 | return metrics 60 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/matching_score_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils import euclidean_distance_matrix 4 | from .base_evaluator import BaseEvaluator 5 | 6 | 7 | class MatchingScoreEvaluator(BaseEvaluator): 8 | 9 | def __init__(self, 10 | data_len=0, 11 | evaluator_model=None, 12 | top_k=3, 13 | batch_size=32, 14 | drop_last=False, 15 | replication_times=1, 16 | replication_reduction='statistics', 17 | **kwargs): 18 | super().__init__(replication_times=replication_times, 19 | replication_reduction=replication_reduction, 20 | batch_size=batch_size, 21 | drop_last=drop_last, 22 | eval_begin_idx=0, 23 | eval_end_idx=data_len, 24 | evaluator_model=evaluator_model) 25 | self.append_indexes = None 26 | self.top_k = top_k 27 | 28 | def single_evaluate(self, results): 29 | results = self.prepare_results(results) 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | pred_motion = results['pred_motion'] 32 | pred_motion_length = results['pred_motion_length'] 33 | pred_motion_mask = results['pred_motion_mask'] 34 | text = results['text'] 35 | token = results['token'] 36 | with torch.no_grad(): 37 | word_emb = self.encode_text(text=text, token=token, 38 | device=device).cpu().detach().numpy() 39 | motion_emb = self.encode_motion( 40 | motion=pred_motion, 41 | motion_length=pred_motion_length, 42 | motion_mask=pred_motion_mask, 43 | device=device).cpu().detach().numpy() 44 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb) 45 | matching_score = dist_mat.trace() 46 | all_size = word_emb.shape[0] 47 | return matching_score, all_size 48 | 49 | def concat_batch_metrics(self, batch_metrics): 50 | matching_score_sum = 0 51 | all_size = 0 52 | for batch_matching_score, batch_all_size in batch_metrics: 53 | matching_score_sum += batch_matching_score 54 | all_size += batch_all_size 55 | matching_score = matching_score_sum / all_size 56 | return matching_score 57 | 58 | def parse_values(self, values): 59 | metrics = {} 60 | metrics['Matching Score (mean)'] = values[0] 61 | metrics['Matching Score (conf)'] = values[1] 62 | return metrics 63 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/multimodality_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..utils import calculate_multimodality 5 | from .base_evaluator import BaseEvaluator 6 | 7 | 8 | class MultiModalityEvaluator(BaseEvaluator): 9 | 10 | def __init__(self, 11 | data_len=0, 12 | evaluator_model=None, 13 | num_samples=100, 14 | num_repeats=30, 15 | num_picks=10, 16 | batch_size=None, 17 | drop_last=False, 18 | replication_times=1, 19 | replication_reduction='statistics', 20 | **kwargs): 21 | super().__init__(replication_times=replication_times, 22 | replication_reduction=replication_reduction, 23 | batch_size=batch_size, 24 | drop_last=drop_last, 25 | eval_begin_idx=data_len, 26 | eval_end_idx=data_len + num_samples * num_repeats, 27 | evaluator_model=evaluator_model) 28 | self.num_samples = num_samples 29 | self.num_repeats = num_repeats 30 | self.num_picks = num_picks 31 | self.append_indexes = [] 32 | for i in range(replication_times): 33 | append_indexes = [] 34 | selected_indexs = np.random.choice(data_len, self.num_samples) 35 | for index in selected_indexs: 36 | append_indexes = append_indexes + [index] * self.num_repeats 37 | self.append_indexes.append(np.array(append_indexes)) 38 | 39 | def single_evaluate(self, results): 40 | results = self.prepare_results(results) 41 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | pred_motion = results['pred_motion'] 43 | pred_motion_length = results['pred_motion_length'] 44 | pred_motion_mask = results['pred_motion_mask'] 45 | with torch.no_grad(): 46 | pred_motion_emb = self.encode_motion( 47 | motion=pred_motion, 48 | motion_length=pred_motion_length, 49 | motion_mask=pred_motion_mask, 50 | device=device).cpu().detach().numpy() 51 | pred_motion_emb = \ 52 | pred_motion_emb.reshape((self.num_samples, self.num_repeats, -1)) 53 | multimodality = calculate_multimodality(pred_motion_emb, 54 | self.num_picks) 55 | return multimodality 56 | 57 | def parse_values(self, values): 58 | metrics = {} 59 | metrics['MultiModality (mean)'] = values[0] 60 | metrics['MultiModality (conf)'] = values[1] 61 | return metrics 62 | -------------------------------------------------------------------------------- /mogen/core/evaluation/evaluators/precision_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..utils import calculate_top_k, euclidean_distance_matrix 5 | from .base_evaluator import BaseEvaluator 6 | 7 | 8 | class PrecisionEvaluator(BaseEvaluator): 9 | 10 | def __init__(self, 11 | data_len=0, 12 | evaluator_model=None, 13 | top_k=3, 14 | batch_size=32, 15 | drop_last=False, 16 | replication_times=1, 17 | replication_reduction='statistics', 18 | **kwargs): 19 | super().__init__(replication_times=replication_times, 20 | replication_reduction=replication_reduction, 21 | batch_size=batch_size, 22 | drop_last=drop_last, 23 | eval_begin_idx=0, 24 | eval_end_idx=data_len, 25 | evaluator_model=evaluator_model) 26 | self.append_indexes = None 27 | self.top_k = top_k 28 | 29 | def single_evaluate(self, results): 30 | results = self.prepare_results(results) 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | pred_motion = results['pred_motion'] 33 | pred_motion_length = results['pred_motion_length'] 34 | pred_motion_mask = results['pred_motion_mask'] 35 | text = results['text'] 36 | token = results['token'] 37 | with torch.no_grad(): 38 | word_emb = self.encode_text(text=text, token=token, 39 | device=device).cpu().detach().numpy() 40 | motion_emb = self.encode_motion( 41 | motion=pred_motion, 42 | motion_length=pred_motion_length, 43 | motion_mask=pred_motion_mask, 44 | device=device).cpu().detach().numpy() 45 | dist_mat = euclidean_distance_matrix(word_emb, motion_emb) 46 | argsmax = np.argsort(dist_mat, axis=1) 47 | top_k_mat = calculate_top_k(argsmax, top_k=self.top_k) 48 | top_k_count = top_k_mat.sum(axis=0) 49 | all_size = word_emb.shape[0] 50 | return top_k_count, all_size 51 | 52 | def concat_batch_metrics(self, batch_metrics): 53 | top_k_count = 0 54 | all_size = 0 55 | for batch_top_k_count, batch_all_size in batch_metrics: 56 | top_k_count += batch_top_k_count 57 | all_size += batch_all_size 58 | R_precision = top_k_count / all_size 59 | return R_precision 60 | 61 | def parse_values(self, values): 62 | metrics = {} 63 | for top_k in range(self.top_k): 64 | metric_name_mean = 'R_precision Top %d (mean)' % (top_k + 1) 65 | metrics[metric_name_mean] = values[0][top_k] 66 | metric_name_conf = 'R_precision Top %d (conf)' % (top_k + 1) 67 | metrics[metric_name_conf] = values[1][top_k] 68 | return metrics 69 | -------------------------------------------------------------------------------- /mogen/core/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | def get_metric_statistics(values, replication_times): 6 | mean = np.mean(values, axis=0) 7 | std = np.std(values, axis=0) 8 | conf_interval = 1.96 * std / np.sqrt(replication_times) 9 | return mean, conf_interval 10 | 11 | 12 | def euclidean_distance_matrix(matrix1, matrix2): 13 | """ 14 | Params: 15 | -- matrix1: N1 x D 16 | -- matrix2: N2 x D 17 | Returns: 18 | -- dist: N1 x N2 19 | dist[i, j] == distance(matrix1[i], matrix2[j]) 20 | """ 21 | assert matrix1.shape[1] == matrix2.shape[1] 22 | d1 = -2 * np.dot(matrix1, matrix2.T) 23 | d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) 24 | d3 = np.sum(np.square(matrix2), axis=1) 25 | dists = np.sqrt(d1 + d2 + d3) 26 | return dists 27 | 28 | 29 | def calculate_top_k(mat, top_k): 30 | size = mat.shape[0] 31 | gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) 32 | bool_mat = (mat == gt_mat) 33 | correct_vec = False 34 | top_k_list = [] 35 | for i in range(top_k): 36 | correct_vec = (correct_vec | bool_mat[:, i]) 37 | top_k_list.append(correct_vec[:, None]) 38 | top_k_mat = np.concatenate(top_k_list, axis=1) 39 | return top_k_mat 40 | 41 | 42 | def calculate_activation_statistics(activations, emb_scale): 43 | """ 44 | Params: 45 | -- activation: num_samples x dim_feat 46 | Returns: 47 | -- mu: dim_feat 48 | -- sigma: dim_feat x dim_feat 49 | """ 50 | activations = activations * emb_scale 51 | mu = np.mean(activations, axis=0) 52 | cov = np.cov(activations, rowvar=False) 53 | return mu, cov 54 | 55 | 56 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 57 | """Numpy implementation of the Frechet Distance. 58 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 59 | and X_2 ~ N(mu_2, C_2) is 60 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 61 | Stable version by Dougal J. Sutherland. 62 | Params: 63 | -- mu1 : Numpy array containing the activations of a layer of the 64 | inception net (like returned by the function 'get_predictions') 65 | for generated samples. 66 | -- mu2 : The sample mean over activations, precalculated on an 67 | representative data set. 68 | -- sigma1: The covariance matrix over activations for generated samples. 69 | -- sigma2: The covariance matrix over activations, precalculated on an 70 | representative data set. 71 | Returns: 72 | -- : The Frechet Distance. 73 | """ 74 | 75 | mu1 = np.atleast_1d(mu1) 76 | mu2 = np.atleast_1d(mu2) 77 | 78 | sigma1 = np.atleast_2d(sigma1) 79 | sigma2 = np.atleast_2d(sigma2) 80 | 81 | assert mu1.shape == mu2.shape, \ 82 | 'Training and test mean vectors have different lengths' 83 | assert sigma1.shape == sigma2.shape, \ 84 | 'Training and test covariances have different dimensions' 85 | 86 | diff = mu1 - mu2 87 | 88 | # Product might be almost singular 89 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 90 | if not np.isfinite(covmean).all(): 91 | msg = ('fid calculation produces singular product; ' 92 | 'adding %s to diagonal of cov estimates') % eps 93 | print(msg) 94 | offset = np.eye(sigma1.shape[0]) * eps 95 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 96 | 97 | # Numerical error might give slight imaginary component 98 | if np.iscomplexobj(covmean): 99 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 100 | m = np.max(np.abs(covmean.imag)) 101 | raise ValueError('Imaginary component {}'.format(m)) 102 | covmean = covmean.real 103 | 104 | tr_covmean = np.trace(covmean) 105 | 106 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 107 | 2 * tr_covmean) 108 | 109 | 110 | def calculate_diversity(activation, diversity_times, emb_scale, norm_scale): 111 | assert len(activation.shape) == 2 112 | assert activation.shape[0] > diversity_times 113 | num_samples = activation.shape[0] 114 | 115 | activation = activation * emb_scale 116 | first_indices = np.random.choice(num_samples, 117 | diversity_times, 118 | replace=False) 119 | second_indices = np.random.choice(num_samples, 120 | diversity_times, 121 | replace=False) 122 | delta = activation[first_indices] - activation[second_indices] 123 | dist = linalg.norm(delta * norm_scale, axis=1) 124 | return dist.mean() 125 | 126 | 127 | def calculate_multimodality(activation, multimodality_times): 128 | assert len(activation.shape) == 3 129 | assert activation.shape[1] > multimodality_times 130 | num_per_sent = activation.shape[1] 131 | 132 | first_dices = np.random.choice(num_per_sent, 133 | multimodality_times, 134 | replace=False) 135 | second_dices = np.random.choice(num_per_sent, 136 | multimodality_times, 137 | replace=False) 138 | delta = activation[:, first_dices] - activation[:, second_dices] 139 | dist = linalg.norm(delta, axis=2) 140 | return dist.mean() 141 | -------------------------------------------------------------------------------- /mogen/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import OPTIMIZERS, build_optimizers 2 | 3 | __all__ = ['build_optimizers', 'OPTIMIZERS'] 4 | -------------------------------------------------------------------------------- /mogen/core/optimizer/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import build_optimizer 3 | from mmcv.utils import Registry 4 | 5 | OPTIMIZERS = Registry('optimizers') 6 | 7 | 8 | def build_optimizers(model, cfgs): 9 | """Build multiple optimizers from configs. If `cfgs` contains several dicts 10 | for optimizers, then a dict for each constructed optimizers will be 11 | returned. If `cfgs` only contains one optimizer config, the constructed 12 | optimizer itself will be returned. For example, 13 | 14 | 1) Multiple optimizer configs: 15 | 16 | .. code-block:: python 17 | 18 | optimizer_cfg = dict( 19 | model1=dict(type='SGD', lr=lr), 20 | model2=dict(type='SGD', lr=lr)) 21 | 22 | The return dict is 23 | ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` 24 | 25 | 2) Single optimizer config: 26 | 27 | .. code-block:: python 28 | 29 | optimizer_cfg = dict(type='SGD', lr=lr) 30 | 31 | The return is ``torch.optim.Optimizer``. 32 | 33 | Args: 34 | model (:obj:`nn.Module`): The model with parameters to be optimized. 35 | cfgs (dict): The config dict of the optimizer. 36 | 37 | Returns: 38 | dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: 39 | The initialized optimizers. 40 | """ 41 | optimizers = {} 42 | if hasattr(model, 'module'): 43 | model = model.module 44 | # determine whether 'cfgs' has several dicts for optimizers 45 | if all(isinstance(v, dict) for v in cfgs.values()): 46 | for key, cfg in cfgs.items(): 47 | cfg_ = cfg.copy() 48 | module = getattr(model, key) 49 | optimizers[key] = build_optimizer(module, cfg_) 50 | return optimizers 51 | 52 | return build_optimizer(model, cfgs) 53 | -------------------------------------------------------------------------------- /mogen/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseMotionDataset 2 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 3 | from .pipelines import Compose 4 | from .samplers import DistributedSampler 5 | from .text_motion_dataset import TextMotionDataset 6 | 7 | __all__ = [ 8 | 'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', 9 | 'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler' 10 | ] 11 | -------------------------------------------------------------------------------- /mogen/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from abc import abstractmethod 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from mogen.core.evaluation import build_evaluator 11 | from mogen.models.builder import build_submodule 12 | 13 | from .builder import DATASETS 14 | from .pipelines import Compose 15 | 16 | 17 | @DATASETS.register_module() 18 | class BaseMotionDataset(Dataset): 19 | """Base motion dataset. 20 | Args: 21 | data_prefix (str): the prefix of data path. 22 | pipeline (list): a list of dict, where each element represents 23 | a operation defined in `mogen.datasets.pipelines`. 24 | ann_file (str | None, optional): the annotation file. When ann_file is 25 | str, the subclass is expected to read from the ann_file. When 26 | ann_file is None, the subclass is expected to read according 27 | to data_prefix. 28 | test_mode (bool): in train mode or test mode. Default: None. 29 | dataset_name (str | None, optional): the name of dataset. It is used 30 | to identify the type of evaluation metric. Default: None. 31 | """ 32 | 33 | def __init__(self, 34 | data_prefix: str, 35 | pipeline: list, 36 | dataset_name: Optional[Union[str, None]] = None, 37 | fixed_length: Optional[Union[int, None]] = None, 38 | ann_file: Optional[Union[str, None]] = None, 39 | motion_dir: Optional[Union[str, None]] = None, 40 | eval_cfg: Optional[Union[dict, None]] = None, 41 | test_mode: Optional[bool] = False): 42 | super(BaseMotionDataset, self).__init__() 43 | 44 | self.data_prefix = data_prefix 45 | self.pipeline = Compose(pipeline) 46 | self.dataset_name = dataset_name 47 | self.fixed_length = fixed_length 48 | self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, 49 | ann_file) 50 | self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, 51 | motion_dir) 52 | self.eval_cfg = copy.deepcopy(eval_cfg) 53 | self.test_mode = test_mode 54 | 55 | self.load_annotations() 56 | if self.test_mode: 57 | self.prepare_evaluation() 58 | 59 | @abstractmethod 60 | def load_anno(self, name): 61 | pass 62 | 63 | def load_annotations(self): 64 | """Load annotations from ``ann_file`` to ``data_infos``""" 65 | self.data_infos = [] 66 | for line in open(self.ann_file, 'r').readlines(): 67 | line = line.strip() 68 | self.data_infos.append(self.load_anno(line)) 69 | 70 | def prepare_data(self, idx: int): 71 | """"Prepare raw data for the f'{idx'}-th data.""" 72 | results = copy.deepcopy(self.data_infos[idx]) 73 | results['dataset_name'] = self.dataset_name 74 | results['sample_idx'] = idx 75 | return self.pipeline(results) 76 | 77 | def __len__(self): 78 | """Return the length of current dataset.""" 79 | if self.test_mode: 80 | return len(self.eval_indexes) 81 | elif self.fixed_length is not None: 82 | return self.fixed_length 83 | return len(self.data_infos) 84 | 85 | def __getitem__(self, idx: int): 86 | """Prepare data for the ``idx``-th data. 87 | As for video dataset, we can first parse raw data for each frame. Then 88 | we combine annotations from all frames. This interface is used to 89 | simplify the logic of video dataset and other special datasets. 90 | """ 91 | if self.test_mode: 92 | idx = self.eval_indexes[idx] 93 | elif self.fixed_length is not None: 94 | idx = idx % len(self.data_infos) 95 | return self.prepare_data(idx) 96 | 97 | def prepare_evaluation(self): 98 | self.evaluators = [] 99 | self.eval_indexes = [] 100 | self.evaluator_model = build_submodule( 101 | self.eval_cfg.get('evaluator_model', None)) 102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 103 | self.evaluator_model = self.evaluator_model.to(device) 104 | self.evaluator_model.eval() 105 | self.eval_cfg['evaluator_model'] = self.evaluator_model 106 | for _ in range(self.eval_cfg['replication_times']): 107 | eval_indexes = np.arange(len(self.data_infos)) 108 | if self.eval_cfg.get('shuffle_indexes', False): 109 | np.random.shuffle(eval_indexes) 110 | self.eval_indexes.append(eval_indexes) 111 | for metric in self.eval_cfg['metrics']: 112 | evaluator, self.eval_indexes = build_evaluator( 113 | metric, self.eval_cfg, len(self.data_infos), self.eval_indexes) 114 | self.evaluators.append(evaluator) 115 | 116 | self.eval_indexes = np.concatenate(self.eval_indexes) 117 | 118 | def evaluate(self, results, work_dir, logger=None): 119 | metrics = {} 120 | for evaluator in self.evaluators: 121 | metrics.update(evaluator.evaluate(results)) 122 | if logger is not None: 123 | logger.info(metrics) 124 | return metrics 125 | -------------------------------------------------------------------------------- /mogen/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | from functools import partial 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | from mmcv.parallel import collate 8 | from mmcv.runner import get_dist_info 9 | from mmcv.utils import Registry, build_from_cfg 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.dataset import Dataset 12 | 13 | from .samplers import DistributedSampler 14 | 15 | if platform.system() != 'Windows': 16 | # https://github.com/pytorch/pytorch/issues/973 17 | import resource 18 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 19 | base_soft_limit = rlimit[0] 20 | hard_limit = rlimit[1] 21 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 22 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 23 | 24 | DATASETS = Registry('dataset') 25 | PIPELINES = Registry('pipeline') 26 | 27 | 28 | def build_dataset(cfg: Union[dict, list, tuple], 29 | default_args: Optional[Union[dict, None]] = None): 30 | """"Build dataset by the given config.""" 31 | from .dataset_wrappers import ConcatDataset, RepeatDataset 32 | if isinstance(cfg, (list, tuple)): 33 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 34 | elif cfg['type'] == 'RepeatDataset': 35 | dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), 36 | cfg['times']) 37 | else: 38 | dataset = build_from_cfg(cfg, DATASETS, default_args) 39 | 40 | return dataset 41 | 42 | 43 | def build_dataloader(dataset: Dataset, 44 | samples_per_gpu: int, 45 | workers_per_gpu: int, 46 | num_gpus: Optional[int] = 1, 47 | dist: Optional[bool] = True, 48 | shuffle: Optional[bool] = True, 49 | round_up: Optional[bool] = True, 50 | seed: Optional[Union[int, None]] = None, 51 | persistent_workers: Optional[bool] = True, 52 | **kwargs): 53 | """Build PyTorch DataLoader. 54 | In distributed training, each GPU/process has a dataloader. 55 | In non-distributed training, there is only one dataloader for all GPUs. 56 | Args: 57 | dataset (:obj:`Dataset`): A PyTorch dataset. 58 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 59 | batch size of each GPU. 60 | workers_per_gpu (int): How many subprocesses to use for data loading 61 | for each GPU. 62 | num_gpus (int, optional): Number of GPUs. Only used in non-distributed 63 | training. 64 | dist (bool, optional): Distributed training/test or not. Default: True. 65 | shuffle (bool, optional): Whether to shuffle the data at every epoch. 66 | Default: True. 67 | round_up (bool, optional): Whether to round up the length of dataset by 68 | adding extra samples to make it evenly divisible. Default: True. 69 | kwargs: any keyword argument to be used to initialize DataLoader 70 | Returns: 71 | DataLoader: A PyTorch dataloader. 72 | """ 73 | rank, world_size = get_dist_info() 74 | if dist: 75 | sampler = DistributedSampler(dataset, 76 | world_size, 77 | rank, 78 | shuffle=shuffle, 79 | round_up=round_up) 80 | shuffle = False 81 | batch_size = samples_per_gpu 82 | num_workers = workers_per_gpu 83 | else: 84 | sampler = None 85 | batch_size = num_gpus * samples_per_gpu 86 | num_workers = num_gpus * workers_per_gpu 87 | 88 | init_fn = partial( 89 | worker_init_fn, num_workers=num_workers, rank=rank, 90 | seed=seed) if seed is not None else None 91 | 92 | data_loader = DataLoader(dataset, 93 | batch_size=batch_size, 94 | sampler=sampler, 95 | num_workers=num_workers, 96 | collate_fn=partial( 97 | collate, samples_per_gpu=samples_per_gpu), 98 | pin_memory=False, 99 | shuffle=shuffle, 100 | worker_init_fn=init_fn, 101 | persistent_workers=persistent_workers, 102 | **kwargs) 103 | 104 | return data_loader 105 | 106 | 107 | def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): 108 | """Init random seed for each worker.""" 109 | # The seed of each worker equals to 110 | # num_worker * rank + worker_id + user_seed 111 | worker_seed = num_workers * rank + worker_id + seed 112 | np.random.seed(worker_seed) 113 | random.seed(worker_seed) 114 | -------------------------------------------------------------------------------- /mogen/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 2 | from torch.utils.data.dataset import Dataset 3 | 4 | from .builder import DATASETS 5 | 6 | 7 | @DATASETS.register_module() 8 | class ConcatDataset(_ConcatDataset): 9 | """A wrapper of concatenated dataset. 10 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 11 | add `get_cat_ids` function. 12 | Args: 13 | datasets (list[:obj:`Dataset`]): A list of datasets. 14 | """ 15 | 16 | def __init__(self, datasets: list): 17 | super(ConcatDataset, self).__init__(datasets) 18 | 19 | 20 | @DATASETS.register_module() 21 | class RepeatDataset(object): 22 | """A wrapper of repeated dataset. 23 | The length of repeated dataset will be `times` larger than the original 24 | dataset. This is useful when the data loading time is long but the dataset 25 | is small. Using RepeatDataset can reduce the data loading time between 26 | epochs. 27 | Args: 28 | dataset (:obj:`Dataset`): The dataset to be repeated. 29 | times (int): Repeat times. 30 | """ 31 | 32 | def __init__(self, dataset: Dataset, times: int): 33 | self.dataset = dataset 34 | self.times = times 35 | 36 | self._ori_len = len(self.dataset) 37 | 38 | def __getitem__(self, idx: int): 39 | return self.dataset[idx % self._ori_len] 40 | 41 | def __len__(self): 42 | return self.times * self._ori_len 43 | -------------------------------------------------------------------------------- /mogen/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .compose import Compose 2 | from .formatting import (Collect, ToTensor, Transpose, WrapFieldsToLists, 3 | to_tensor) 4 | from .siamese_motion import ProcessSiameseMotion, SwapSiameseMotion 5 | from .transforms import Crop, Normalize, RandomCrop 6 | 7 | __all__ = [ 8 | 'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', 9 | 'ToTensor', 'Crop', 'RandomCrop', 'Normalize', 'SwapSiameseMotion', 10 | 'ProcessSiameseMotion' 11 | ] 12 | -------------------------------------------------------------------------------- /mogen/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Compose(object): 10 | """Compose a data pipeline with a sequence of transforms. 11 | 12 | Args: 13 | transforms (list[dict | callable]): 14 | Either config dicts of transforms or transform objects. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | assert isinstance(transforms, Sequence) 19 | self.transforms = [] 20 | for transform in transforms: 21 | if isinstance(transform, dict): 22 | transform = build_from_cfg(transform, PIPELINES) 23 | self.transforms.append(transform) 24 | elif callable(transform): 25 | self.transforms.append(transform) 26 | else: 27 | raise TypeError('transform must be callable or a dict, but got' 28 | f' {type(transform)}') 29 | 30 | def __call__(self, data): 31 | for t in self.transforms: 32 | data = t(data) 33 | if data is None: 34 | return None 35 | return data 36 | 37 | def __repr__(self): 38 | format_string = self.__class__.__name__ + '(' 39 | for t in self.transforms: 40 | format_string += f'\n {t}' 41 | format_string += '\n)' 42 | return format_string 43 | -------------------------------------------------------------------------------- /mogen/datasets/pipelines/formatting.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import DataContainer as DC 7 | 8 | from ..builder import PIPELINES 9 | 10 | 11 | def to_tensor(data): 12 | """Convert objects of various python types to :obj:`torch.Tensor`. 13 | 14 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 15 | :class:`Sequence`, :class:`int` and :class:`float`. 16 | """ 17 | if isinstance(data, torch.Tensor): 18 | return data 19 | elif isinstance(data, np.ndarray): 20 | return torch.from_numpy(data) 21 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 22 | return torch.tensor(data) 23 | elif isinstance(data, int): 24 | return torch.LongTensor([data]) 25 | elif isinstance(data, float): 26 | return torch.FloatTensor([data]) 27 | else: 28 | raise TypeError( 29 | f'Type {type(data)} cannot be converted to tensor.' 30 | 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' 31 | '`Sequence`, `int` and `float`') 32 | 33 | 34 | @PIPELINES.register_module() 35 | class ToTensor(object): 36 | 37 | def __init__(self, keys): 38 | self.keys = keys 39 | 40 | def __call__(self, results): 41 | for key in self.keys: 42 | results[key] = to_tensor(results[key]) 43 | return results 44 | 45 | def __repr__(self): 46 | return self.__class__.__name__ + f'(keys={self.keys})' 47 | 48 | 49 | @PIPELINES.register_module() 50 | class Transpose(object): 51 | 52 | def __init__(self, keys, order): 53 | self.keys = keys 54 | self.order = order 55 | 56 | def __call__(self, results): 57 | for key in self.keys: 58 | results[key] = results[key].transpose(self.order) 59 | return results 60 | 61 | def __repr__(self): 62 | return self.__class__.__name__ + \ 63 | f'(keys={self.keys}, order={self.order})' 64 | 65 | 66 | @PIPELINES.register_module() 67 | class Collect(object): 68 | """Collect data from the loader relevant to the specific task. 69 | 70 | This is usually the last stage of the data loader pipeline. 71 | 72 | Args: 73 | keys (Sequence[str]): Keys of results to be collected in ``data``. 74 | meta_keys (Sequence[str], optional): Meta keys to be converted to 75 | ``mmcv.DataContainer`` and collected in ``data[motion_metas]``. 76 | Default: ``('filename', 'ori_filename', 77 | 'ori_shape', 'motion_shape', 'motion_mask')`` 78 | 79 | Returns: 80 | dict: The result dict contains the following keys 81 | - keys in``self.keys`` 82 | - ``motion_metas`` if available 83 | """ 84 | 85 | def __init__(self, 86 | keys, 87 | meta_keys=('filename', 'ori_filename', 'ori_shape', 88 | 'motion_shape', 'motion_mask')): 89 | self.keys = keys 90 | self.meta_keys = meta_keys 91 | 92 | def __call__(self, results): 93 | data = {} 94 | motion_meta = {} 95 | for key in self.meta_keys: 96 | if key in results: 97 | motion_meta[key] = results[key] 98 | data['motion_metas'] = DC(motion_meta, cpu_only=True) 99 | for key in self.keys: 100 | data[key] = results[key] 101 | return data 102 | 103 | def __repr__(self): 104 | return self.__class__.__name__ + \ 105 | f'(keys={self.keys}, meta_keys={self.meta_keys})' 106 | 107 | 108 | @PIPELINES.register_module() 109 | class WrapFieldsToLists(object): 110 | """Wrap fields of the data dictionary into lists for evaluation. 111 | 112 | This class can be used as a last step of a test or validation 113 | pipeline for single image evaluation or inference. 114 | 115 | Example: 116 | >>> test_pipeline = [ 117 | >>> dict(type='LoadImageFromFile'), 118 | >>> dict(type='Normalize', 119 | mean=[123.675, 116.28, 103.53], 120 | std=[58.395, 57.12, 57.375], 121 | to_rgb=True), 122 | >>> dict(type='ImageToTensor', keys=['img']), 123 | >>> dict(type='Collect', keys=['img']), 124 | >>> dict(type='WrapIntoLists') 125 | >>> ] 126 | """ 127 | 128 | def __call__(self, results): 129 | # Wrap dict fields into lists 130 | for key, val in results.items(): 131 | results[key] = [val] 132 | return results 133 | 134 | def __repr__(self): 135 | return f'{self.__class__.__name__}()' 136 | -------------------------------------------------------------------------------- /mogen/datasets/pipelines/siamese_motion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..builder import PIPELINES 5 | from .quaternion import qbetween_np, qinv_np, qmul_np, qrot_np 6 | 7 | face_joint_indx = [2, 1, 17, 16] 8 | fid_l = [7, 10] 9 | fid_r = [8, 11] 10 | 11 | trans_matrix = torch.Tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], 12 | [0.0, -1.0, 0.0]]) 13 | 14 | 15 | def rigid_transform(relative, data): 16 | 17 | global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3)) 18 | global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3)) 19 | 20 | relative_rot = relative[0] 21 | relative_t = relative[1:3] 22 | relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4, )) 23 | relative_r_rot_quat[..., 0] = np.cos(relative_rot) 24 | relative_r_rot_quat[..., 2] = np.sin(relative_rot) 25 | global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions) 26 | global_positions[..., [0, 2]] += relative_t 27 | data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1, )) 28 | global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel) 29 | data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1, )) 30 | 31 | return data 32 | 33 | 34 | @PIPELINES.register_module() 35 | class SwapSiameseMotion(object): 36 | r"""Swap motion sequences. 37 | 38 | Args: 39 | prob (float): The probability of swapping siamese motions 40 | """ 41 | 42 | def __init__(self, prob=0.5): 43 | self.prob = prob 44 | assert prob >= 0 and prob <= 1.0 45 | 46 | def __call__(self, results): 47 | if np.random.rand() <= self.prob: 48 | motion1 = results['motion1'] 49 | motion2 = results['motion2'] 50 | results['motion1'] = motion2 51 | results['motion2'] = motion1 52 | return results 53 | 54 | def __repr__(self): 55 | repr_str = self.__class__.__name__ + f'(prob={self.prob})' 56 | return repr_str 57 | 58 | 59 | @PIPELINES.register_module() 60 | class ProcessSiameseMotion(object): 61 | r"""Process siamese motion sequences. 62 | The code is borrowed from 63 | https://github.com/tr3e/InterGen/blob/master/utils/utils.py 64 | """ 65 | 66 | def __init__(self, feet_threshold, prev_frames, n_joints, prob): 67 | self.feet_threshold = feet_threshold 68 | self.prev_frames = prev_frames 69 | self.n_joints = n_joints 70 | self.prob = prob 71 | 72 | def process_single_motion(self, motion): 73 | feet_thre = self.feet_threshold 74 | prev_frames = self.prev_frames 75 | n_joints = self.n_joints 76 | '''Uniform Skeleton''' 77 | # positions = uniform_skeleton(positions, tgt_offsets) 78 | 79 | positions = motion[:, :n_joints * 3].reshape(-1, n_joints, 3) 80 | rotations = motion[:, n_joints * 3:] 81 | 82 | positions = np.einsum("mn, tjn->tjm", trans_matrix, positions) 83 | '''Put on Floor''' 84 | floor_height = positions.min(axis=0).min(axis=0)[1] 85 | positions[:, :, 1] -= floor_height 86 | '''XZ at origin''' 87 | root_pos_init = positions[prev_frames] 88 | root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) 89 | positions = positions - root_pose_init_xz 90 | '''All initially face Z+''' 91 | r_hip, l_hip, sdr_r, sdr_l = face_joint_indx 92 | across = root_pos_init[r_hip] - root_pos_init[l_hip] 93 | across = across / np.sqrt((across**2).sum(axis=-1))[..., np.newaxis] 94 | 95 | # forward (3,), rotate around y-axis 96 | forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) 97 | # forward (3,) 98 | forward_init = forward_init / np.sqrt((forward_init**2).sum(axis=-1)) 99 | forward_init = forward_init[..., np.newaxis] 100 | 101 | target = np.array([[0, 0, 1]]) 102 | root_quat_init = qbetween_np(forward_init, target) 103 | root_quat_init_for_all = \ 104 | np.ones(positions.shape[:-1] + (4,)) * root_quat_init 105 | 106 | positions = qrot_np(root_quat_init_for_all, positions) 107 | """ Get Foot Contacts """ 108 | 109 | def foot_detect(positions, thres): 110 | velfactor, heightfactor = \ 111 | np.array([thres, thres]), np.array([0.12, 0.05]) 112 | 113 | feet_l_x = \ 114 | (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 115 | feet_l_y = \ 116 | (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 117 | feet_l_z = \ 118 | (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 119 | feet_l_h = positions[:-1, fid_l, 1] 120 | feet_l_sum = feet_l_x + feet_l_y + feet_l_z 121 | feet_l = ((feet_l_sum < velfactor) & (feet_l_h < heightfactor)) 122 | feet_l = feet_l.astype(np.float32) 123 | 124 | feet_r_x = \ 125 | (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 126 | feet_r_y = \ 127 | (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 128 | feet_r_z = \ 129 | (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 130 | feet_r_h = positions[:-1, fid_r, 1] 131 | feet_r_sum = feet_r_x + feet_r_y + feet_r_z 132 | feet_r = ((feet_r_sum < velfactor) & (feet_r_h < heightfactor)) 133 | feet_r = feet_r.astype(np.float32) 134 | return feet_l, feet_r 135 | 136 | feet_l, feet_r = foot_detect(positions, feet_thre) 137 | '''Get Joint Rotation Representation''' 138 | rot_data = rotations 139 | '''Get Joint Rotation Invariant Position Represention''' 140 | joint_positions = positions.reshape(len(positions), -1) 141 | joint_vels = positions[1:] - positions[:-1] 142 | joint_vels = joint_vels.reshape(len(joint_vels), -1) 143 | 144 | data = joint_positions[:-1] 145 | data = np.concatenate([data, joint_vels], axis=-1) 146 | data = np.concatenate([data, rot_data[:-1]], axis=-1) 147 | data = np.concatenate([data, feet_l, feet_r], axis=-1) 148 | 149 | return data, root_quat_init, root_pose_init_xz[None] 150 | 151 | def __call__(self, results): 152 | motion1, root_quat_init1, root_pos_init1 = \ 153 | self.process_single_motion(results['motion1']) 154 | motion2, root_quat_init2, root_pos_init2 = \ 155 | self.process_single_motion(results['motion2']) 156 | r_relative = qmul_np(root_quat_init2, qinv_np(root_quat_init1)) 157 | angle = np.arctan2(r_relative[:, 2:3], r_relative[:, 0:1]) 158 | 159 | xz = qrot_np(root_quat_init1, root_pos_init2 - root_pos_init1)[:, 160 | [0, 2]] 161 | relative = np.concatenate([angle, xz], axis=-1)[0] 162 | motion2 = rigid_transform(relative, motion2) 163 | if np.random.rand() <= self.prob: 164 | motion2, motion1 = motion1, motion2 165 | motion = np.concatenate((motion1, motion2), axis=-1) 166 | results['motion'] = motion 167 | return results 168 | 169 | def __repr__(self): 170 | repr_str = self.__class__.__name__ 171 | repr_str += f'(feet_threshold={self.feet_threshold})' 172 | repr_str += f'(feet_threshold={self.feet_threshold})' 173 | repr_str += f'(n_joints={self.n_joints})' 174 | repr_str += f'(prob={self.prob})' 175 | return repr_str 176 | -------------------------------------------------------------------------------- /mogen/datasets/pipelines/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class Crop(object): 12 | r"""Crop motion sequences. 13 | 14 | Args: 15 | crop_size (int): The size of the cropped motion sequence. 16 | """ 17 | 18 | def __init__(self, crop_size: Optional[Union[int, None]] = None): 19 | self.crop_size = crop_size 20 | assert self.crop_size is not None 21 | 22 | def __call__(self, results): 23 | motion = results['motion'] 24 | length = len(motion) 25 | if length >= self.crop_size: 26 | idx = random.randint(0, length - self.crop_size) 27 | motion = motion[idx:idx + self.crop_size] 28 | results['motion_length'] = self.crop_size 29 | else: 30 | padding_length = self.crop_size - length 31 | D = motion.shape[1:] 32 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) 33 | motion = np.concatenate([motion, padding_zeros], axis=0) 34 | results['motion_length'] = length 35 | assert len(motion) == self.crop_size 36 | results['motion'] = motion 37 | results['motion_shape'] = motion.shape 38 | if length >= self.crop_size: 39 | results['motion_mask'] = torch.ones(self.crop_size).numpy() 40 | else: 41 | results['motion_mask'] = torch.cat( 42 | (torch.ones(length), 43 | torch.zeros(self.crop_size - length))).numpy() 44 | return results 45 | 46 | def __repr__(self): 47 | repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' 48 | return repr_str 49 | 50 | 51 | @PIPELINES.register_module() 52 | class RandomCrop(object): 53 | r"""Random crop motion sequences. Each sequence will be padded with zeros 54 | to the maximum length. 55 | 56 | Args: 57 | min_size (int or None): The minimum size of the cropped motion 58 | sequence (inclusive). 59 | max_size (int or None): The maximum size of the cropped motion 60 | sequence (inclusive). 61 | """ 62 | 63 | def __init__(self, 64 | min_size: Optional[Union[int, None]] = None, 65 | max_size: Optional[Union[int, None]] = None): 66 | self.min_size = min_size 67 | self.max_size = max_size 68 | assert self.min_size is not None 69 | assert self.max_size is not None 70 | 71 | def __call__(self, results): 72 | motion = results['motion'] 73 | length = len(motion) 74 | crop_size = random.randint(self.min_size, self.max_size) 75 | if length > crop_size: 76 | idx = random.randint(0, length - crop_size) 77 | motion = motion[idx:idx + crop_size] 78 | results['motion_length'] = crop_size 79 | else: 80 | results['motion_length'] = length 81 | padding_length = self.max_size - min(crop_size, length) 82 | if padding_length > 0: 83 | D = motion.shape[1:] 84 | padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) 85 | motion = np.concatenate([motion, padding_zeros], axis=0) 86 | results['motion'] = motion 87 | results['motion_shape'] = motion.shape 88 | if length >= self.max_size and crop_size == self.max_size: 89 | results['motion_mask'] = torch.ones(self.max_size).numpy() 90 | else: 91 | results['motion_mask'] = torch.cat( 92 | (torch.ones(min(length, crop_size)), 93 | torch.zeros(self.max_size - min(length, crop_size))), 94 | dim=0).numpy() 95 | assert len(motion) == self.max_size 96 | return results 97 | 98 | def __repr__(self): 99 | repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' 100 | repr_str += f', max_size={self.max_size})' 101 | return repr_str 102 | 103 | 104 | @PIPELINES.register_module() 105 | class Normalize(object): 106 | """Normalize motion sequences. 107 | 108 | Args: 109 | mean_path (str): Path of mean file. 110 | std_path (str): Path of std file. 111 | """ 112 | 113 | def __init__(self, mean_path, std_path, eps=1e-9, keys=['motion']): 114 | self.mean = np.load(mean_path) 115 | self.std = np.load(std_path) 116 | self.eps = eps 117 | self.keys = keys 118 | 119 | def __call__(self, results): 120 | for k in self.keys: 121 | motion = results[k] 122 | motion = (motion - self.mean) / (self.std + self.eps) 123 | results[k] = motion 124 | return results 125 | -------------------------------------------------------------------------------- /mogen/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_sampler import DistributedSampler 2 | 3 | __all__ = ['DistributedSampler'] 4 | -------------------------------------------------------------------------------- /mogen/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DistributedSampler as _DistributedSampler 3 | 4 | 5 | class DistributedSampler(_DistributedSampler): 6 | 7 | def __init__(self, 8 | dataset, 9 | num_replicas=None, 10 | rank=None, 11 | shuffle=True, 12 | round_up=True): 13 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 14 | self.shuffle = shuffle 15 | self.round_up = round_up 16 | if self.round_up: 17 | self.total_size = self.num_samples * self.num_replicas 18 | else: 19 | self.total_size = len(self.dataset) 20 | 21 | def __iter__(self): 22 | # deterministically shuffle based on epoch 23 | if self.shuffle: 24 | g = torch.Generator() 25 | g.manual_seed(self.epoch) 26 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 27 | else: 28 | indices = torch.arange(len(self.dataset)).tolist() 29 | 30 | # add extra samples to make it evenly divisible 31 | if self.round_up: 32 | indices = ( 33 | indices * 34 | int(self.total_size / len(indices) + 1))[:self.total_size] 35 | assert len(indices) == self.total_size 36 | 37 | # subsample 38 | indices = indices[self.rank:self.total_size:self.num_replicas] 39 | if self.round_up: 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | -------------------------------------------------------------------------------- /mogen/datasets/text_motion_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import os.path 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from .base_dataset import BaseMotionDataset 10 | from .builder import DATASETS 11 | 12 | 13 | @DATASETS.register_module() 14 | class TextMotionDataset(BaseMotionDataset): 15 | """TextMotion dataset. 16 | 17 | Args: 18 | text_dir (str): Path to the directory containing the text files. 19 | """ 20 | 21 | def __init__(self, 22 | data_prefix: str, 23 | pipeline: list, 24 | dataset_name: Optional[Union[str, None]] = None, 25 | fixed_length: Optional[Union[int, None]] = None, 26 | ann_file: Optional[Union[str, None]] = None, 27 | motion_dir: Optional[Union[str, None]] = None, 28 | text_dir: Optional[Union[str, None]] = None, 29 | token_dir: Optional[Union[str, None]] = None, 30 | clip_feat_dir: Optional[Union[str, None]] = None, 31 | eval_cfg: Optional[Union[dict, None]] = None, 32 | test_mode: Optional[bool] = False, 33 | siamese_mode: Optional[bool] = False, 34 | tcomb_mode: Optional[bool] = False): 35 | self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, 36 | text_dir) 37 | if token_dir is not None: 38 | self.token_dir = os.path.join(data_prefix, 'datasets', 39 | dataset_name, token_dir) 40 | else: 41 | self.token_dir = None 42 | if clip_feat_dir is not None: 43 | self.clip_feat_dir = os.path.join(data_prefix, 'datasets', 44 | dataset_name, clip_feat_dir) 45 | else: 46 | self.clip_feat_dir = None 47 | self.siamese_mode = siamese_mode 48 | self.tcomb_mode = tcomb_mode 49 | super(TextMotionDataset, self).__init__(data_prefix=data_prefix, 50 | pipeline=pipeline, 51 | dataset_name=dataset_name, 52 | fixed_length=fixed_length, 53 | ann_file=ann_file, 54 | motion_dir=motion_dir, 55 | eval_cfg=eval_cfg, 56 | test_mode=test_mode) 57 | 58 | def load_anno(self, name): 59 | results = {} 60 | if self.siamese_mode: 61 | motion_path = os.path.join(self.motion_dir, name + '.npz') 62 | motion_data = np.load(motion_path) 63 | results['motion1'] = motion_data['motion1'] 64 | results['motion2'] = motion_data['motion2'] 65 | assert results['motion1'].shape == results['motion2'].shape 66 | else: 67 | motion_path = os.path.join(self.motion_dir, name + '.npy') 68 | motion_data = np.load(motion_path) 69 | results['motion'] = motion_data 70 | text_path = os.path.join(self.text_dir, name + '.txt') 71 | text_data = [] 72 | for line in open(text_path, 'r'): 73 | text_data.append(line.strip()) 74 | results['text'] = text_data 75 | if self.token_dir is not None: 76 | token_path = os.path.join(self.token_dir, name + '.txt') 77 | token_data = [] 78 | for line in open(token_path, 'r'): 79 | token_data.append(line.strip()) 80 | results['token'] = token_data 81 | if self.clip_feat_dir is not None: 82 | clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy') 83 | clip_feat = torch.from_numpy(np.load(clip_feat_path)) 84 | results['clip_feat'] = clip_feat 85 | return results 86 | 87 | def prepare_data(self, idx: int): 88 | """"Prepare raw data for the f'{idx'}-th data.""" 89 | results = copy.deepcopy(self.data_infos[idx]) 90 | text_list = results['text'] 91 | idx = np.random.randint(0, len(text_list)) 92 | results['text'] = text_list[idx] 93 | if 'clip_feat' in results.keys(): 94 | results['clip_feat'] = results['clip_feat'][idx] 95 | if 'token' in results.keys(): 96 | results['token'] = results['token'][idx] 97 | results['dataset_name'] = self.dataset_name 98 | results = self.pipeline(results) 99 | return results 100 | -------------------------------------------------------------------------------- /mogen/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import * # noqa: F401,F403 2 | from .attentions import * # noqa: F401,F403 3 | from .builder import * # noqa: F401,F403 4 | from .losses import * # noqa: F401,F403 5 | from .rnns import * # noqa: F401,F403 6 | from .transformers import * # noqa: F401,F403 7 | from .utils import * # noqa: F401,F403 8 | -------------------------------------------------------------------------------- /mogen/models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_architecture import MotionDiffusion 2 | from .vae_architecture import MotionVAE 3 | 4 | __all__ = ['MotionVAE', 'MotionDiffusion'] 5 | -------------------------------------------------------------------------------- /mogen/models/architectures/base_architecture.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import BaseModule 6 | 7 | 8 | def to_cpu(x): 9 | if isinstance(x, torch.Tensor): 10 | return x.detach().cpu() 11 | return x 12 | 13 | 14 | class BaseArchitecture(BaseModule): 15 | """Base class for mogen architecture.""" 16 | 17 | def __init__(self, init_cfg=None): 18 | super(BaseArchitecture, self).__init__(init_cfg) 19 | 20 | def forward_train(self, **kwargs): 21 | pass 22 | 23 | def forward_test(self, **kwargs): 24 | pass 25 | 26 | def _parse_losses(self, losses): 27 | """Parse the raw outputs (losses) of the network. 28 | Args: 29 | losses (dict): Raw output of the network, which usually contain 30 | losses and other necessary information. 31 | Returns: 32 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ 33 | which may be a weighted sum of all losses, log_vars contains \ 34 | all the variables to be sent to the logger. 35 | """ 36 | log_vars = OrderedDict() 37 | for loss_name, loss_value in losses.items(): 38 | if isinstance(loss_value, torch.Tensor): 39 | log_vars[loss_name] = loss_value.mean() 40 | elif isinstance(loss_value, list): 41 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 42 | else: 43 | raise TypeError( 44 | f'{loss_name} is not a tensor or list of tensors') 45 | 46 | loss = sum(_value for _key, _value in log_vars.items() 47 | if 'loss' in _key) 48 | 49 | log_vars['loss'] = loss 50 | for loss_name, loss_value in log_vars.items(): 51 | # reduce loss when distributed training 52 | if dist.is_available() and dist.is_initialized(): 53 | loss_value = loss_value.data.clone() 54 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 55 | log_vars[loss_name] = loss_value.item() 56 | 57 | return loss, log_vars 58 | 59 | def train_step(self, data, optimizer): 60 | """The iteration step during training. 61 | This method defines an iteration step during training, except for the 62 | back propagation and optimizer updating, which are done in an optimizer 63 | hook. Note that in some complicated cases or models, the whole process 64 | including back propagation and optimizer updating is also defined in 65 | this method, such as GAN. 66 | Args: 67 | data (dict): The output of dataloader. 68 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 69 | runner is passed to ``train_step()``. This argument is unused 70 | and reserved. 71 | Returns: 72 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ 73 | ``num_samples``. 74 | - ``loss`` is a tensor for back propagation, which can be a 75 | weighted sum of multiple losses. 76 | - ``log_vars`` contains all the variables to be sent to the 77 | logger. 78 | - ``num_samples`` indicates the batch size (when the model is 79 | DDP, it means the batch size on each GPU), which is used for 80 | averaging the logs. 81 | """ 82 | losses = self(**data) 83 | loss, log_vars = self._parse_losses(losses) 84 | 85 | outputs = dict(loss=loss, 86 | log_vars=log_vars, 87 | num_samples=len(data['motion'])) 88 | 89 | return outputs 90 | 91 | def val_step(self, data, optimizer=None): 92 | """The iteration step during validation. 93 | This method shares the same signature as :func:`train_step`, but used 94 | during val epochs. Note that the evaluation after training epochs is 95 | not implemented with this method, but an evaluation hook. 96 | """ 97 | losses = self(**data) 98 | loss, log_vars = self._parse_losses(losses) 99 | 100 | outputs = dict(loss=loss, 101 | log_vars=log_vars, 102 | num_samples=len(data['motion'])) 103 | 104 | return outputs 105 | 106 | def forward(self, **kwargs): 107 | if self.training: 108 | return self.forward_train(**kwargs) 109 | else: 110 | return self.forward_test(**kwargs) 111 | 112 | def split_results(self, results): 113 | B = results['motion'].shape[0] 114 | output = [] 115 | for i in range(B): 116 | batch_output = dict() 117 | batch_output['motion'] = to_cpu(results['motion'][i]) 118 | batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) 119 | batch_output['motion_length'] = to_cpu(results['motion_length'][i]) 120 | batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) 121 | if 'pred_motion_length' in results.keys(): 122 | batch_output['pred_motion_length'] = \ 123 | to_cpu(results['pred_motion_length'][i]) 124 | else: 125 | batch_output['pred_motion_length'] = \ 126 | to_cpu(results['motion_length'][i]) 127 | if 'pred_motion_mask' in results: 128 | batch_output['pred_motion_mask'] = \ 129 | to_cpu(results['pred_motion_mask'][i]) 130 | else: 131 | batch_output['pred_motion_mask'] = \ 132 | to_cpu(results['motion_mask'][i]) 133 | if 'motion_metas' in results.keys(): 134 | motion_metas = results['motion_metas'][i] 135 | if 'text' in motion_metas.keys(): 136 | batch_output['text'] = motion_metas['text'] 137 | if 'token' in motion_metas.keys(): 138 | batch_output['token'] = motion_metas['token'] 139 | output.append(batch_output) 140 | return output 141 | -------------------------------------------------------------------------------- /mogen/models/architectures/vae_architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..builder import ARCHITECTURES, build_loss, build_submodule 4 | from .base_architecture import BaseArchitecture 5 | 6 | 7 | @ARCHITECTURES.register_module() 8 | class PoseVAE(BaseArchitecture): 9 | 10 | def __init__(self, 11 | encoder=None, 12 | decoder=None, 13 | loss_recon=None, 14 | kl_div_loss_weight=None, 15 | init_cfg=None, 16 | **kwargs): 17 | super().__init__(init_cfg=init_cfg, **kwargs) 18 | self.encoder = build_submodule(encoder) 19 | self.decoder = build_submodule(decoder) 20 | self.loss_recon = build_loss(loss_recon) 21 | self.kl_div_loss_weight = kl_div_loss_weight 22 | 23 | def reparameterize(self, mu, logvar): 24 | std = torch.exp(logvar / 2) 25 | 26 | eps = std.data.new(std.size()).normal_() 27 | latent_code = eps.mul(std).add_(mu) 28 | return latent_code 29 | 30 | def encode(self, pose): 31 | mu, logvar = self.encoder(pose) 32 | return mu 33 | 34 | def forward(self, **kwargs): 35 | motion = kwargs['motion'].float() 36 | B, T = motion.shape[:2] 37 | pose = motion.reshape(B * T, -1) 38 | pose = pose[:, :-4] 39 | 40 | mu, logvar = self.encoder(pose) 41 | z = self.reparameterize(mu, logvar) 42 | pred = self.decoder(z) 43 | 44 | loss = dict() 45 | recon_loss = self.loss_recon(pred, pose, reduction_override='none') 46 | loss['recon_loss'] = recon_loss 47 | if self.kl_div_loss_weight is not None: 48 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 49 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) 50 | 51 | return loss 52 | 53 | 54 | @ARCHITECTURES.register_module() 55 | class MotionVAE(BaseArchitecture): 56 | 57 | def __init__(self, 58 | encoder=None, 59 | decoder=None, 60 | loss_recon=None, 61 | kl_div_loss_weight=None, 62 | init_cfg=None, 63 | **kwargs): 64 | super().__init__(init_cfg=init_cfg, **kwargs) 65 | self.encoder = build_submodule(encoder) 66 | self.decoder = build_submodule(decoder) 67 | self.loss_recon = build_loss(loss_recon) 68 | self.kl_div_loss_weight = kl_div_loss_weight 69 | 70 | def sample(self, std=1, latent_code=None): 71 | if latent_code is not None: 72 | z = latent_code 73 | else: 74 | z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std 75 | output = self.decoder(z) 76 | if self.use_normalization: 77 | output = output * self.motion_std 78 | output = output + self.motion_mean 79 | return output 80 | 81 | def reparameterize(self, mu, logvar): 82 | std = torch.exp(logvar / 2) 83 | 84 | eps = std.data.new(std.size()).normal_() 85 | latent_code = eps.mul(std).add_(mu) 86 | return latent_code 87 | 88 | def encode(self, motion, motion_mask): 89 | mu, logvar = self.encoder(motion, motion_mask) 90 | return self.reparameterize(mu, logvar) 91 | 92 | def decode(self, z, motion_mask): 93 | return self.decoder(z, motion_mask) 94 | 95 | def forward(self, **kwargs): 96 | motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] 97 | B, T = motion.shape[:2] 98 | 99 | mu, logvar = self.encoder(motion, motion_mask) 100 | z = self.reparameterize(mu, logvar) 101 | pred = self.decoder(z, motion_mask) 102 | 103 | loss = dict() 104 | recon_loss = self.loss_recon(pred, motion, reduction_override='none') 105 | recon_loss = recon_loss.mean(dim=-1) * motion_mask 106 | recon_loss = recon_loss.sum() / motion_mask.sum() 107 | loss['recon_loss'] = recon_loss 108 | if self.kl_div_loss_weight is not None: 109 | loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 110 | loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /mogen/models/attentions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_attention import BaseMixedAttention 2 | from .efficient_attention import (EfficientCrossAttention, 3 | EfficientMixedAttention, 4 | EfficientSelfAttention) 5 | from .fine_attention import SAMI 6 | from .semantics_modulated import (DualSemanticsModulatedAttention, 7 | SemanticsModulatedAttention) 8 | 9 | __all__ = [ 10 | 'EfficientSelfAttention', 'EfficientCrossAttention', 11 | 'EfficientMixedAttention', 'SemanticsModulatedAttention', 12 | 'DualSemanticsModulatedAttention', 'BaseMixedAttention', 'SAMI' 13 | ] 14 | -------------------------------------------------------------------------------- /mogen/models/attentions/base_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import ATTENTIONS 6 | from ..utils.stylization_block import StylizationBlock 7 | 8 | 9 | @ATTENTIONS.register_module() 10 | class BaseMixedAttention(nn.Module): 11 | 12 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, 13 | time_embed_dim): 14 | super().__init__() 15 | self.num_heads = num_heads 16 | 17 | self.norm = nn.LayerNorm(latent_dim) 18 | self.text_norm = nn.LayerNorm(text_latent_dim) 19 | 20 | self.query = nn.Linear(latent_dim, latent_dim) 21 | self.key_text = nn.Linear(text_latent_dim, latent_dim) 22 | self.value_text = nn.Linear(text_latent_dim, latent_dim) 23 | self.key_motion = nn.Linear(latent_dim, latent_dim) 24 | self.value_motion = nn.Linear(latent_dim, latent_dim) 25 | 26 | self.dropout = nn.Dropout(dropout) 27 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 28 | 29 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): 30 | """ 31 | x: B, T, D 32 | xf: B, N, L 33 | """ 34 | B, T, D = x.shape 35 | N = xf.shape[1] + x.shape[1] 36 | H = self.num_heads 37 | # B, T, D 38 | query = self.query(self.norm(x)).view(B, T, H, -1) 39 | # B, N, D 40 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) 41 | text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) 42 | key = torch.cat( 43 | (self.key_text(self.text_norm(xf)), self.key_motion(self.norm(x))), 44 | dim=1).view(B, N, H, -1) 45 | 46 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 47 | motion_mask = src_mask.view(B, 1, T, 1) 48 | text_mask = text_cond_type.view(B, 1, -1, 1) 49 | mask = torch.cat((text_mask, motion_mask), dim=2) 50 | attention = attention + (1 - mask) * -1000000 51 | attention = F.softmax(attention, dim=2) 52 | 53 | value = torch.cat(( 54 | self.value_text(self.text_norm(xf)) * text_cond_type, 55 | self.value_motion(self.norm(x)) * src_mask, 56 | ), 57 | dim=1).view(B, N, H, -1) 58 | 59 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 60 | y = x + self.proj_out(y, emb) 61 | return y 62 | 63 | 64 | @ATTENTIONS.register_module() 65 | class BaseSelfAttention(nn.Module): 66 | 67 | def __init__(self, latent_dim, num_heads, dropout, time_embed_dim): 68 | super().__init__() 69 | self.num_heads = num_heads 70 | 71 | self.norm = nn.LayerNorm(latent_dim) 72 | self.query = nn.Linear(latent_dim, latent_dim) 73 | self.key = nn.Linear(latent_dim, latent_dim) 74 | self.value = nn.Linear(latent_dim, latent_dim) 75 | 76 | self.dropout = nn.Dropout(dropout) 77 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 78 | 79 | def forward(self, x, emb, src_mask, **kwargs): 80 | """ 81 | x: B, T, D 82 | """ 83 | B, T, D = x.shape 84 | H = self.num_heads 85 | # B, T, D 86 | query = self.query(self.norm(x)).view(B, T, H, -1) 87 | # B, N, D 88 | key = self.key(self.norm(x)).view(B, T, H, -1) 89 | 90 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 91 | mask = src_mask.view(B, 1, T, 1) 92 | attention = attention + (1 - mask) * -1000000 93 | attention = F.softmax(attention, dim=2) 94 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) 95 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 96 | y = x + self.proj_out(y, emb) 97 | return y 98 | 99 | 100 | @ATTENTIONS.register_module() 101 | class BaseCrossAttention(nn.Module): 102 | 103 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, 104 | time_embed_dim): 105 | super().__init__() 106 | self.num_heads = num_heads 107 | 108 | self.norm = nn.LayerNorm(latent_dim) 109 | self.text_norm = nn.LayerNorm(text_latent_dim) 110 | self.query = nn.Linear(latent_dim, latent_dim) 111 | self.key = nn.Linear(text_latent_dim, latent_dim) 112 | self.value = nn.Linear(text_latent_dim, latent_dim) 113 | self.dropout = nn.Dropout(dropout) 114 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 115 | 116 | def forward(self, x, xf, emb, src_mask, cond_type=None, **kwargs): 117 | """ 118 | x: B, T, D 119 | xf: B, N, L 120 | """ 121 | B, T, D = x.shape 122 | N = xf.shape[1] 123 | H = self.num_heads 124 | # B, T, D 125 | query = self.query(self.norm(x)).view(B, T, H, -1) 126 | # B, N, D 127 | if cond_type is None: 128 | text_cond_type = 1 129 | mask = 1 130 | else: 131 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) 132 | text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) 133 | mask = text_cond_type.view(B, 1, -1, 1) 134 | key = self.key(self.text_norm(xf)).view(B, N, H, -1) 135 | attention = torch.einsum('bnhl,bmhl->bnmh', query, key) 136 | attention = attention + (1 - mask) * -1000000 137 | attention = F.softmax(attention, dim=2) 138 | 139 | value = (self.value(self.text_norm(xf)) * text_cond_type) 140 | value = value.view(B, N, H, -1) 141 | y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) 142 | y = x + self.proj_out(y, emb) 143 | return y 144 | -------------------------------------------------------------------------------- /mogen/models/attentions/efficient_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import ATTENTIONS 6 | from ..utils.stylization_block import StylizationBlock 7 | 8 | 9 | @ATTENTIONS.register_module() 10 | class EfficientSelfAttention(nn.Module): 11 | 12 | def __init__(self, latent_dim, num_heads, dropout, time_embed_dim=None): 13 | super().__init__() 14 | self.num_heads = num_heads 15 | self.norm = nn.LayerNorm(latent_dim) 16 | self.query = nn.Linear(latent_dim, latent_dim) 17 | self.key = nn.Linear(latent_dim, latent_dim) 18 | self.value = nn.Linear(latent_dim, latent_dim) 19 | self.dropout = nn.Dropout(dropout) 20 | self.time_embed_dim = time_embed_dim 21 | if time_embed_dim is not None: 22 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, 23 | dropout) 24 | 25 | def forward(self, x, src_mask, emb=None, **kwargs): 26 | """ 27 | x: B, T, D 28 | """ 29 | B, T, D = x.shape 30 | H = self.num_heads 31 | # B, T, D 32 | query = self.query(self.norm(x)) 33 | # B, T, D 34 | key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000) 35 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 36 | key = F.softmax(key.view(B, T, H, -1), dim=1) 37 | # B, T, H, HD 38 | value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) 39 | # B, H, HD, HD 40 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 41 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 42 | if self.time_embed_dim is None: 43 | y = x + y 44 | else: 45 | y = x + self.proj_out(y, emb) 46 | return y 47 | 48 | 49 | @ATTENTIONS.register_module() 50 | class EfficientCrossAttention(nn.Module): 51 | 52 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, 53 | time_embed_dim): 54 | super().__init__() 55 | self.num_heads = num_heads 56 | self.norm = nn.LayerNorm(latent_dim) 57 | self.text_norm = nn.LayerNorm(text_latent_dim) 58 | self.query = nn.Linear(latent_dim, latent_dim) 59 | self.key = nn.Linear(text_latent_dim, latent_dim) 60 | self.value = nn.Linear(text_latent_dim, latent_dim) 61 | self.dropout = nn.Dropout(dropout) 62 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 63 | 64 | def forward(self, x, xf, emb, cond_type=None, **kwargs): 65 | """ 66 | x: B, T, D 67 | xf: B, N, L 68 | """ 69 | B, T, D = x.shape 70 | N = xf.shape[1] 71 | H = self.num_heads 72 | # B, T, D 73 | query = self.query(self.norm(x)) 74 | # B, N, D 75 | key = self.key(self.text_norm(xf)) 76 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 77 | if cond_type is None: 78 | key = F.softmax(key.view(B, N, H, -1), dim=1) 79 | # B, N, H, HD 80 | value = self.value(self.text_norm(xf)).view(B, N, H, -1) 81 | else: 82 | text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) 83 | text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) 84 | key = key + (1 - text_cond_type) * -1000000 85 | key = F.softmax(key.view(B, N, H, -1), dim=1) 86 | value = self.value(self.text_norm(xf) * text_cond_type) 87 | value = value.view(B, N, H, -1) 88 | # B, H, HD, HD 89 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 90 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 91 | y = x + self.proj_out(y, emb) 92 | return y 93 | 94 | 95 | @ATTENTIONS.register_module() 96 | class EfficientMixedAttention(nn.Module): 97 | 98 | def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, 99 | time_embed_dim): 100 | super().__init__() 101 | self.num_heads = num_heads 102 | 103 | self.norm = nn.LayerNorm(latent_dim) 104 | self.text_norm = nn.LayerNorm(text_latent_dim) 105 | 106 | self.query = nn.Linear(latent_dim, latent_dim) 107 | self.key_text = nn.Linear(text_latent_dim, latent_dim) 108 | self.value_text = nn.Linear(text_latent_dim, latent_dim) 109 | self.key_motion = nn.Linear(latent_dim, latent_dim) 110 | self.value_motion = nn.Linear(latent_dim, latent_dim) 111 | 112 | self.dropout = nn.Dropout(dropout) 113 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 114 | 115 | def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): 116 | """ 117 | x: B, T, D 118 | xf: B, N, L 119 | """ 120 | B, T, D = x.shape 121 | N = xf.shape[1] + x.shape[1] 122 | H = self.num_heads 123 | 124 | text_feat = xf 125 | # B, T, D 126 | query = self.query(self.norm(x)).view(B, T, H, -1) 127 | # B, N, D 128 | text_cond_type = (cond_type % 10 > 0).float() 129 | src_mask = src_mask.view(B, T, 1) 130 | key_text = self.key_text(self.text_norm(text_feat)) 131 | key_text = key_text + (1 - text_cond_type) * -1000000 132 | key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 133 | key = torch.cat((key_text, key_motion), dim=1) 134 | 135 | query = F.softmax(query.view(B, T, H, -1), dim=-1) 136 | key = self.dropout(F.softmax(key.view(B, N, H, -1), dim=1)) 137 | value = torch.cat(( 138 | self.value_text(self.text_norm(text_feat)) * text_cond_type, 139 | self.value_motion(self.norm(x)) * src_mask, 140 | ), 141 | dim=1).view(B, N, H, -1) 142 | # B, H, HD, HD 143 | attention = torch.einsum('bnhd,bnhl->bhdl', key, value) 144 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 145 | y = x + self.proj_out(y, emb) 146 | return y 147 | -------------------------------------------------------------------------------- /mogen/models/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.cnn import MODELS as MMCV_MODELS 2 | from mmcv.utils import Registry 3 | 4 | 5 | def build_from_cfg(cfg, registry, default_args=None): 6 | if cfg is None: 7 | return None 8 | return MMCV_MODELS.build_func(cfg, registry, default_args) 9 | 10 | 11 | MODELS = Registry('models', parent=MMCV_MODELS, build_func=build_from_cfg) 12 | 13 | LOSSES = MODELS 14 | ARCHITECTURES = MODELS 15 | SUBMODULES = MODELS 16 | ATTENTIONS = MODELS 17 | 18 | 19 | def build_loss(cfg): 20 | """Build loss.""" 21 | return LOSSES.build(cfg) 22 | 23 | 24 | def build_architecture(cfg): 25 | """Build framework.""" 26 | return ARCHITECTURES.build(cfg) 27 | 28 | 29 | def build_submodule(cfg): 30 | """Build submodule.""" 31 | return SUBMODULES.build(cfg) 32 | 33 | 34 | def build_attention(cfg): 35 | """Build attention.""" 36 | return ATTENTIONS.build(cfg) 37 | -------------------------------------------------------------------------------- /mogen/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan_loss import GANLoss 2 | from .mse_loss import MSELoss 3 | from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, 4 | weighted_loss) 5 | 6 | __all__ = [ 7 | 'convert_to_one_hot', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 8 | 'MSELoss', 'GANLoss' 9 | ] 10 | -------------------------------------------------------------------------------- /mogen/models/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | from ..builder import LOSSES 5 | 6 | 7 | @LOSSES.register_module() 8 | class GANLoss(nn.Module): 9 | """Define GAN loss. 10 | 11 | Args: 12 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. 13 | real_label_val (float): The value for real label. Default: 1.0. 14 | fake_label_val (float): The value for fake label. Default: 0.0. 15 | loss_weight (float): Loss weight. Default: 1.0. 16 | Note that loss_weight is only for generators; and it is always 1.0 17 | for discriminators. 18 | """ 19 | 20 | def __init__(self, 21 | gan_type, 22 | real_label_val=1.0, 23 | fake_label_val=0.0, 24 | loss_weight=1.0): 25 | super().__init__() 26 | self.gan_type = gan_type 27 | self.loss_weight = loss_weight 28 | self.real_label_val = real_label_val 29 | self.fake_label_val = fake_label_val 30 | 31 | if self.gan_type == 'vanilla': 32 | self.loss = nn.BCEWithLogitsLoss() 33 | elif self.gan_type == 'lsgan': 34 | self.loss = nn.MSELoss() 35 | elif self.gan_type == 'wgan': 36 | self.loss = self._wgan_loss 37 | elif self.gan_type == 'hinge': 38 | self.loss = nn.ReLU() 39 | else: 40 | raise NotImplementedError( 41 | f'GAN type {self.gan_type} is not implemented.') 42 | 43 | @staticmethod 44 | def _wgan_loss(input, target): 45 | """wgan loss. 46 | 47 | Args: 48 | input (Tensor): Input tensor. 49 | target (bool): Target label. 50 | Returns: 51 | Tensor: wgan loss. 52 | """ 53 | return -input.mean() if target else input.mean() 54 | 55 | def get_target_label(self, input, target_is_real): 56 | """Get target label. 57 | 58 | Args: 59 | input (Tensor): Input tensor. 60 | target_is_real (bool): Whether the target is real or fake. 61 | Returns: 62 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise, 63 | return Tensor. 64 | """ 65 | 66 | if self.gan_type == 'wgan': 67 | return target_is_real 68 | target_val = (self.real_label_val 69 | if target_is_real else self.fake_label_val) 70 | return input.new_ones(input.size()) * target_val 71 | 72 | def forward(self, input, target_is_real, is_disc=False): 73 | """ 74 | Args: 75 | input (Tensor): The input for the loss module, i.e., the network 76 | prediction. 77 | target_is_real (bool): Whether the targe is real or fake. 78 | is_disc (bool): Whether the loss for discriminators or not. 79 | Default: False. 80 | Returns: 81 | Tensor: GAN loss value. 82 | """ 83 | target_label = self.get_target_label(input, target_is_real) 84 | if self.gan_type == 'hinge': 85 | if is_disc: # for discriminators in hinge-gan 86 | input = -input if target_is_real else input 87 | loss = self.loss(1 + input).mean() 88 | else: # for generators in hinge-gan 89 | loss = -input.mean() 90 | else: # other gan types 91 | loss = self.loss(input, target_label) 92 | 93 | # loss_weight is always 1.0 for discriminators 94 | return loss if is_disc else loss * self.loss_weight 95 | -------------------------------------------------------------------------------- /mogen/models/losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..builder import LOSSES 5 | from .utils import weighted_loss 6 | 7 | 8 | def gmof(x, sigma): 9 | """Geman-McClure error function.""" 10 | x_squared = x**2 11 | sigma_squared = sigma**2 12 | return (sigma_squared * x_squared) / (sigma_squared + x_squared) 13 | 14 | 15 | @weighted_loss 16 | def mse_loss(pred, target): 17 | """Warpper of mse loss.""" 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | @weighted_loss 22 | def mse_loss_with_gmof(pred, target, sigma): 23 | """Extended MSE Loss with GMOF.""" 24 | loss = F.mse_loss(pred, target, reduction='none') 25 | loss = gmof(loss, sigma) 26 | return loss 27 | 28 | 29 | @LOSSES.register_module() 30 | class MSELoss(nn.Module): 31 | """MSELoss. 32 | Args: 33 | reduction (str, optional): The method that reduces the loss to a 34 | scalar. Options are "none", "mean" and "sum". 35 | loss_weight (float, optional): The weight of the loss. Defaults to 1.0 36 | """ 37 | 38 | def __init__(self, reduction='mean', loss_weight=1.0): 39 | super().__init__() 40 | assert reduction in (None, 'none', 'mean', 'sum') 41 | reduction = 'none' if reduction is None else reduction 42 | self.reduction = reduction 43 | self.loss_weight = loss_weight 44 | 45 | def forward(self, 46 | pred, 47 | target, 48 | weight=None, 49 | avg_factor=None, 50 | reduction_override=None): 51 | """Forward function of loss. 52 | Args: 53 | pred (torch.Tensor): The prediction. 54 | target (torch.Tensor): The learning target of the prediction. 55 | weight (torch.Tensor, optional): Weight of the loss for each 56 | prediction. Defaults to None. 57 | avg_factor (int, optional): Average factor that is used to average 58 | the loss. Defaults to None. 59 | reduction_override (str, optional): The reduction method used to 60 | override the original reduction method of the loss. 61 | Defaults to None. 62 | Returns: 63 | torch.Tensor: The calculated loss 64 | """ 65 | assert reduction_override in (None, 'none', 'mean', 'sum') 66 | reduction = (reduction_override 67 | if reduction_override else self.reduction) 68 | loss = self.loss_weight * mse_loss( 69 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 70 | return loss 71 | -------------------------------------------------------------------------------- /mogen/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def reduce_loss(loss, reduction): 8 | """Reduce loss as specified. 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | Return: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | elif reduction_enum == 2: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 26 | """Apply element-wise weight and reduce loss. 27 | Args: 28 | loss (Tensor): Element-wise loss. 29 | weight (Tensor): Element-wise weights. 30 | reduction (str): Same as built-in losses of PyTorch. 31 | avg_factor (float): Average factor when computing the mean of losses. 32 | Returns: 33 | Tensor: Processed loss values. 34 | """ 35 | # if weight is specified, apply element-wise weight 36 | if weight is not None: 37 | loss = loss * weight 38 | 39 | # if avg_factor is not specified, just reduce the loss 40 | if avg_factor is None: 41 | loss = reduce_loss(loss, reduction) 42 | else: 43 | # if reduction is mean, then average the loss by avg_factor 44 | if reduction == 'mean': 45 | loss = loss.sum() / avg_factor 46 | # if reduction is 'none', then do nothing, otherwise raise an error 47 | elif reduction != 'none': 48 | raise ValueError('avg_factor can not be used with reduction="sum"') 49 | return loss 50 | 51 | 52 | def weighted_loss(loss_func): 53 | """Create a weighted version of a given loss function. 54 | To use this decorator, the loss function must have the signature like 55 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 56 | element-wise loss without any reduction. This decorator will add weight 57 | and reduction arguments to the function. The decorated function will have 58 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 59 | avg_factor=None, **kwargs)`. 60 | :Example: 61 | >>> import torch 62 | >>> @weighted_loss 63 | >>> def l1_loss(pred, target): 64 | >>> return (pred - target).abs() 65 | >>> pred = torch.Tensor([0, 2, 3]) 66 | >>> target = torch.Tensor([1, 1, 1]) 67 | >>> weight = torch.Tensor([1, 0, 1]) 68 | >>> l1_loss(pred, target) 69 | tensor(1.3333) 70 | >>> l1_loss(pred, target, weight) 71 | tensor(1.) 72 | >>> l1_loss(pred, target, reduction='none') 73 | tensor([1., 1., 2.]) 74 | >>> l1_loss(pred, target, weight, avg_factor=2) 75 | tensor(1.5000) 76 | """ 77 | 78 | @functools.wraps(loss_func) 79 | def wrapper(pred, 80 | target, 81 | weight=None, 82 | reduction='mean', 83 | avg_factor=None, 84 | **kwargs): 85 | # get element-wise loss 86 | loss = loss_func(pred, target, **kwargs) 87 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 88 | return loss 89 | 90 | return wrapper 91 | 92 | 93 | def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: 94 | """This function converts target class indices to one-hot vectors, given 95 | the number of classes. 96 | Args: 97 | targets (Tensor): The ground truth label of the prediction 98 | with shape (N, 1) 99 | classes (int): the number of classes. 100 | Returns: 101 | Tensor: Processed loss values. 102 | """ 103 | assert (torch.max(targets).item() 104 | < classes), 'Class Index must be less than number of classes' 105 | one_hot_targets = torch.zeros((targets.shape[0], classes), 106 | dtype=torch.long, 107 | device=targets.device) 108 | one_hot_targets.scatter_(1, targets.long(), 1) 109 | return one_hot_targets 110 | -------------------------------------------------------------------------------- /mogen/models/rnns/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2m_bigru import T2MMotionEncoder, T2MTextEncoder 2 | 3 | __all__ = ['T2MMotionEncoder', 'T2MTextEncoder'] 4 | -------------------------------------------------------------------------------- /mogen/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import ACTORDecoder, ACTOREncoder 2 | from .finemogen import FineMoGenTransformer 3 | from .intergen import InterCLIP 4 | from .mdm import MDMTransformer 5 | from .momatmogen import MoMatMoGenTransformer 6 | from .motiondiffuse import MotionDiffuseTransformer 7 | from .remodiffuse import ReMoDiffuseTransformer 8 | 9 | __all__ = [ 10 | 'ACTOREncoder', 'ACTORDecoder', 'MotionDiffuseTransformer', 11 | 'ReMoDiffuseTransformer', 'MDMTransformer', 'FineMoGenTransformer', 12 | 'InterCLIP', 'MoMatMoGenTransformer' 13 | ] 14 | -------------------------------------------------------------------------------- /mogen/models/transformers/momatmogen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mogen.models.utils.misc import zero_module 5 | from mogen.models.utils.position_encoding import timestep_embedding 6 | from mogen.models.utils.stylization_block import StylizationBlock 7 | 8 | from ..builder import SUBMODULES, build_attention 9 | from .remodiffuse import ReMoDiffuseTransformer 10 | 11 | 12 | class FFN(nn.Module): 13 | 14 | def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim): 15 | super().__init__() 16 | self.latent_dim = latent_dim 17 | self.linear1 = nn.Linear(latent_dim, ffn_dim) 18 | self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) 19 | self.activation = nn.GELU() 20 | self.dropout = nn.Dropout(dropout) 21 | self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) 22 | 23 | def forward(self, x, emb, **kwargs): 24 | x1 = x[:, :, :self.latent_dim].contiguous() 25 | x2 = x[:, :, self.latent_dim:].contiguous() 26 | y1 = self.linear2(self.dropout(self.activation(self.linear1(x1)))) 27 | y1 = x1 + self.proj_out(y1, emb) 28 | y2 = self.linear2(self.dropout(self.activation(self.linear1(x2)))) 29 | y2 = x2 + self.proj_out(y2, emb) 30 | y = torch.cat((y1, y2), dim=-1) 31 | return y 32 | 33 | 34 | class DecoderLayer(nn.Module): 35 | 36 | def __init__(self, ca_block_cfg=None, ffn_cfg=None): 37 | super().__init__() 38 | self.ca_block = build_attention(ca_block_cfg) 39 | self.ffn = FFN(**ffn_cfg) 40 | 41 | def forward(self, **kwargs): 42 | if self.ca_block is not None: 43 | x = self.ca_block(**kwargs) 44 | kwargs.update({'x': x}) 45 | if self.ffn is not None: 46 | x = self.ffn(**kwargs) 47 | return x 48 | 49 | 50 | @SUBMODULES.register_module() 51 | class MoMatMoGenTransformer(ReMoDiffuseTransformer): 52 | 53 | def build_temporal_blocks(self, sa_block_cfg, ca_block_cfg, ffn_cfg): 54 | self.temporal_decoder_blocks = nn.ModuleList() 55 | for i in range(self.num_layers): 56 | self.temporal_decoder_blocks.append( 57 | DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg)) 58 | 59 | def forward(self, motion, timesteps, motion_mask=None, **kwargs): 60 | """ 61 | motion: B, T, D 62 | """ 63 | T = motion.shape[1] 64 | conditions = self.get_precompute_condition(device=motion.device, 65 | **kwargs) 66 | if len(motion_mask.shape) == 2: 67 | src_mask = motion_mask.clone().unsqueeze(-1) 68 | else: 69 | src_mask = motion_mask.clone() 70 | 71 | if self.time_embedding_type == 'sinusoidal': 72 | emb = self.time_embed( 73 | timestep_embedding(timesteps, self.latent_dim)) 74 | else: 75 | emb = self.time_embed(self.time_tokens(timesteps)) 76 | 77 | if self.use_text_proj: 78 | emb = emb + conditions['xf_proj'] 79 | # B, T, latent_dim 80 | motion1 = motion[:, :, :self.input_feats].contiguous() 81 | motion2 = motion[:, :, self.input_feats:].contiguous() 82 | h1 = self.joint_embed(motion1) 83 | h2 = self.joint_embed(motion2) 84 | if self.use_pos_embedding: 85 | h1 = h1 + self.sequence_embedding.unsqueeze(0)[:, :T, :] 86 | h2 = h2 + self.sequence_embedding.unsqueeze(0)[:, :T, :] 87 | h = torch.cat((h1, h2), dim=-1) 88 | 89 | if self.training: 90 | output = self.forward_train(h=h, 91 | src_mask=src_mask, 92 | emb=emb, 93 | timesteps=timesteps, 94 | **conditions) 95 | else: 96 | output = self.forward_test(h=h, 97 | src_mask=src_mask, 98 | emb=emb, 99 | timesteps=timesteps, 100 | **conditions) 101 | if self.use_residual_connection: 102 | output = motion + output 103 | return output 104 | 105 | def forward_train(self, 106 | h=None, 107 | src_mask=None, 108 | emb=None, 109 | xf_out=None, 110 | re_dict=None, 111 | **kwargs): 112 | B, T = h.shape[0], h.shape[1] 113 | cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) 114 | for module in self.temporal_decoder_blocks: 115 | h = module(x=h, 116 | xf=xf_out, 117 | emb=emb, 118 | src_mask=src_mask, 119 | cond_type=cond_type, 120 | re_dict=re_dict) 121 | 122 | out1 = self.out(h[:, :, :self.latent_dim].contiguous()) 123 | out1 = out1.view(B, T, -1).contiguous() 124 | out2 = self.out(h[:, :, self.latent_dim:].contiguous()) 125 | out2 = out2.view(B, T, -1).contiguous() 126 | output = torch.cat((out1, out2), dim=-1) 127 | return output 128 | 129 | def forward_test(self, 130 | h=None, 131 | src_mask=None, 132 | emb=None, 133 | xf_out=None, 134 | re_dict=None, 135 | timesteps=None, 136 | **kwargs): 137 | B, T = h.shape[0], h.shape[1] 138 | both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99 139 | text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 140 | retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10 141 | none_cond_type = torch.zeros(B, 1, 1).to(h.device) 142 | 143 | all_cond_type = torch.cat( 144 | (both_cond_type, text_cond_type, retr_cond_type, none_cond_type), 145 | dim=0) 146 | h = h.repeat(4, 1, 1) 147 | xf_out = xf_out.repeat(4, 1, 1) 148 | emb = emb.repeat(4, 1) 149 | src_mask = src_mask.repeat(4, 1, 1) 150 | if re_dict['re_motion'].shape[0] != h.shape[0]: 151 | re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1) 152 | re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1) 153 | re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1) 154 | for module in self.temporal_decoder_blocks: 155 | h = module(x=h, 156 | xf=xf_out, 157 | emb=emb, 158 | src_mask=src_mask, 159 | cond_type=all_cond_type, 160 | re_dict=re_dict) 161 | out1 = self.out(h[:, :, :self.latent_dim].contiguous()) 162 | out1 = out1.view(4 * B, T, -1).contiguous() 163 | out2 = self.out(h[:, :, self.latent_dim:].contiguous()) 164 | out2 = out2.view(4 * B, T, -1).contiguous() 165 | out = torch.cat((out1, out2), dim=-1) 166 | out_both = out[:B].contiguous() 167 | out_text = out[B:2 * B].contiguous() 168 | out_retr = out[2 * B:3 * B].contiguous() 169 | out_none = out[3 * B:].contiguous() 170 | 171 | coef_cfg = self.scale_func(int(timesteps[0])) 172 | both_coef = coef_cfg['both_coef'] 173 | text_coef = coef_cfg['text_coef'] 174 | retr_coef = coef_cfg['retr_coef'] 175 | none_coef = coef_cfg['none_coef'] 176 | output = out_both * both_coef 177 | output += out_text * text_coef 178 | output += out_retr * retr_coef 179 | output += out_none * none_coef 180 | return output 181 | -------------------------------------------------------------------------------- /mogen/models/transformers/motiondiffuse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..builder import SUBMODULES 5 | from .diffusion_transformer import DiffusionTransformer 6 | 7 | 8 | @SUBMODULES.register_module() 9 | class MotionDiffuseTransformer(DiffusionTransformer): 10 | 11 | def __init__(self, **kwargs): 12 | super().__init__(**kwargs) 13 | 14 | def get_precompute_condition(self, 15 | text=None, 16 | xf_proj=None, 17 | xf_out=None, 18 | device=None, 19 | clip_feat=None, 20 | **kwargs): 21 | if xf_proj is None or xf_out is None: 22 | xf_proj, xf_out = self.encode_text(text, clip_feat, device) 23 | return {'xf_proj': xf_proj, 'xf_out': xf_out} 24 | 25 | def post_process(self, motion): 26 | if self.post_process_cfg is not None: 27 | if self.post_process_cfg.get("unnormalized_infer", False): 28 | mean = torch.from_numpy( 29 | np.load(self.post_process_cfg['mean_path'])) 30 | mean = mean.type_as(motion) 31 | std = torch.from_numpy( 32 | np.load(self.post_process_cfg['std_path'])) 33 | std = std.type_as(motion) 34 | motion = motion * std + mean 35 | return motion 36 | 37 | def forward_train(self, 38 | h=None, 39 | src_mask=None, 40 | emb=None, 41 | xf_out=None, 42 | **kwargs): 43 | B, T = h.shape[0], h.shape[1] 44 | for module in self.temporal_decoder_blocks: 45 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) 46 | output = self.out(h).view(B, T, -1).contiguous() 47 | return output 48 | 49 | def forward_test(self, 50 | h=None, 51 | src_mask=None, 52 | emb=None, 53 | xf_out=None, 54 | **kwargs): 55 | B, T = h.shape[0], h.shape[1] 56 | for module in self.temporal_decoder_blocks: 57 | h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) 58 | output = self.out(h).view(B, T, -1).contiguous() 59 | return output 60 | -------------------------------------------------------------------------------- /mogen/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuan-zhang/FineMoGen/d5697b5aa6ad2de0301e77a251c7d28fb177ee23/mogen/models/utils/__init__.py -------------------------------------------------------------------------------- /mogen/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | def set_requires_grad(nets, requires_grad=False): 2 | """Set requies_grad for all the networks. 3 | 4 | Args: 5 | nets (nn.Module | list[nn.Module]): A list of networks or a single 6 | network. 7 | requires_grad (bool): Whether the networks require gradients or not 8 | """ 9 | if not isinstance(nets, list): 10 | nets = [nets] 11 | for net in nets: 12 | if net is not None: 13 | for param in net.parameters(): 14 | param.requires_grad = requires_grad 15 | 16 | 17 | def zero_module(module): 18 | """ 19 | Zero out the parameters of a module and return it. 20 | """ 21 | for p in module.parameters(): 22 | p.detach().zero_() 23 | return module 24 | -------------------------------------------------------------------------------- /mogen/models/utils/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_MLP(dim_list, latent_dim): 5 | model_list = [] 6 | prev = dim_list[0] 7 | for cur in dim_list[1:]: 8 | model_list.append(nn.Linear(prev, cur)) 9 | model_list.append(nn.GELU()) 10 | prev = cur 11 | model_list.append(nn.Linear(prev, latent_dim)) 12 | model = nn.Sequential(*model_list) 13 | return model 14 | -------------------------------------------------------------------------------- /mogen/models/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class SinusoidalPositionalEncoding(nn.Module): 9 | 10 | def __init__(self, d_model, dropout=0.1, max_len=5000): 11 | super(SinusoidalPositionalEncoding, self).__init__() 12 | self.dropout = nn.Dropout(p=dropout) 13 | 14 | pe = torch.zeros(max_len, d_model) 15 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 16 | div_term = torch.arange(0, d_model, 2).float() 17 | div_term = div_term * (-np.log(10000.0) / d_model) 18 | div_term = torch.exp(div_term) 19 | pe[:, 0::2] = torch.sin(position * div_term) 20 | pe[:, 1::2] = torch.cos(position * div_term) 21 | pe = pe.unsqueeze(0).transpose(0, 1) 22 | # T, 1, D 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, x): 26 | x = x + self.pe[:x.shape[0]] 27 | return self.dropout(x) 28 | 29 | 30 | class LearnedPositionalEncoding(nn.Module): 31 | 32 | def __init__(self, d_model, dropout=0.1, max_len=5000): 33 | super(LearnedPositionalEncoding, self).__init__() 34 | self.dropout = nn.Dropout(p=dropout) 35 | self.pe = nn.Parameter(torch.randn(max_len, 1, d_model)) 36 | 37 | def forward(self, x): 38 | x = x + self.pe[:x.shape[0]] 39 | return self.dropout(x) 40 | 41 | 42 | def timestep_embedding(timesteps, dim, max_period=10000): 43 | """ 44 | Create sinusoidal timestep embeddings. 45 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 46 | These may be fractional. 47 | :param dim: the dimension of the output. 48 | :param max_period: controls the minimum frequency of the embeddings. 49 | :return: an [N x dim] Tensor of positional embeddings. 50 | """ 51 | half = dim // 2 52 | idx = torch.arange(start=0, end=half, dtype=torch.float32) 53 | freqs = torch.exp(-math.log(max_period) * idx / 54 | half).to(device=timesteps.device) 55 | args = timesteps[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat( 59 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 60 | return embedding 61 | -------------------------------------------------------------------------------- /mogen/models/utils/stylization_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def zero_module(module): 6 | """ 7 | Zero out the parameters of a module and return it. 8 | """ 9 | for p in module.parameters(): 10 | p.detach().zero_() 11 | return module 12 | 13 | 14 | class StylizationBlock(nn.Module): 15 | 16 | def __init__(self, latent_dim, time_embed_dim, dropout): 17 | super().__init__() 18 | self.emb_layers = nn.Sequential( 19 | nn.SiLU(), 20 | nn.Linear(time_embed_dim, 2 * latent_dim), 21 | ) 22 | self.norm = nn.LayerNorm(latent_dim) 23 | self.out_layers = nn.Sequential( 24 | nn.SiLU(), 25 | nn.Dropout(p=dropout), 26 | zero_module(nn.Linear(latent_dim, latent_dim)), 27 | ) 28 | 29 | def forward(self, h, emb): 30 | """ 31 | h: B, T, D 32 | emb: B, D 33 | """ 34 | # B, 1, 2D 35 | emb_out = self.emb_layers(emb).unsqueeze(1) 36 | # scale: B, 1, D / shift: B, 1, D 37 | scale, shift = torch.chunk(emb_out, 2, dim=2) 38 | h = self.norm(h) * (1 + scale) + shift 39 | h = self.out_layers(h) 40 | return h 41 | -------------------------------------------------------------------------------- /mogen/models/utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from os.path import join as pjoin 3 | 4 | import numpy as np 5 | 6 | POS_enumerator = { 7 | 'VERB': 0, 8 | 'NOUN': 1, 9 | 'DET': 2, 10 | 'ADP': 3, 11 | 'NUM': 4, 12 | 'AUX': 5, 13 | 'PRON': 6, 14 | 'ADJ': 7, 15 | 'ADV': 8, 16 | 'Loc_VIP': 9, 17 | 'Body_VIP': 10, 18 | 'Obj_VIP': 11, 19 | 'Act_VIP': 12, 20 | 'Desc_VIP': 13, 21 | 'OTHER': 14, 22 | } 23 | 24 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 25 | 'forward', 'back', 'backward', 'up', 'down', 'straight', 'curve') 26 | 27 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 28 | 'waist', 'eye', 'knee', 'shoulder', 'thigh') 29 | 30 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 31 | 'handrail', 'baseball', 'basketball') 32 | 33 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 34 | 'throw', 'hop', 'dance', 'jump', 'turn', 'stumble', 'dance', 35 | 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 36 | 'stroll', 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 37 | 'lean', 'rotate', 'spin', 'spread', 'climb') 38 | 39 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 40 | 'happy', 'angry', 'sad', 'happily', 'angrily', 'sadly') 41 | 42 | VIP_dict = { 43 | 'Loc_VIP': Loc_list, 44 | 'Body_VIP': Body_list, 45 | 'Obj_VIP': Obj_List, 46 | 'Act_VIP': Act_list, 47 | 'Desc_VIP': Desc_list, 48 | } 49 | 50 | 51 | class WordVectorizer(object): 52 | 53 | def __init__(self, meta_root, prefix): 54 | vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix)) 55 | words = pickle.load( 56 | open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb')) 57 | word2idx = pickle.load( 58 | open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb')) 59 | self.word2vec = {w: vectors[word2idx[w]] for w in words} 60 | 61 | def _get_pos_ohot(self, pos): 62 | pos_vec = np.zeros(len(POS_enumerator)) 63 | if pos in POS_enumerator: 64 | pos_vec[POS_enumerator[pos]] = 1 65 | else: 66 | pos_vec[POS_enumerator['OTHER']] = 1 67 | return pos_vec 68 | 69 | def __len__(self): 70 | return len(self.word2vec) 71 | 72 | def __getitem__(self, item): 73 | word, pos = item.split('/') 74 | if word in self.word2vec: 75 | word_vec = self.word2vec[word] 76 | vip_pos = None 77 | for key, values in VIP_dict.items(): 78 | if word in values: 79 | vip_pos = key 80 | break 81 | if vip_pos is not None: 82 | pos_vec = self._get_pos_ohot(vip_pos) 83 | else: 84 | pos_vec = self._get_pos_ohot(pos) 85 | else: 86 | word_vec = self.word2vec['unk'] 87 | pos_vec = self._get_pos_ohot('OTHER') 88 | return word_vec, pos_vec 89 | -------------------------------------------------------------------------------- /mogen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from mogen.utils.collect_env import collect_env 2 | from mogen.utils.dist_utils import DistOptimizerHook, allreduce_grads 3 | from mogen.utils.logger import get_root_logger 4 | from mogen.utils.misc import multi_apply, torch_to_numpy 5 | from mogen.utils.path_utils import (Existence, check_input_path, 6 | check_path_existence, check_path_suffix, 7 | prepare_output_path) 8 | 9 | __all__ = [ 10 | 'collect_env', 'DistOptimizerHook', 'allreduce_grads', 'get_root_logger', 11 | 'multi_apply', 'torch_to_numpy', 'Existence', 'check_input_path', 12 | 'check_path_existence', 'check_path_suffix', 'prepare_output_path' 13 | ] 14 | -------------------------------------------------------------------------------- /mogen/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import collect_env as collect_base_env 2 | from mmcv.utils import get_git_hash 3 | 4 | import mogen 5 | 6 | 7 | def collect_env(): 8 | """Collect the information of the running environments.""" 9 | env_info = collect_base_env() 10 | env_info['mogen'] = mogen.__version__ + '+' + get_git_hash()[:7] 11 | return env_info 12 | 13 | 14 | if __name__ == '__main__': 15 | for name, val in collect_env().items(): 16 | print(f'{name}: {val}') 17 | -------------------------------------------------------------------------------- /mogen/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.distributed as dist 4 | from mmcv.runner import OptimizerHook 5 | from torch._utils import (_flatten_dense_tensors, _take_tensors, 6 | _unflatten_dense_tensors) 7 | 8 | 9 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 10 | if bucket_size_mb > 0: 11 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 12 | buckets = _take_tensors(tensors, bucket_size_bytes) 13 | else: 14 | buckets = OrderedDict() 15 | for tensor in tensors: 16 | tp = tensor.type() 17 | if tp not in buckets: 18 | buckets[tp] = [] 19 | buckets[tp].append(tensor) 20 | buckets = buckets.values() 21 | 22 | for bucket in buckets: 23 | flat_tensors = _flatten_dense_tensors(bucket) 24 | dist.all_reduce(flat_tensors) 25 | flat_tensors.div_(world_size) 26 | for tensor, synced in zip( 27 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 28 | tensor.copy_(synced) 29 | 30 | 31 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 32 | grads = [ 33 | param.grad.data for param in params 34 | if param.requires_grad and param.grad is not None 35 | ] 36 | world_size = dist.get_world_size() 37 | if coalesce: 38 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 39 | else: 40 | for tensor in grads: 41 | dist.all_reduce(tensor.div_(world_size)) 42 | 43 | 44 | class DistOptimizerHook(OptimizerHook): 45 | 46 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): 47 | self.grad_clip = grad_clip 48 | self.coalesce = coalesce 49 | self.bucket_size_mb = bucket_size_mb 50 | 51 | def after_train_iter(self, runner): 52 | runner.optimizer.zero_grad() 53 | runner.outputs['loss'].backward() 54 | if self.grad_clip is not None: 55 | self.clip_grads(runner.model.parameters()) 56 | runner.optimizer.step() 57 | -------------------------------------------------------------------------------- /mogen/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | return get_logger('mogen', log_file, log_level) 8 | -------------------------------------------------------------------------------- /mogen/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | 6 | def multi_apply(func, *args, **kwargs): 7 | pfunc = partial(func, **kwargs) if kwargs else func 8 | map_results = map(pfunc, *args) 9 | return tuple(map(list, zip(*map_results))) 10 | 11 | 12 | def torch_to_numpy(x): 13 | assert isinstance(x, torch.Tensor) 14 | return x.detach().cpu().numpy() 15 | -------------------------------------------------------------------------------- /mogen/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | 3 | 4 | def parse_version_info(version_str): 5 | """Parse a version string into a tuple. 6 | Args: 7 | version_str (str): The version string. 8 | Returns: 9 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 10 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 11 | """ 12 | version_info = [] 13 | for x in version_str.split('.'): 14 | if x.isdigit(): 15 | version_info.append(int(x)) 16 | elif x.find('rc') != -1: 17 | patch_version = x.split('rc') 18 | version_info.append(int(patch_version[0])) 19 | version_info.append(f'rc{patch_version[1]}') 20 | return tuple(version_info) 21 | 22 | 23 | version_info = parse_version_info(__version__) 24 | 25 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | scipy 5 | git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | WORK_DIR=$2 3 | GPUS=$3 4 | PORT=${PORT:-29500} 5 | 6 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 7 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 8 | $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:4} -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | set -x 5 | 6 | PARTITION=$1 7 | JOB_NAME=$2 8 | CONFIG=$3 9 | WORK_DIR=$4 10 | CHECKPOINT=$5 11 | GPUS=1 12 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) 13 | CPUS_PER_TASK=${CPUS_PER_TASK:-2} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | PY_ARGS=${@:6} 16 | 17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 18 | srun -p ${PARTITION} \ 19 | --job-name=${JOB_NAME} \ 20 | --gres=gpu:${GPUS_PER_NODE} \ 21 | --ntasks=${GPUS} \ 22 | --ntasks-per-node=${GPUS_PER_NODE} \ 23 | --cpus-per-task=${CPUS_PER_TASK} \ 24 | --kill-on-bad-exit=1 \ 25 | ${SRUN_ARGS} \ 26 | python -u tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | PARTITION=$1 5 | JOB_NAME=$2 6 | CONFIG=$3 7 | WORK_DIR=$4 8 | GPUS=$5 9 | GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-2} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:6} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | -w SG-IDC2-10-51-5-49 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import mmcv 6 | import torch 7 | from mmcv import DictAction 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 10 | wrap_fp16_model) 11 | 12 | from mogen.apis import multi_gpu_test, single_gpu_test 13 | from mogen.datasets import build_dataloader, build_dataset 14 | from mogen.models import build_architecture 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='mogen evaluation') 19 | parser.add_argument('config', help='test config file path') 20 | parser.add_argument('--work-dir', 21 | help='the dir to save evaluation results') 22 | parser.add_argument('checkpoint', help='checkpoint file') 23 | parser.add_argument('--out', help='output result file') 24 | parser.add_argument('--gpu_collect', 25 | action='store_true', 26 | help='whether to use gpu to collect results') 27 | parser.add_argument('--tmpdir', help='tmp dir for writing some results') 28 | parser.add_argument( 29 | '--cfg-options', 30 | nargs='+', 31 | action=DictAction, 32 | help='override some settings in the used config, the key-value pair ' 33 | 'in xxx=yyy format will be merged into config file.') 34 | parser.add_argument('--launcher', 35 | choices=['none', 'pytorch', 'slurm', 'mpi'], 36 | default='none', 37 | help='job launcher') 38 | parser.add_argument('--local_rank', type=int, default=0) 39 | parser.add_argument('--device', 40 | choices=['cpu', 'cuda'], 41 | default='cuda', 42 | help='device used for testing') 43 | args = parser.parse_args() 44 | if 'LOCAL_RANK' not in os.environ: 45 | os.environ['LOCAL_RANK'] = str(args.local_rank) 46 | return args 47 | 48 | 49 | def main(): 50 | args = parse_args() 51 | 52 | cfg = mmcv.Config.fromfile(args.config) 53 | if args.cfg_options is not None: 54 | cfg.merge_from_dict(args.cfg_options) 55 | # set cudnn_benchmark 56 | if cfg.get('cudnn_benchmark', False): 57 | torch.backends.cudnn.benchmark = True 58 | cfg.data.test.test_mode = True 59 | 60 | # init distributed env first, since logger depends on the dist info. 61 | if args.launcher == 'none': 62 | distributed = False 63 | else: 64 | distributed = True 65 | init_dist(args.launcher, **cfg.dist_params) 66 | 67 | # build the dataloader 68 | dataset = build_dataset(cfg.data.test) 69 | # the extra round_up data will be removed during gpu/cpu collect 70 | data_loader = build_dataloader(dataset, 71 | samples_per_gpu=cfg.data.samples_per_gpu, 72 | workers_per_gpu=cfg.data.workers_per_gpu, 73 | dist=distributed, 74 | shuffle=False, 75 | round_up=False) 76 | 77 | # build the model and load checkpoint 78 | model = build_architecture(cfg.model) 79 | fp16_cfg = cfg.get('fp16', None) 80 | if fp16_cfg is not None: 81 | wrap_fp16_model(model) 82 | load_checkpoint(model, args.checkpoint, map_location='cpu') 83 | 84 | if not distributed: 85 | if args.device == 'cpu': 86 | model = model.cpu() 87 | else: 88 | model = MMDataParallel(model, device_ids=[0]) 89 | outputs = single_gpu_test(model, data_loader) 90 | else: 91 | model = MMDistributedDataParallel( 92 | model.cuda(), 93 | device_ids=[torch.cuda.current_device()], 94 | broadcast_buffers=False) 95 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 96 | args.gpu_collect) 97 | 98 | rank, _ = get_dist_info() 99 | if rank == 0: 100 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 101 | results = dataset.evaluate(outputs, args.work_dir) 102 | for k, v in results.items(): 103 | print(f'\n{k} : {v:.4f}') 104 | 105 | if args.out and rank == 0: 106 | print(f'\nwriting results to {args.out}') 107 | mmcv.dump(results, args.out) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | 7 | import mmcv 8 | import torch 9 | from mmcv import Config, DictAction 10 | from mmcv.runner import get_dist_info, init_dist 11 | 12 | from mogen.apis import set_random_seed, train_model 13 | from mogen.datasets import build_dataset 14 | from mogen.models import build_architecture 15 | from mogen.utils import collect_env, get_root_logger 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Train a model') 20 | parser.add_argument('config', help='train config file path') 21 | parser.add_argument('--work-dir', help='the dir to save logs and models') 22 | parser.add_argument('--resume-from', 23 | help='the checkpoint file to resume from') 24 | parser.add_argument( 25 | '--no-validate', 26 | action='store_true', 27 | help='whether not to evaluate the checkpoint during training') 28 | group_gpus = parser.add_mutually_exclusive_group() 29 | group_gpus.add_argument('--device', help='device used for training') 30 | group_gpus.add_argument('--gpus', 31 | type=int, 32 | help='number of gpus to use ' 33 | '(only applicable to non-distributed training)') 34 | group_gpus.add_argument('--gpu-ids', 35 | type=int, 36 | nargs='+', 37 | help='ids of gpus to use ' 38 | '(only applicable to non-distributed training)') 39 | parser.add_argument('--seed', type=int, default=None, help='random seed') 40 | parser.add_argument( 41 | '--deterministic', 42 | action='store_true', 43 | help='whether to set deterministic options for CUDNN backend.') 44 | parser.add_argument('--options', 45 | nargs='+', 46 | action=DictAction, 47 | help='arguments in dict') 48 | parser.add_argument('--launcher', 49 | choices=['none', 'pytorch', 'slurm', 'mpi'], 50 | default='none', 51 | help='job launcher') 52 | parser.add_argument('--local_rank', type=int, default=0) 53 | args = parser.parse_args() 54 | if 'LOCAL_RANK' not in os.environ: 55 | os.environ['LOCAL_RANK'] = str(args.local_rank) 56 | 57 | return args 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | 63 | cfg = Config.fromfile(args.config) 64 | if args.options is not None: 65 | cfg.merge_from_dict(args.options) 66 | # set cudnn_benchmark 67 | if cfg.get('cudnn_benchmark', False): 68 | torch.backends.cudnn.benchmark = True 69 | 70 | # work_dir is determined in this priority: CLI > segment in file > filename 71 | if args.work_dir is not None: 72 | # update configs according to CLI args if args.work_dir is not None 73 | cfg.work_dir = args.work_dir 74 | elif cfg.get('work_dir', None) is None: 75 | # use config filename as default work_dir if cfg.work_dir is None 76 | cfg.work_dir = osp.join('./work_dirs', 77 | osp.splitext(osp.basename(args.config))[0]) 78 | if args.resume_from is not None: 79 | cfg.resume_from = args.resume_from 80 | if args.gpu_ids is not None: 81 | cfg.gpu_ids = args.gpu_ids 82 | else: 83 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 84 | 85 | # init distributed env first, since logger depends on the dist info. 86 | if args.launcher == 'none': 87 | distributed = False 88 | else: 89 | distributed = True 90 | init_dist(args.launcher, **cfg.dist_params) 91 | _, world_size = get_dist_info() 92 | cfg.gpu_ids = range(world_size) 93 | 94 | # create work_dir 95 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 96 | # dump config 97 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 98 | # init the logger before other steps 99 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 100 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 101 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 102 | 103 | # init the meta dict to record some important information such as 104 | # environment info and seed, which will be logged 105 | meta = dict() 106 | # log env info 107 | env_info_dict = collect_env() 108 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 109 | dash_line = '-' * 60 + '\n' 110 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 111 | dash_line) 112 | meta['env_info'] = env_info 113 | 114 | # log some basic info 115 | logger.info(f'Distributed training: {distributed}') 116 | logger.info(f'Config:\n{cfg.pretty_text}') 117 | 118 | # set random seeds 119 | if args.seed is not None: 120 | logger.info(f'Set random seed to {args.seed}, ' 121 | f'deterministic: {args.deterministic}') 122 | set_random_seed(args.seed, deterministic=args.deterministic) 123 | cfg.seed = args.seed 124 | meta['seed'] = args.seed 125 | 126 | model = build_architecture(cfg.model) 127 | model.init_weights() 128 | 129 | datasets = [build_dataset(cfg.data.train)] 130 | if len(cfg.workflow) == 2: 131 | val_dataset = copy.deepcopy(cfg.data.val) 132 | val_dataset.pipeline = cfg.data.train.pipeline 133 | datasets.append(build_dataset(val_dataset)) 134 | # add an attribute for visualization convenience 135 | train_model(model, 136 | datasets, 137 | cfg, 138 | distributed=distributed, 139 | validate=(not args.no_validate), 140 | timestamp=timestamp, 141 | device='cpu' if args.device == 'cpu' else 'cuda', 142 | meta=meta) 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | from mmcv.parallel import MMDataParallel 8 | from mmcv.runner import load_checkpoint 9 | from scipy.ndimage import gaussian_filter 10 | 11 | from mogen.models import build_architecture 12 | from mogen.utils.plot_utils import (plot_3d_motion, plot_siamese_3d_motion, 13 | recover_from_ric, t2m_kinematic_chain) 14 | 15 | 16 | def motion_temporal_filter(motion, sigma=1): 17 | motion = motion.reshape(motion.shape[0], -1) 18 | for i in range(motion.shape[1]): 19 | motion[:, i] = gaussian_filter(motion[:, i], 20 | sigma=sigma, 21 | mode="nearest") 22 | return motion.reshape(motion.shape[0], -1, 3) 23 | 24 | 25 | def plot_t2m(data, motion_length, result_path, npy_path, caption): 26 | joints = recover_from_ric(torch.from_numpy(data).float(), 22).numpy() 27 | joints = motion_temporal_filter(joints, sigma=2.5) 28 | plot_3d_motion(save_path=result_path, 29 | motion_length=motion_length, 30 | kinematic_tree=t2m_kinematic_chain, 31 | joints=joints, 32 | title=caption, 33 | fps=20) 34 | if npy_path is not None: 35 | np.save(npy_path, joints) 36 | 37 | 38 | def plot_interhuman(data, result_path, npy_path, caption): 39 | data = data.reshape(data.shape[0], 2, -1) 40 | joints1 = data[:, 0, :22 * 3].reshape(-1, 22, 3) 41 | joints2 = data[:, 1, :22 * 3].reshape(-1, 22, 3) 42 | joints1 = motion_temporal_filter(joints1, sigma=4.5) 43 | joints2 = motion_temporal_filter(joints2, sigma=4.5) 44 | plot_siamese_3d_motion(save_path=result_path, 45 | kinematic_tree=t2m_kinematic_chain, 46 | mp_joints=[joints1, joints2], 47 | title=caption, 48 | fps=30) 49 | 50 | 51 | def parse_args(): 52 | parser = argparse.ArgumentParser(description='mogen evaluation') 53 | parser.add_argument('config', help='test config file path') 54 | parser.add_argument('checkpoint', help='checkpoint file') 55 | parser.add_argument('--text', help='motion description', nargs='+') 56 | parser.add_argument('--motion_length', 57 | type=int, 58 | help='expected motion length', 59 | nargs='+') 60 | parser.add_argument('--out', help='output animation file') 61 | parser.add_argument('--pose_npy', 62 | help='output pose sequence file', 63 | default=None) 64 | parser.add_argument('--launcher', 65 | choices=['none', 'pytorch', 'slurm', 'mpi'], 66 | default='none', 67 | help='job launcher') 68 | parser.add_argument('--local_rank', type=int, default=0) 69 | parser.add_argument('--device', 70 | choices=['cpu', 'cuda'], 71 | default='cuda', 72 | help='device used for testing') 73 | args = parser.parse_args() 74 | if 'LOCAL_RANK' not in os.environ: 75 | os.environ['LOCAL_RANK'] = str(args.local_rank) 76 | return args 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | 82 | cfg = mmcv.Config.fromfile(args.config) 83 | # set cudnn_benchmark 84 | if cfg.get('cudnn_benchmark', False): 85 | torch.backends.cudnn.benchmark = True 86 | cfg.data.test.test_mode = True 87 | 88 | # build the model and load checkpoint 89 | model = build_architecture(cfg.model) 90 | load_checkpoint(model, args.checkpoint, map_location='cpu') 91 | 92 | if args.device == 'cpu': 93 | model = model.cpu() 94 | else: 95 | model = MMDataParallel(model, device_ids=[0]) 96 | model.eval() 97 | 98 | dataset_name = cfg.data.test.dataset_name 99 | assert dataset_name in ["human_ml3d", "inter_human"] 100 | assert len(args.motion_length) == len(args.text) 101 | max_length = max(args.motion_length) 102 | if dataset_name == "human_ml3d": 103 | input_dim = 263 104 | assert max_length >= 16 and max_length <= 196 105 | elif dataset_name == "inter_human": 106 | input_dim = 524 107 | assert max_length >= 16 and max_length <= 300 108 | mean_path = os.path.join("data", "datasets", dataset_name, "mean.npy") 109 | std_path = os.path.join("data", "datasets", dataset_name, "std.npy") 110 | mean = np.load(mean_path) 111 | std = np.load(std_path) 112 | 113 | device = args.device 114 | num_intervals = len(args.text) 115 | motion = torch.zeros(num_intervals, max_length, input_dim).to(device) 116 | motion_mask = torch.zeros(num_intervals, max_length).to(device) 117 | for i in range(num_intervals): 118 | motion_mask[i, :args.motion_length[i]] = 1 119 | motion_length = torch.Tensor(args.motion_length).long().to(device) 120 | model = model.to(device) 121 | metas = [] 122 | for t in args.text: 123 | metas.append({'text': t}) 124 | input = { 125 | 'motion': motion, 126 | 'motion_mask': motion_mask, 127 | 'motion_length': motion_length, 128 | 'num_intervals': num_intervals, 129 | 'motion_metas': metas, 130 | } 131 | 132 | all_pred_motion = [] 133 | with torch.no_grad(): 134 | input['inference_kwargs'] = {} 135 | output = model(**input) 136 | for i in range(num_intervals): 137 | pred_motion = output[i]['pred_motion'][:int(motion_length[i])] 138 | pred_motion = pred_motion.cpu().detach().numpy() 139 | pred_motion = pred_motion * std + mean 140 | all_pred_motion.append(pred_motion) 141 | pred_motion = np.concatenate(all_pred_motion, axis=0) 142 | 143 | if dataset_name == "human_ml3d": 144 | plot_t2m(data=pred_motion, 145 | motion_length=args.motion_length, 146 | result_path=args.out, 147 | npy_path=args.pose_npy, 148 | caption=args.text) 149 | elif dataset_name == "inter_human": 150 | plot_interhuman(data=pred_motion, 151 | result_path=args.out, 152 | npy_path=args.pose_npy, 153 | caption=args.text) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | --------------------------------------------------------------------------------