├── .gitignore ├── LICENCE ├── README.md ├── configs ├── sampler │ ├── sampler_high_res.yml │ └── sampler_low_res.yml ├── video_transformer │ ├── video_trans_high_res.yml │ └── video_trans_low_res.yml └── vqgan │ ├── vqgan_decompose_high_res.yml │ └── vqgan_decompose_low_res.yml ├── data ├── __init__.py ├── data_sampler.py ├── decompose_dataset.py ├── moving_label_clip_all_rate_dataset.py ├── prefetch_dataloader.py └── sample_identity_dataset.py ├── env.yaml ├── generate_long_video.ipynb ├── img └── teaser.png ├── models ├── __init__.py ├── app_transformer_model.py ├── archs │ ├── __init__.py │ ├── dalle_transformer_arch.py │ ├── einops_exts.py │ ├── rotary_embedding_torch.py │ ├── transformer_arch.py │ └── vqgan_arch.py ├── base_model.py ├── losses │ ├── __init__.py │ └── vqgan_loss.py ├── video_transformer_model.py └── vqgan_decompose_model.py ├── train_dist.py ├── train_sampler.py ├── train_vqvae_iter_dist.py └── utils ├── __init__.py ├── dist_util.py ├── logger.py ├── options.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/* 2 | .ipynb_checkpoints/* 3 | *.pyc 4 | experiments/* 5 | pretrained_models/* 6 | tb_logger/* 7 | datasets/* 8 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Text2Performer: Text-Driven Human Video Generation

4 | 5 |
6 | Yuming Jiang1, 7 | Shuai Yang1, 8 | Tong Liang Koh1, 9 | Wayne Wu2, 10 | Chen Change Loy1, 11 | Ziwei Liu1 12 |
13 |
14 | 1S-Lab, Nanyang Technological University  2Shanghai AI Laboratory 15 |
16 | 17 | [Paper](https://arxiv.org/pdf/2304.08483.pdf) | [Project Page](https://yumingj.github.io/projects/Text2Performer.html) | [Dataset](https://github.com/yumingj/Fashion-Text2Video) | [Video](https://youtu.be/YwhaJUk_qo0) 18 |
19 | 20 | Text2Performer synthesizes human videos by taking the text descriptions as the only input. 21 | 22 |
23 | 24 |
25 | 26 | :open_book: For more visual results, go checkout our project page 27 | 28 |
29 | 30 | ## Installation 31 | **Clone this repo:** 32 | ```bash 33 | git clone https://github.com/yumingj/Text2Performer.git 34 | cd Text2Performer 35 | ``` 36 | 37 | **Dependencies:** 38 | 39 | ```bash 40 | conda env create -f env.yaml 41 | conda activate text2performer 42 | ``` 43 | 44 | ## (1) Dataset Preparation 45 | 46 | In this work, we contribute a human video dataset with rich label and text annotations named [Fashion-Text2Video](https://github.com/yumingj/Fashion-Text2Video) Dataset. 47 | 48 | You can download our processed dataset from this [Google Drive](https://drive.google.com/drive/folders/1NFd_irnw8kgNcu5KfWhRA8RZPdBK5p1I?usp=sharing). 49 | After downloading the dataset, unzip the file and put them under the dataset folder with the following structure: 50 | ``` 51 | ./datasets 52 | ├── FashionDataset_frames_crop 53 | ├── xxxxxx 54 | ├── 000.png 55 | ├── 001.png 56 | ├── ... 57 | ├── xxxxxx 58 | └── xxxxxx 59 | ├── train_frame_num.txt 60 | ├── val_frame_num.txt 61 | ├── test_frame_num.txt 62 | ├── moving_frames.npy 63 | ├── captions_app.json 64 | ├── caption_motion_template.json 65 | ├── action_label 66 | ├── xxxxxx.txt 67 | ├── xxxxxx.txt 68 | ├── ... 69 | └── xxxxxx.txt 70 | └── shhq_dataset % optional 71 | ``` 72 | 73 | ## (2) Sampling 74 | 75 | ### Pretrained Models 76 | 77 | Pretrained models can be downloaded from the [Google Drive](https://drive.google.com/drive/folders/1Dgg0EaldNfyPhykHw1TYrm4qme3CqrDz?usp=sharing). Unzip the file and put them under the pretrained_models folder with the following structure: 78 | ``` 79 | pretrained_models 80 | ├── sampler_high_res.pth 81 | ├── video_trans_high_res.pth 82 | └── vqgan_decomposed_high_res.pth 83 | ``` 84 | 85 | After downloading pretrained models, you can use ```generate_long_video.ipynb``` to generate videos. 86 | 87 | ## (3) Training Text2Performer 88 | ### Stage I: Decomposed VQGAN 89 | Train the decomposed VQGAN. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1G59bRoOUEQA8xljRDsfyiw6g8spV3Y7_/view?usp=sharing). 90 | 91 | For better performance, we also use the data from [SHHQ dataset](https://github.com/stylegan-human/StyleGAN-Human/blob/main/docs/Dataset.md) to train this stage. 92 | ```python 93 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=29596 train_vqvae_iter_dist.py -opt ./configs/vqgan/vqgan_decompose_high_res.yml --launcher pytorch 94 | ``` 95 | 96 | ### Stage II: Video Transformer 97 | Train the video transformer. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1QRQlhl8z4-BQfmUvHoVrJnSpxQaKDPZH/view?usp=sharing). 98 | ```python 99 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=29596 train_dist.py -opt ./configs/video_transformer/video_trans_high_res.yml --launcher pytorch 100 | ``` 101 | 102 | ### Stage III: Appearance Transformer 103 | Train the appearance transformer. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/19nYQT511XsBzq1sMUc2MmfpDKT7HVi8Z/view?usp=sharing). 104 | ```python 105 | python train_sampler.py -opt ./configs/sampler/sampler_high_res.yml 106 | ``` 107 | 108 | ## Citation 109 | 110 | If you find this work useful for your research, please consider citing our paper: 111 | 112 | ```bibtex 113 | @inproceedings{jiang2023text2performer, 114 | title={Text2Performer: Text-Driven Human Video Generation}, 115 | author={Jiang, Yuming and Yang, Shuai and Koh, Tong Liang and Wu, Wayne and Loy, Chen Change and Liu, Ziwei}, 116 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 117 | year={2023} 118 | } 119 | ``` 120 | 121 | ## :newspaper_roll: License 122 | 123 | Distributed under the S-Lab License. See `LICENSE` for more information. 124 | 125 | ![visitor badge](https://visitor-badge.glitch.me/badge?page_id=yumingj/Text2Performer&left_color=red&right_color=green&left_text=HelloVisitors) 126 | 127 | -------------------------------------------------------------------------------- /configs/sampler/sampler_high_res.yml: -------------------------------------------------------------------------------- 1 | name: sampler_high_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 1 9 | datasets: 10 | train: 11 | video_dir: ./datasets/FashionDataset_frames_crop 12 | data_name_txt: ./datasets/train_frame_num.txt 13 | text_file: ./datasets/captions_app.json 14 | downsample_factor: 1 15 | xflip: True 16 | 17 | val: 18 | video_dir: ./datasets/FashionDataset_frames_crop 19 | data_name_txt: ./datasets/val_frame_num.txt 20 | text_file: ./datasets/captions_app.json 21 | downsample_factor: 1 22 | xflip: False 23 | 24 | test: 25 | video_dir: ./datasets/FashionDataset_frames_crop 26 | data_name_txt: ./datasets/test_frame_num.txt 27 | text_file: ./datasets/captions_app.json 28 | downsample_factor: 1 29 | xflip: False 30 | 31 | # pretrained models 32 | img_ae_path: ./pretrained_models/vqgan_decomposed_high_res.pth 33 | 34 | model_type: AppTransformerModel 35 | # network configs 36 | 37 | # image autoencoder 38 | img_embed_dim: 256 39 | img_n_embed: 1024 40 | img_double_z: false 41 | img_z_channels: 256 42 | img_resolution: 512 43 | img_in_channels: 3 44 | img_out_ch: 3 45 | img_ch: 128 46 | img_ch_mult: [1, 1, 2, 2, 4] 47 | img_other_ch_mult: [4, 4] 48 | img_num_res_blocks: 2 49 | img_attn_resolutions: [32] 50 | img_dropout: 0.0 51 | 52 | # sampler configs 53 | codebook_size: 1024 54 | bert_n_emb: 512 55 | bert_n_layers: 24 56 | bert_n_head: 8 57 | block_size: 512 # 32 x 16 58 | latent_shape: [32, 16] 59 | embd_pdrop: 0.0 60 | resid_pdrop: 0.0 61 | attn_pdrop: 0.0 62 | 63 | # loss configs 64 | loss_type: reweighted_elbo 65 | mask_schedule: random 66 | 67 | sample_steps: 64 68 | 69 | # training configs 70 | val_freq: 50 71 | print_freq: 100 72 | weight_decay: 0 73 | manual_seed: 2021 74 | num_epochs: 1000 75 | lr: !!float 1e-4 76 | lr_decay: step 77 | gamma: 1.0 78 | step: 50 79 | 80 | text_seq_len: 50 81 | -------------------------------------------------------------------------------- /configs/sampler/sampler_low_res.yml: -------------------------------------------------------------------------------- 1 | name: sampler_low_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | 6 | # dataset configs 7 | batch_size: 4 8 | num_workers: 1 9 | datasets: 10 | train: 11 | video_dir: ./datasets/FashionDataset_frames_crop 12 | data_name_txt: ./datasets/train_frame_num.txt 13 | text_file: ./datasets/captions_app.json 14 | downsample_factor: 2 15 | xflip: True 16 | 17 | val: 18 | video_dir: ./datasets/FashionDataset_frames_crop 19 | data_name_txt: ./datasets/val_frame_num.txt 20 | text_file: ./datasets/captions_app.json 21 | downsample_factor: 2 22 | xflip: False 23 | 24 | test: 25 | video_dir: ./datasets/FashionDataset_frames_crop 26 | data_name_txt: ./datasets/test_frame_num.txt 27 | text_file: ./datasets/captions_app.json 28 | downsample_factor: 2 29 | xflip: False 30 | 31 | # pretrained models 32 | img_ae_path: ./pretrained_models/vqgan_decomposed_low_res.pth 33 | 34 | model_type: AppTransformerModel 35 | # network configs 36 | 37 | # image autoencoder 38 | img_embed_dim: 256 39 | img_n_embed: 1024 40 | img_double_z: false 41 | img_z_channels: 256 42 | img_resolution: 512 43 | img_in_channels: 3 44 | img_out_ch: 3 45 | img_ch: 128 46 | img_ch_mult: [1, 1, 2, 2, 4] 47 | img_other_ch_mult: [4, 4] 48 | img_num_res_blocks: 2 49 | img_attn_resolutions: [32] 50 | img_dropout: 0.0 51 | 52 | # sampler configs 53 | codebook_size: 1024 54 | bert_n_emb: 512 55 | bert_n_layers: 24 56 | bert_n_head: 8 57 | block_size: 128 # 32 x 16 58 | latent_shape: [16, 8] 59 | embd_pdrop: 0.0 60 | resid_pdrop: 0.0 61 | attn_pdrop: 0.0 62 | 63 | # loss configs 64 | loss_type: reweighted_elbo 65 | mask_schedule: random 66 | 67 | sample_steps: 64 68 | 69 | # training configs 70 | val_freq: 50 71 | print_freq: 100 72 | weight_decay: 0 73 | manual_seed: 2021 74 | num_epochs: 1000 75 | lr: !!float 1e-4 76 | lr_decay: step 77 | gamma: 1.0 78 | step: 50 79 | 80 | text_seq_len: 50 81 | -------------------------------------------------------------------------------- /configs/video_transformer/video_trans_high_res.yml: -------------------------------------------------------------------------------- 1 | name: video_trans_high_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | num_gpu: 4 6 | 7 | # dataset configs 8 | batch_size_per_gpu: 1 9 | num_worker_per_gpu: 0 10 | datasets: 11 | train: 12 | type: MovingLabelsClipAllRateTextDataset 13 | video_dir: ./datasets/FashionDataset_frames_crop 14 | action_label_folder: ./datasets/action_label 15 | data_name_txt: ./datasets/train_frame_num.txt 16 | frame_sample_rate: 1 17 | fixed_video_len: 8 18 | moving_frame_dict: ./datasets/moving_frames.npy 19 | overall_caption_templates_file: ./datasets/caption_motion_template.json 20 | interpolation_rate: 0.2 21 | downsample_factor: 1 22 | random_start: True 23 | xflip: False 24 | 25 | val: 26 | type: MovingLabelsClipAllRateTextDataset 27 | video_dir: ./datasets/FashionDataset_frames_crop 28 | action_label_folder: ./datasets/action_label 29 | data_name_txt: ./datasets/val_frame_num.txt 30 | frame_sample_rate: 1 31 | fixed_video_len: 8 32 | moving_frame_dict: ./datasets/moving_frames.npy 33 | overall_caption_templates_file: ./datasets/caption_motion_template.json 34 | downsample_factor: 1 35 | interpolation_rate: 0.0 36 | random_start: True 37 | xflip: False 38 | 39 | test: 40 | type: MovingLabelsClipAllRateTextDataset 41 | video_dir: ./datasets/FashionDataset_frames_crop 42 | action_label_folder: ./datasets/action_label 43 | data_name_txt: ./datasets/test_frame_num.txt 44 | frame_sample_rate: 1 45 | moving_frame_dict: ./datasets/moving_frames.npy 46 | fixed_video_len: 8 47 | overall_caption_templates_file: ./datasets/caption_motion_template.json 48 | downsample_factor: 1 49 | interpolation_rate: 0.0 50 | random_start: True 51 | xflip: False 52 | 53 | prefetch_mode: ~ 54 | 55 | # pretrained models 56 | img_ae_path: ./pretrained_models/vqgan_decomposed_high_res.pth 57 | 58 | model_type: VideoTransformerModel 59 | 60 | # network configs 61 | # image autoencoder 62 | img_embed_dim: 256 63 | img_n_embed: 1024 64 | img_double_z: false 65 | img_z_channels: 256 66 | img_resolution: 512 67 | img_in_channels: 3 68 | img_out_ch: 3 69 | img_ch: 128 70 | img_ch_mult: [1, 1, 2, 2, 4] 71 | img_other_ch_mult: [4, 4] 72 | img_num_res_blocks: 2 73 | img_attn_resolutions: [32] 74 | img_dropout: 0.0 75 | 76 | # sampler configs 77 | dim: 128 78 | depth: 6 79 | dim_head: 64 80 | heads: 12 81 | ff_mult: 4 82 | norm_out: true 83 | attn_dropout: 0.0 84 | ff_dropout: 0.0 85 | final_proj: true 86 | normformer: true 87 | rotary_emb: true 88 | latent_shape: [8, 4] 89 | action_label_num: 23 90 | 91 | # training configs 92 | val_freq: 50 93 | print_freq: 100 94 | weight_decay: 0 95 | manual_seed: 2022 96 | num_epochs: 1000 97 | lr: !!float 1e-4 98 | lr_decay: step 99 | gamma: 1.0 100 | step: 50 101 | perceptual_weight: 1.0 102 | 103 | larger_ratio: 3 104 | 105 | num_inside_timesteps: 24 106 | inside_ratio: 0 -------------------------------------------------------------------------------- /configs/video_transformer/video_trans_low_res.yml: -------------------------------------------------------------------------------- 1 | name: video_trans_low_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | num_gpu: 4 6 | 7 | # dataset configs 8 | batch_size_per_gpu: 1 9 | num_worker_per_gpu: 0 10 | datasets: 11 | train: 12 | type: MovingLabelsClipAllRateTextDataset 13 | video_dir: ./datasets/FashionDataset_frames_crop 14 | action_label_folder: ./datasets/action_label 15 | data_name_txt: ./datasets/train_frame_num.txt 16 | frame_sample_rate: 1 17 | fixed_video_len: 20 18 | moving_frame_dict: ./datasets/moving_frames.npy 19 | overall_caption_templates_file: ./datasets/caption_motion_template.json 20 | interpolation_rate: 0.2 21 | downsample_factor: 2 22 | random_start: True 23 | xflip: False 24 | 25 | val: 26 | type: MovingLabelsClipAllRateTextDataset 27 | video_dir: ./datasets/FashionDataset_frames_crop 28 | action_label_folder: ./datasets/action_label 29 | data_name_txt: ./datasets/val_frame_num.txt 30 | frame_sample_rate: 1 31 | fixed_video_len: 20 32 | moving_frame_dict: ./datasets/moving_frames.npy 33 | overall_caption_templates_file: ./datasets/caption_motion_template.json 34 | downsample_factor: 2 35 | interpolation_rate: 0.0 36 | random_start: True 37 | xflip: False 38 | 39 | test: 40 | type: MovingLabelsClipAllRateTextDataset 41 | video_dir: ./datasets/FashionDataset_frames_crop 42 | action_label_folder: ./datasets/action_label 43 | data_name_txt: ./datasets/test_frame_num.txt 44 | frame_sample_rate: 1 45 | moving_frame_dict: ./datasets/moving_frames.npy 46 | fixed_video_len: 20 47 | overall_caption_templates_file: ./datasets/caption_motion_template.json 48 | downsample_factor: 2 49 | interpolation_rate: 0.0 50 | random_start: True 51 | xflip: False 52 | 53 | prefetch_mode: ~ 54 | 55 | # pretrained models 56 | img_ae_path: ./pretrained_models/vqgan_decomposed_low_res.pth 57 | 58 | model_type: VideoTransformerModel 59 | 60 | # network configs 61 | # image autoencoder 62 | img_embed_dim: 256 63 | img_n_embed: 1024 64 | img_double_z: false 65 | img_z_channels: 256 66 | img_resolution: 512 67 | img_in_channels: 3 68 | img_out_ch: 3 69 | img_ch: 128 70 | img_ch_mult: [1, 1, 2, 2, 4] 71 | img_other_ch_mult: [4, 4] 72 | img_num_res_blocks: 2 73 | img_attn_resolutions: [32] 74 | img_dropout: 0.0 75 | 76 | # sampler configs 77 | dim: 128 78 | depth: 6 79 | dim_head: 64 80 | heads: 12 81 | ff_mult: 4 82 | norm_out: true 83 | attn_dropout: 0.0 84 | ff_dropout: 0.0 85 | final_proj: true 86 | normformer: true 87 | rotary_emb: true 88 | latent_shape: [4, 2] 89 | action_label_num: 23 90 | 91 | # training configs 92 | val_freq: 50 93 | print_freq: 100 94 | weight_decay: 0 95 | manual_seed: 2022 96 | num_epochs: 1000 97 | lr: !!float 1e-4 98 | lr_decay: step 99 | gamma: 1.0 100 | step: 50 101 | perceptual_weight: 1.0 102 | 103 | larger_ratio: 6 104 | 105 | num_inside_timesteps: 6 106 | inside_ratio: 0 -------------------------------------------------------------------------------- /configs/vqgan/vqgan_decompose_high_res.yml: -------------------------------------------------------------------------------- 1 | name: vqgan_decompose_high_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | num_gpu: 4 6 | 7 | # dataset configs 8 | batch_size_per_gpu: 4 9 | num_worker_per_gpu: 1 10 | datasets: 11 | train: 12 | type: DecomposeMixDataset 13 | video_dir: ./datasets/FashionDataset_frames_crop 14 | data_name_txt: ./datasets/train_frame_num.txt 15 | shhq_data_dir: ./datasets/shhq_data 16 | downsample_factor: 1 17 | xflip: True 18 | 19 | val: 20 | type: DecomposeDataset 21 | video_dir: ./datasets/FashionDataset_frames_crop 22 | data_name_txt: ./datasets/val_frame_num.txt 23 | downsample_factor: 1 24 | xflip: False 25 | 26 | test: 27 | type: DecomposeDataset 28 | video_dir: ./datasets/FashionDataset_frames_crop 29 | data_name_txt: ./datasets/test_frame_num.txt 30 | downsample_factor: 1 31 | xflip: False 32 | 33 | 34 | model_type: VQGANDecomposeModel 35 | # network configs 36 | embed_dim: 256 37 | n_embed: 1024 38 | double_z: false 39 | z_channels: 256 40 | resolution: 512 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: [1, 1, 2, 2, 4] 45 | other_ch_mult: [4, 4] 46 | num_res_blocks: 2 47 | attn_resolutions: [32] 48 | dropout: 0.0 49 | 50 | disc_layers: 3 51 | disc_weight_max: 1 52 | disc_start_step: 40001 53 | n_channels: 3 54 | ndf: 64 55 | nf: 128 56 | perceptual_weight: 1.0 57 | 58 | num_segm_classes: 24 59 | 60 | 61 | # training configs 62 | val_freq: 5000 63 | print_freq: 10 64 | weight_decay: 0 65 | manual_seed: 2021 66 | num_epochs: 100000 67 | lr: !!float 1.0e-04 68 | lr_decay: step 69 | gamma: 1.0 70 | step: 50 71 | 72 | random_dropout: 1.0 -------------------------------------------------------------------------------- /configs/vqgan/vqgan_decompose_low_res.yml: -------------------------------------------------------------------------------- 1 | name: vqgan_decompose_low_res 2 | use_tb_logger: true 3 | set_CUDA_VISIBLE_DEVICES: ~ 4 | gpu_ids: [3] 5 | num_gpu: 4 6 | 7 | # dataset configs 8 | batch_size_per_gpu: 16 9 | num_worker_per_gpu: 4 10 | datasets: 11 | train: 12 | type: DecomposeMixDataset 13 | video_dir: ./datasets/FashionDataset_frames_crop 14 | data_name_txt: ./datasets/train_frame_num.txt 15 | shhq_data_dir: ./datasets/shhq_dataset 16 | downsample_factor: 2 17 | xflip: True 18 | 19 | val: 20 | type: DecomposeDataset 21 | video_dir: ./datasets/FashionDataset_frames_crop 22 | data_name_txt: ./datasets/val_frame_num.txt 23 | downsample_factor: 2 24 | xflip: False 25 | 26 | test: 27 | type: DecomposeDataset 28 | video_dir: ./datasets/FashionDataset_frames_crop 29 | data_name_txt: ./datasets/test_frame_num.txt 30 | downsample_factor: 2 31 | xflip: False 32 | 33 | 34 | model_type: VQGANDecomposeModel 35 | # network configs 36 | embed_dim: 256 37 | n_embed: 1024 38 | double_z: false 39 | z_channels: 256 40 | resolution: 512 41 | in_channels: 3 42 | out_ch: 3 43 | ch: 128 44 | ch_mult: [1, 1, 2, 2, 4] 45 | other_ch_mult: [4, 4] 46 | num_res_blocks: 2 47 | attn_resolutions: [32] 48 | dropout: 0.0 49 | 50 | disc_layers: 1 51 | disc_weight_max: 1 52 | disc_start_step: 10001 53 | n_channels: 3 54 | ndf: 64 55 | nf: 128 56 | perceptual_weight: 1.0 57 | 58 | num_segm_classes: 24 59 | 60 | 61 | # training configs 62 | val_freq: 5000 63 | print_freq: 10 64 | weight_decay: 0 65 | manual_seed: 2021 66 | num_epochs: 100000 67 | lr: !!float 1.0e-04 68 | lr_decay: step 69 | gamma: 1.0 70 | step: 50 71 | 72 | random_dropout: 1.0 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import random 4 | from functools import partial 5 | from os import path as osp 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | 11 | from data.prefetch_dataloader import PrefetchDataLoader 12 | from utils.dist_util import get_dist_info 13 | from utils.logger import get_root_logger 14 | 15 | __all__ = ['create_dataset', 'create_dataloader'] 16 | 17 | # automatically scan and import dataset modules 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [ 21 | osp.splitext(osp.basename(v))[0] 22 | for v in glob.glob(f'{data_folder}/*_dataset.py') 23 | ] 24 | # import all the dataset modules 25 | _dataset_modules = [ 26 | importlib.import_module(f'data.{file_name}') 27 | for file_name in dataset_filenames 28 | ] 29 | 30 | 31 | def create_dataset(dataset_opt): 32 | """Create dataset. 33 | 34 | Args: 35 | dataset_opt (dict): Configuration for dataset. It constains: 36 | name (str): Dataset name. 37 | type (str): Dataset type. 38 | """ 39 | dataset_type = dataset_opt['type'] 40 | 41 | # dynamic instantiation 42 | for module in _dataset_modules: 43 | dataset_cls = getattr(module, dataset_type, None) 44 | if dataset_cls is not None: 45 | break 46 | if dataset_cls is None: 47 | raise ValueError(f'Dataset {dataset_type} is not found.') 48 | 49 | dataset = dataset_cls(dataset_opt) 50 | 51 | logger = get_root_logger() 52 | logger.info( 53 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 54 | 'is created.') 55 | 56 | return dataset 57 | 58 | 59 | def create_dataloader(dataset, 60 | dataset_opt, 61 | phase, 62 | num_gpu=1, 63 | dist=False, 64 | sampler=None, 65 | seed=None): 66 | """Create dataloader. 67 | 68 | Args: 69 | dataset (torch.utils.data.Dataset): Dataset. 70 | dataset_opt (dict): Dataset options. It contains the following keys: 71 | phase (str): 'train' or 'val'. 72 | num_worker_per_gpu (int): Number of workers for each GPU. 73 | batch_size_per_gpu (int): Training batch size for each GPU. 74 | num_gpu (int): Number of GPUs. Used only in the train phase. 75 | Default: 1. 76 | dist (bool): Whether in distributed training. Used only in the train 77 | phase. Default: False. 78 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 79 | seed (int | None): Seed. Default: None 80 | """ 81 | rank, _ = get_dist_info() 82 | if phase == 'train': 83 | if dist: # distributed training 84 | batch_size = dataset_opt['batch_size_per_gpu'] 85 | num_workers = dataset_opt['num_worker_per_gpu'] 86 | else: # non-distributed training 87 | multiplier = 1 if num_gpu == 0 else num_gpu 88 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 89 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 90 | dataloader_args = dict( 91 | dataset=dataset, 92 | batch_size=batch_size, 93 | shuffle=False, 94 | num_workers=num_workers, 95 | sampler=sampler, 96 | drop_last=True) 97 | if sampler is None: 98 | dataloader_args['shuffle'] = True 99 | dataloader_args['worker_init_fn'] = partial( 100 | worker_init_fn, num_workers=num_workers, rank=rank, 101 | seed=seed) if seed is not None else None 102 | elif phase in ['val', 'test']: # validation 103 | dataloader_args = dict( 104 | dataset=dataset, batch_size=1, shuffle=False, num_workers=1) 105 | else: 106 | raise ValueError(f'Wrong dataset phase: {phase}. ' 107 | "Supported ones are 'train', 'val' and 'test'.") 108 | 109 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 110 | 111 | prefetch_mode = dataset_opt.get('prefetch_mode') 112 | if prefetch_mode == 'cpu': # CPUPrefetcher 113 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 114 | logger = get_root_logger() 115 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 116 | f'num_prefetch_queue = {num_prefetch_queue}') 117 | return PrefetchDataLoader( 118 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 119 | else: 120 | # prefetch_mode=None: Normal dataloader 121 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 122 | return torch.utils.data.DataLoader(**dataloader_args) 123 | 124 | 125 | def worker_init_fn(worker_id, num_workers, rank, seed): 126 | # Set the worker seed to num_workers * rank + worker_id + seed 127 | worker_seed = num_workers * rank + worker_id + seed 128 | np.random.seed(worker_seed) 129 | random.seed(worker_seed) 130 | -------------------------------------------------------------------------------- /data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class EnlargedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | Modified from torch.utils.data.distributed.DistributedSampler 11 | Support enlarging the dataset for iteration-based training, for saving 12 | time when restart the dataloader after each epoch 13 | 14 | Args: 15 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 16 | num_replicas (int | None): Number of processes participating in 17 | the training. It is usually the world_size. 18 | rank (int | None): Rank of the current process within num_replicas. 19 | ratio (int): Enlarging ratio. Default: 1. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas, rank, ratio=1): 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = math.ceil( 28 | len(self.dataset) * ratio / self.num_replicas) 29 | self.total_size = self.num_samples * self.num_replicas 30 | 31 | def __iter__(self): 32 | # deterministically shuffle based on epoch 33 | g = torch.Generator() 34 | g.manual_seed(self.epoch) 35 | indices = torch.randperm(self.total_size, generator=g).tolist() 36 | 37 | dataset_size = len(self.dataset) 38 | indices = [v % dataset_size for v in indices] 39 | 40 | # subsample 41 | indices = indices[self.rank:self.total_size:self.num_replicas] 42 | assert len(indices) == self.num_samples 43 | 44 | return iter(indices) 45 | 46 | def __len__(self): 47 | return self.num_samples 48 | 49 | def set_epoch(self, epoch): 50 | self.epoch = epoch 51 | -------------------------------------------------------------------------------- /data/decompose_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | 12 | 13 | class DecomposeDataset(data.Dataset): 14 | 15 | def __init__(self, opt): 16 | self._video_dir = opt['video_dir'] 17 | self.downsample_factor = opt['downsample_factor'] 18 | 19 | self._video_names = [] 20 | self._frame_nums = [] 21 | 22 | self.transform = transforms.Compose([ 23 | transforms.ColorJitter(brightness=.5, hue=.3), 24 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) 25 | ]) 26 | 27 | self.xflip = opt['xflip'] 28 | 29 | video_names_list = open(opt['data_name_txt'], 'r').readlines() 30 | 31 | for row in video_names_list: 32 | video_name = row.split()[0] 33 | frame_num = int(row.split()[1]) 34 | self._video_names.append(video_name) 35 | self._frame_nums.append(frame_num) 36 | 37 | def _load_raw_image(self, img_path): 38 | with open(img_path, 'rb') as f: 39 | image = Image.open(f) 40 | image.load() 41 | if self.downsample_factor != 1: 42 | width, height = image.size 43 | width = width // self.downsample_factor 44 | height = height // self.downsample_factor 45 | image = image.resize( 46 | size=(width, height), resample=Image.LANCZOS) 47 | 48 | return image 49 | 50 | def __getitem__(self, index): 51 | video_name = self._video_names[index] 52 | 53 | random_frame_idx = random.randint(30, self._frame_nums[index] - 30) 54 | 55 | img_path = f'{self._video_dir}/{video_name}/{random_frame_idx:03d}.png' 56 | random_frame = self._load_raw_image(img_path) 57 | 58 | random_frame_aug = self.transform(random_frame) 59 | 60 | random_frame = np.array(random_frame).transpose(2, 0, 61 | 1).astype(np.float32) 62 | random_frame_aug = np.array(random_frame_aug).transpose( 63 | 2, 0, 1).astype(np.float32) 64 | 65 | identity_image = self._load_raw_image( 66 | f'{self._video_dir}/{video_name}/000.png') 67 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype( 68 | np.float32) 69 | 70 | if self.xflip and random.random() > 0.5: 71 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W] 72 | random_frame_aug = random_frame_aug[:, :, ::-1].copy() 73 | 74 | random_frame = random_frame / 127.5 - 1 75 | random_frame_aug = random_frame_aug / 127.5 - 1 76 | identity_image = identity_image / 127.5 - 1 77 | 78 | random_frame = torch.from_numpy(random_frame) 79 | identity_image = torch.from_numpy(identity_image) 80 | 81 | return_dict = { 82 | # 'densepose': pose, 83 | 'video_name': f'{video_name}_{random_frame_idx:03d}', 84 | 'frame_img': random_frame, 85 | 'frame_img_aug': random_frame_aug, 86 | 'identity_image': identity_image 87 | } 88 | 89 | return return_dict 90 | 91 | def __len__(self): 92 | return len(self._video_names) 93 | 94 | 95 | class DecomposeMixDataset(data.Dataset): 96 | 97 | def __init__(self, opt): 98 | self._video_dir = opt['video_dir'] 99 | self.downsample_factor = opt['downsample_factor'] 100 | 101 | _video_names = [] 102 | _frame_nums = [] 103 | 104 | self.transform = transforms.Compose([ 105 | transforms.ColorJitter(brightness=.5, hue=.3), 106 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) 107 | ]) 108 | 109 | self.translate_transform = transforms.Compose( 110 | [transforms.RandomAffine(degrees=0, translate=(0.1, 0.3))]) 111 | 112 | self.xflip = opt['xflip'] 113 | 114 | video_names_list = open(opt['data_name_txt'], 'r').readlines() 115 | 116 | for row in video_names_list: 117 | video_name = row.split()[0] 118 | frame_num = int(row.split()[1]) 119 | _video_names.append(video_name) 120 | _frame_nums.append(frame_num) 121 | 122 | self._shhq_data_dir = opt['shhq_data_dir'] 123 | self._ann_dir = opt['ann_dir'] 124 | 125 | _image_fnames = [] 126 | # for idx, row in enumerate( 127 | # open(os.path.join(f'{self._ann_dir}/upper_fused.txt'), 'r')): 128 | # annotations = row.split() 129 | # if len(annotations[:-1]) == 1: 130 | # img_name = annotations[0] 131 | # else: 132 | # img_name = '' 133 | # for name in annotations[:-1]: 134 | # img_name += f'{name}\xa0' 135 | # img_name = img_name[:-1] 136 | # _image_fnames.append(img_name) 137 | shhq_path_list = glob.glob(f'{self._shhq_data_dir}/*.png') 138 | for shhq_path in shhq_path_list: 139 | _image_fnames.append(shhq_path.split('/')[-1]) 140 | 141 | augment_times = max(1, len(_image_fnames) // len(_video_names)) 142 | 143 | augmented_videos = _video_names * augment_times 144 | self._frame_nums = _frame_nums * augment_times 145 | self._all_file_name = augmented_videos + _image_fnames 146 | self.video_num = len(augmented_videos) 147 | 148 | def _load_raw_image(self, img_path): 149 | with open(img_path, 'rb') as f: 150 | image = Image.open(f) 151 | image.load() 152 | if self.downsample_factor != 1: 153 | width, height = image.size 154 | width = width // self.downsample_factor 155 | height = height // self.downsample_factor 156 | image = image.resize( 157 | size=(width, height), resample=Image.LANCZOS) 158 | 159 | return image 160 | 161 | def sample_video_data(self, index): 162 | video_name = self._all_file_name[index] 163 | 164 | random_frame_idx = random.randint(30, self._frame_nums[index] - 30) 165 | 166 | img_path = f'{self._video_dir}/{video_name}/{random_frame_idx:03d}.png' 167 | random_frame = self._load_raw_image(img_path) 168 | 169 | random_frame_aug = self.transform(random_frame) 170 | 171 | random_frame = np.array(random_frame).transpose(2, 0, 172 | 1).astype(np.float32) 173 | random_frame_aug = np.array(random_frame_aug).transpose( 174 | 2, 0, 1).astype(np.float32) 175 | 176 | identity_image = self._load_raw_image( 177 | f'{self._video_dir}/{video_name}/000.png') 178 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype( 179 | np.float32) 180 | 181 | if self.xflip and random.random() > 0.5: 182 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W] 183 | random_frame_aug = random_frame_aug[:, :, ::-1].copy() 184 | 185 | return identity_image, random_frame, random_frame_aug 186 | 187 | def sample_img_data(self, index): 188 | img_name = self._all_file_name[index] 189 | 190 | img_path = f'{self._shhq_data_dir}/{img_name}' 191 | 192 | identity_image = self._load_raw_image(img_path) 193 | 194 | random_frame = self.translate_transform(identity_image) 195 | 196 | random_frame_aug = self.transform(random_frame) 197 | 198 | random_frame = np.array(random_frame).transpose(2, 0, 199 | 1).astype(np.float32) 200 | random_frame_aug = np.array(random_frame_aug).transpose( 201 | 2, 0, 1).astype(np.float32) 202 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype( 203 | np.float32) 204 | 205 | if self.xflip and random.random() > 0.5: 206 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W] 207 | random_frame_aug = random_frame_aug[:, :, ::-1].copy() 208 | 209 | return identity_image, random_frame, random_frame_aug 210 | 211 | def __getitem__(self, index): 212 | 213 | if index < self.video_num: 214 | identity_image, random_frame, random_frame_aug = self.sample_video_data( 215 | index) 216 | else: 217 | identity_image, random_frame, random_frame_aug = self.sample_img_data( 218 | index) 219 | 220 | random_frame = random_frame / 127.5 - 1 221 | random_frame_aug = random_frame_aug / 127.5 - 1 222 | identity_image = identity_image / 127.5 - 1 223 | 224 | random_frame = torch.from_numpy(random_frame) 225 | random_frame_aug = torch.from_numpy(random_frame_aug) 226 | identity_image = torch.from_numpy(identity_image) 227 | 228 | return_dict = { 229 | 'frame_img': random_frame, 230 | 'frame_img_aug': random_frame_aug, 231 | 'identity_image': identity_image 232 | } 233 | 234 | return return_dict 235 | 236 | def __len__(self): 237 | return len(self._all_file_name) 238 | -------------------------------------------------------------------------------- /data/moving_label_clip_all_rate_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | 10 | def proper_capitalize(text): 11 | if len(text) > 0: 12 | text = text.lower() 13 | text = text[0].capitalize() + text[1:] 14 | for idx, char in enumerate(text): 15 | if char in ['.', '!', '?'] and (idx + 2) < len(text): 16 | text = text[:idx + 2] + text[idx + 2].capitalize() + text[idx + 17 | 3:] 18 | text = text.replace(' i ', ' I ') 19 | text = text.replace(',i ', ',I ') 20 | text = text.replace('.i ', '.I ') 21 | text = text.replace('!i ', '!I ') 22 | return text 23 | 24 | 25 | class MovingLabelsClipAllRateTextDataset(data.Dataset): 26 | 27 | def __init__(self, opt): 28 | self._video_dir = opt['video_dir'] 29 | self.downsample_factor = opt['downsample_factor'] 30 | self.random_start = opt['random_start'] 31 | self._video_names = [] 32 | 33 | self.frame_sample_rate = opt['frame_sample_rate'] 34 | 35 | self._action_label_folder = opt['action_label_folder'] 36 | self.action_labels = [] 37 | 38 | self.fixed_video_len = opt['fixed_video_len'] 39 | self.xflip = opt['xflip'] 40 | 41 | video_names_list = open(opt['data_name_txt'], 'r').readlines() 42 | 43 | self.moving_dict = np.load( 44 | opt['moving_frame_dict'], allow_pickle=True).item() 45 | 46 | self.interpolation_rate = opt['interpolation_rate'] 47 | 48 | self.all_clip_start_frame_list = [] 49 | self.all_clip_end_frame_list = [] 50 | self.all_clip_action_label_list = [] 51 | self.frame_num_list = [] 52 | 53 | with open(opt['overall_caption_templates_file'], 'r') as f: 54 | self.overall_caption_templates = json.load(f) 55 | 56 | for row in video_names_list: 57 | video_name = row.split()[0] 58 | frame_nums = int(row.split()[1]) 59 | action_label_txt = open( 60 | f'{self._action_label_folder}/{video_name}.txt', 61 | 'r').readlines() 62 | 63 | clip_start_frame_list = [] 64 | clip_end_frame_list = [] 65 | clip_action_label_list = [] 66 | for action_row in action_label_txt: 67 | start_frame, end_frame, action_label = action_row[:-1].split() 68 | start_frame = int(start_frame) 69 | end_frame = int(end_frame) 70 | action_label = int(action_label) 71 | 72 | if (end_frame - start_frame 73 | ) < self.frame_sample_rate * self.fixed_video_len: 74 | continue 75 | 76 | clip_start_frame_list.append(start_frame) 77 | clip_end_frame_list.append(end_frame) 78 | clip_action_label_list.append(action_label) 79 | 80 | if len(clip_start_frame_list) == 0: 81 | continue 82 | 83 | self._video_names.append(video_name) 84 | self.all_clip_start_frame_list.append(clip_start_frame_list) 85 | self.all_clip_end_frame_list.append(clip_end_frame_list) 86 | self.all_clip_action_label_list.append(clip_action_label_list) 87 | self.frame_num_list.append(frame_nums) 88 | 89 | assert len(self._video_names) == len(self.all_clip_start_frame_list) 90 | assert len(self._video_names) == len(self.all_clip_end_frame_list) 91 | assert len(self._video_names) == len(self.all_clip_action_label_list) 92 | assert len(self._video_names) == len(self.frame_num_list) 93 | 94 | def _load_raw_image(self, img_path): 95 | with open(img_path, 'rb') as f: 96 | image = Image.open(f) 97 | if self.downsample_factor != 1: 98 | width, height = image.size 99 | width = width // self.downsample_factor 100 | height = height // self.downsample_factor 101 | image = image.resize( 102 | size=(width, height), resample=Image.LANCZOS) 103 | image = np.array(image) 104 | if image.ndim == 2: 105 | image = image[:, :, np.newaxis] # HW => HWC 106 | image = image.transpose(2, 0, 1) # HWC => CHW 107 | return image.astype(np.float32) 108 | 109 | def sample_motion_clip(self, index): 110 | clip_start_frame_list = self.all_clip_start_frame_list[index] 111 | clip_end_frame_list = self.all_clip_end_frame_list[index] 112 | clip_action_label_list = self.all_clip_action_label_list[index] 113 | 114 | num_clip = len(clip_start_frame_list) 115 | 116 | clip_index = random.randint(0, num_clip - 1) 117 | 118 | action_label_list = clip_action_label_list[clip_index] 119 | 120 | clip_idx = list( 121 | range(clip_start_frame_list[clip_index], 122 | clip_end_frame_list[clip_index] + 1)) 123 | 124 | segm = len(clip_idx) // self.fixed_video_len 125 | 126 | segm_dist = [] 127 | for i in range(self.fixed_video_len - 1): 128 | segm_dist.append(segm) 129 | 130 | for i in range( 131 | min( 132 | len(clip_idx) - sum(segm_dist) - 2, 133 | self.fixed_video_len - 1)): 134 | segm_dist[i] += 1 135 | 136 | frame_idx_list = [] 137 | frame_idx_list.append(clip_start_frame_list[clip_index]) 138 | 139 | for i in range(len(segm_dist) - 1): 140 | frame_idx_list.append(clip_start_frame_list[clip_index] + 141 | sum(segm_dist[:i + 1])) 142 | frame_idx_list.append(clip_end_frame_list[clip_index]) 143 | 144 | return frame_idx_list, action_label_list 145 | 146 | def sample_random_clip(self, index): 147 | 148 | video_name = self._video_names[index] 149 | 150 | video_len = self.fixed_video_len 151 | 152 | if len(self.moving_dict[video_name]) == 0: 153 | start_frame = random.randint( 154 | 30, self.frame_num_list[index] - 1 - video_len) 155 | else: 156 | start_frame = random.choice(self.moving_dict[video_name]) 157 | 158 | while ((start_frame + video_len) > 159 | (self.frame_num_list[index] - 1)): 160 | start_frame = random.choice(self.moving_dict[video_name]) 161 | 162 | frame_idx_list = [] 163 | for frame_idx in range(video_len): 164 | video_frame_idx = start_frame + frame_idx 165 | frame_idx_list.append(video_frame_idx) 166 | 167 | action_label_list = 22 168 | 169 | return frame_idx_list, action_label_list 170 | 171 | def generate_caption(self, label): 172 | caption = random.choice(self.overall_caption_templates[label]) 173 | replacing_word = random.choice( 174 | self.overall_caption_templates["gender"]) 175 | caption = caption.replace('', replacing_word) 176 | caption = proper_capitalize(caption) 177 | 178 | return caption 179 | 180 | def __getitem__(self, index): 181 | video_name = self._video_names[index] 182 | 183 | if np.random.uniform(low=0.0, high=1.0) < self.interpolation_rate: 184 | frame_idx_list, action_label_list = self.sample_random_clip(index) 185 | interpolation_mode = True 186 | else: 187 | frame_idx_list, action_label_list = self.sample_motion_clip(index) 188 | interpolation_mode = False 189 | 190 | frames = [] 191 | for frame_idx in frame_idx_list: 192 | img_path = f'{self._video_dir}/{video_name}/{frame_idx:03d}.png' 193 | frames.append(self._load_raw_image(img_path)) 194 | 195 | frames = np.stack(frames, axis=0) 196 | 197 | exemplar_img = self._load_raw_image( 198 | f'{self._video_dir}/{video_name}/000.png') 199 | 200 | frames = frames / 127.5 - 1 201 | exemplar_img = exemplar_img / 127.5 - 1 202 | 203 | frames = torch.from_numpy(frames) 204 | exemplar_img = torch.from_numpy(exemplar_img) 205 | 206 | if action_label_list == 22: 207 | text_description = 'empty' 208 | else: 209 | text_description = self.generate_caption(str(action_label_list)) 210 | 211 | return_dict = { 212 | 'video_name': video_name, 213 | 'video_frames': frames, 214 | 'video_len': self.fixed_video_len, 215 | 'exemplar_img': exemplar_img, 216 | 'action_labels': action_label_list, 217 | 'text': text_description, 218 | 'interpolation_mode': interpolation_mode 219 | } 220 | 221 | return return_dict 222 | 223 | def __len__(self): 224 | return len(self._video_names) 225 | -------------------------------------------------------------------------------- /data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class PrefetchGenerator(threading.Thread): 9 | """A general prefetch generator. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 13 | 14 | Args: 15 | generator: Python generator. 16 | num_prefetch_queue (int): Number of prefetch queue. 17 | """ 18 | 19 | def __init__(self, generator, num_prefetch_queue): 20 | threading.Thread.__init__(self) 21 | self.queue = Queue.Queue(num_prefetch_queue) 22 | self.generator = generator 23 | self.daemon = True 24 | self.start() 25 | 26 | def run(self): 27 | for item in self.generator: 28 | self.queue.put(item) 29 | self.queue.put(None) 30 | 31 | def __next__(self): 32 | next_item = self.queue.get() 33 | if next_item is None: 34 | raise StopIteration 35 | return next_item 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class PrefetchDataLoader(DataLoader): 42 | """Prefetch version of dataloader. 43 | 44 | Ref: 45 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 46 | 47 | TODO: 48 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 49 | ddp. 50 | 51 | Args: 52 | num_prefetch_queue (int): Number of prefetch queue. 53 | kwargs (dict): Other arguments for dataloader. 54 | """ 55 | 56 | def __init__(self, num_prefetch_queue, **kwargs): 57 | self.num_prefetch_queue = num_prefetch_queue 58 | super(PrefetchDataLoader, self).__init__(**kwargs) 59 | 60 | def __iter__(self): 61 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 62 | 63 | 64 | class CPUPrefetcher(): 65 | """CPU prefetcher. 66 | 67 | Args: 68 | loader: Dataloader. 69 | """ 70 | 71 | def __init__(self, loader): 72 | self.ori_loader = loader 73 | self.loader = iter(loader) 74 | 75 | def next(self): 76 | try: 77 | return next(self.loader) 78 | except StopIteration: 79 | return None 80 | 81 | def reset(self): 82 | self.loader = iter(self.ori_loader) 83 | 84 | 85 | class CUDAPrefetcher(): 86 | """CUDA prefetcher. 87 | 88 | Ref: 89 | https://github.com/NVIDIA/apex/issues/304# 90 | 91 | It may consums more GPU memory. 92 | 93 | Args: 94 | loader: Dataloader. 95 | opt (dict): Options. 96 | """ 97 | 98 | def __init__(self, loader, opt): 99 | self.ori_loader = loader 100 | self.loader = iter(loader) 101 | self.opt = opt 102 | self.stream = torch.cuda.Stream() 103 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 104 | self.preload() 105 | 106 | def preload(self): 107 | try: 108 | self.batch = next(self.loader) # self.batch is a dict 109 | except StopIteration: 110 | self.batch = None 111 | return None 112 | # put tensors to gpu 113 | with torch.cuda.stream(self.stream): 114 | for k, v in self.batch.items(): 115 | if torch.is_tensor(v): 116 | self.batch[k] = self.batch[k].to( 117 | device=self.device, non_blocking=True) 118 | 119 | def next(self): 120 | torch.cuda.current_stream().wait_stream(self.stream) 121 | batch = self.batch 122 | self.preload() 123 | return batch 124 | 125 | def reset(self): 126 | self.loader = iter(self.ori_loader) 127 | self.preload() 128 | -------------------------------------------------------------------------------- /data/sample_identity_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | 10 | class SampleIdentityDataset(data.Dataset): 11 | 12 | def __init__(self, opt): 13 | self._video_dir = opt['video_dir'] 14 | self.downsample_factor = opt['downsample_factor'] 15 | 16 | self._video_names = [] 17 | self._frame_nums = [] 18 | 19 | self.xflip = opt['xflip'] 20 | 21 | video_names_list = open(opt['data_name_txt'], 'r').readlines() 22 | 23 | for row in video_names_list: 24 | video_name = row.split()[0] 25 | self._video_names.append(video_name) 26 | 27 | with open(opt['text_file']) as json_file: 28 | self.text_descriptions = json.load(json_file) 29 | 30 | def _load_raw_image(self, img_path): 31 | with open(img_path, 'rb') as f: 32 | image = Image.open(f) 33 | image.load() 34 | if self.downsample_factor != 1: 35 | width, height = image.size 36 | width = width // self.downsample_factor 37 | height = height // self.downsample_factor 38 | image = image.resize( 39 | size=(width, height), resample=Image.LANCZOS) 40 | 41 | return image 42 | 43 | def __getitem__(self, index): 44 | video_name = self._video_names[index] 45 | 46 | identity_image = self._load_raw_image( 47 | f'{self._video_dir}/{video_name}/000.png') 48 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype( 49 | np.float32) 50 | if self.xflip and random.random() > 0.5: 51 | identity_image = identity_image[:, :, ::-1].copy() # [C, H ,W] 52 | 53 | identity_image = identity_image / 127.5 - 1 54 | 55 | identity_image = torch.from_numpy(identity_image) 56 | 57 | if len(self.text_descriptions[video_name]) == 1: 58 | text_description = self.text_descriptions[video_name] 59 | else: 60 | text_description = random.choice( 61 | self.text_descriptions[video_name]) 62 | 63 | return_dict = { 64 | 'img_name': video_name, 65 | 'image': identity_image, 66 | 'text': text_description 67 | } 68 | 69 | return return_dict 70 | 71 | def __len__(self): 72 | return len(self._video_names) 73 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: text2performer 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2023.08.22=h06a4308_0 9 | - certifi=2022.12.7=py37h06a4308_0 10 | - cudatoolkit=10.1.243=h6bb024c_0 11 | - freetype=2.11.0=h70c0345_0 12 | - giflib=5.2.1=h7b6447c_0 13 | - intel-openmp=2021.4.0=h06a4308_3561 14 | - jpeg=9b=h024ee3a_2 15 | - lcms2=2.12=h3be6417_0 16 | - ld_impl_linux-64=2.38=h1181459_1 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=9.1.0=hdf63c60_0 21 | - libtiff=4.1.0=h2733197_1 22 | - libuv=1.40.0=h7b6447c_0 23 | - libwebp=1.2.0=h89dd481_0 24 | - lz4-c=1.9.3=h295c915_1 25 | - mkl=2021.4.0=h06a4308_640 26 | - mkl-service=2.4.0=py37h7f8727e_0 27 | - mkl_fft=1.3.1=py37hd3c417c_0 28 | - mkl_random=1.2.2=py37h51133e4_0 29 | - ncurses=6.3=h7f8727e_2 30 | - ninja=1.10.2=h06a4308_5 31 | - ninja-base=1.10.2=hd09550d_5 32 | - numpy=1.21.5=py37he7a7128_2 33 | - numpy-base=1.21.5=py37hf524024_2 34 | - openssl=1.1.1w=h7f8727e_0 35 | - pillow=9.0.1=py37h22f2fdc_0 36 | - pip=22.3.1=py37h06a4308_0 37 | - python=3.7.11=h12debd9_0 38 | - pytorch=1.7.1=py3.7_cuda10.1.243_cudnn7.6.3_0 39 | - readline=8.1.2=h7f8727e_1 40 | - setuptools=65.6.3=py37h06a4308_0 41 | - six=1.16.0=pyhd3eb1b0_1 42 | - sqlite=3.38.5=hc218d9a_0 43 | - tk=8.6.12=h1ccaba5_0 44 | - torchvision=0.8.2=py37_cu101 45 | - typing_extensions=4.3.0=py37h06a4308_0 46 | - wheel=0.38.4=py37h06a4308_0 47 | - xz=5.2.5=h7f8727e_1 48 | - zlib=1.2.12=h7f8727e_2 49 | - zstd=1.4.9=haebb681_0 50 | - pip: 51 | - absl-py==2.0.0 52 | - anyio==3.7.1 53 | - argon2-cffi==23.1.0 54 | - argon2-cffi-bindings==21.2.0 55 | - attrs==23.1.0 56 | - babel==2.12.1 57 | - backcall==0.2.0 58 | - beautifulsoup4==4.12.2 59 | - bleach==6.0.0 60 | - cachetools==4.2.4 61 | - cffi==1.15.1 62 | - charset-normalizer==3.2.0 63 | - click==8.1.7 64 | - debugpy==1.7.0 65 | - decorator==5.1.1 66 | - defusedxml==0.7.1 67 | - einops==0.6.1 68 | - entrypoints==0.4 69 | - exceptiongroup==1.1.3 70 | - fastjsonschema==2.18.0 71 | - filelock==3.12.2 72 | - fsspec==2023.1.0 73 | - google-auth==1.35.0 74 | - google-auth-oauthlib==0.4.6 75 | - grpcio==1.58.0 76 | - huggingface-hub==0.16.4 77 | - idna==3.4 78 | - importlib-metadata==6.7.0 79 | - importlib-resources==5.12.0 80 | - ipykernel==6.16.2 81 | - ipython==7.34.0 82 | - ipython-genutils==0.2.0 83 | - jedi==0.19.0 84 | - jinja2==3.1.2 85 | - joblib==1.3.2 86 | - json5==0.9.14 87 | - jsonschema==4.17.3 88 | - jupyter-client==7.4.9 89 | - jupyter-core==4.12.0 90 | - jupyter-server==1.24.0 91 | - jupyterlab==3.3.2 92 | - jupyterlab-pygments==0.2.2 93 | - jupyterlab-server==2.24.0 94 | - lpips==0.1.4 95 | - markdown==3.4.4 96 | - markupsafe==2.1.3 97 | - matplotlib-inline==0.1.6 98 | - mistune==3.0.1 99 | - nbclassic==0.5.6 100 | - nbclient==0.7.4 101 | - nbconvert==7.6.0 102 | - nbformat==5.8.0 103 | - nest-asyncio==1.5.8 104 | - nltk==3.8.1 105 | - notebook-shim==0.2.3 106 | - oauthlib==3.2.2 107 | - opencv-python==4.5.5.62 108 | - packaging==23.1 109 | - pandocfilters==1.5.0 110 | - parso==0.8.3 111 | - pexpect==4.8.0 112 | - pickleshare==0.7.5 113 | - pkgutil-resolve-name==1.3.10 114 | - prometheus-client==0.17.1 115 | - prompt-toolkit==3.0.39 116 | - protobuf==3.20.3 117 | - psutil==5.9.5 118 | - ptyprocess==0.7.0 119 | - pyasn1==0.5.0 120 | - pyasn1-modules==0.3.0 121 | - pycparser==2.21 122 | - pygments==2.16.1 123 | - pyrsistent==0.19.3 124 | - python-dateutil==2.8.2 125 | - pytz==2023.3.post1 126 | - pyyaml==6.0.1 127 | - pyzmq==25.1.1 128 | - regex==2023.8.8 129 | - requests==2.31.0 130 | - requests-oauthlib==1.3.1 131 | - safetensors==0.3.3 132 | - scikit-learn==1.0.2 133 | - scipy==1.7.3 134 | - send2trash==1.8.2 135 | - sentence-transformers==2.2.2 136 | - sentencepiece==0.1.99 137 | - sniffio==1.3.0 138 | - soupsieve==2.4.1 139 | - tensorboard==2.5.0 140 | - tensorboard-data-server==0.6.1 141 | - tensorboard-plugin-wit==1.8.1 142 | - terminado==0.17.1 143 | - threadpoolctl==3.1.0 144 | - tinycss2==1.2.1 145 | - tokenizers==0.13.3 146 | - tornado==6.2 147 | - tqdm==4.66.1 148 | - traitlets==5.9.0 149 | - transformers==4.30.2 150 | - urllib3==2.0.5 151 | - wcwidth==0.2.7 152 | - webencodings==0.5.1 153 | - websocket-client==1.6.1 154 | - werkzeug==2.2.3 155 | - zipp==3.15.0 156 | 157 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/img/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import logging 4 | import os.path as osp 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] 12 | for v in glob.glob(f'{model_folder}/*_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamically instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = logging.getLogger('base') 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /models/app_transformer_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributions as dists 8 | import torch.nn.functional as F 9 | from sentence_transformers import SentenceTransformer 10 | from torchvision.utils import save_image 11 | 12 | from models.archs.transformer_arch import TransformerLanguage 13 | from models.archs.vqgan_arch import ( 14 | DecoderUpOthersDoubleIdentity, 15 | EncoderDecomposeBaseDownOthersDoubleIdentity, VectorQuantizer) 16 | 17 | logger = logging.getLogger('base') 18 | 19 | 20 | class AppTransformerModel(): 21 | """Texture-Aware Diffusion based Transformer model. 22 | """ 23 | 24 | def __init__(self, opt): 25 | self.opt = opt 26 | self.device = torch.device('cuda') 27 | self.is_train = opt['is_train'] 28 | 29 | # VQVAE for image 30 | self.img_encoder = EncoderDecomposeBaseDownOthersDoubleIdentity( 31 | ch=opt['img_ch'], 32 | num_res_blocks=opt['img_num_res_blocks'], 33 | attn_resolutions=opt['img_attn_resolutions'], 34 | ch_mult=opt['img_ch_mult'], 35 | other_ch_mult=opt['img_other_ch_mult'], 36 | in_channels=opt['img_in_channels'], 37 | resolution=opt['img_resolution'], 38 | z_channels=opt['img_z_channels'], 39 | double_z=opt['img_double_z'], 40 | dropout=opt['img_dropout']).to(self.device) 41 | self.img_decoder = DecoderUpOthersDoubleIdentity( 42 | in_channels=opt['img_in_channels'], 43 | resolution=opt['img_resolution'], 44 | z_channels=opt['img_z_channels'], 45 | ch=opt['img_ch'], 46 | out_ch=opt['img_out_ch'], 47 | num_res_blocks=opt['img_num_res_blocks'], 48 | attn_resolutions=opt['img_attn_resolutions'], 49 | ch_mult=opt['img_ch_mult'], 50 | other_ch_mult=opt['img_other_ch_mult'], 51 | dropout=opt['img_dropout'], 52 | resamp_with_conv=True, 53 | give_pre_end=False).to(self.device) 54 | self.quantize_identity = VectorQuantizer( 55 | opt['img_n_embed'], opt['img_embed_dim'], 56 | beta=0.25).to(self.device) 57 | self.quant_conv_identity = torch.nn.Conv2d(opt["img_z_channels"], 58 | opt['img_embed_dim'], 59 | 1).to(self.device) 60 | self.post_quant_conv_identity = torch.nn.Conv2d( 61 | opt['img_embed_dim'], opt["img_z_channels"], 1).to(self.device) 62 | 63 | self.quantize_others = VectorQuantizer( 64 | opt['img_n_embed'], opt['img_embed_dim'] // 2, 65 | beta=0.25).to(self.device) 66 | self.quant_conv_others = torch.nn.Conv2d(opt["img_z_channels"] // 2, 67 | opt['img_embed_dim'] // 2, 68 | 1).to(self.device) 69 | self.post_quant_conv_others = torch.nn.Conv2d( 70 | opt['img_embed_dim'] // 2, opt["img_z_channels"] // 2, 71 | 1).to(self.device) 72 | self.load_pretrained_image_vae() 73 | 74 | # define sampler 75 | self._denoise_fn = TransformerLanguage( 76 | codebook_size=opt['codebook_size'], 77 | bert_n_emb=opt['bert_n_emb'], 78 | bert_n_layers=opt['bert_n_layers'], 79 | bert_n_head=opt['bert_n_head'], 80 | block_size=opt['block_size'] * 2, 81 | embd_pdrop=opt['embd_pdrop'], 82 | resid_pdrop=opt['resid_pdrop'], 83 | attn_pdrop=opt['attn_pdrop']).to(self.device) 84 | 85 | self.num_classes = opt['codebook_size'] 86 | self.shape = tuple(opt['latent_shape']) 87 | self.num_timesteps = 1000 88 | 89 | self.mask_id = opt['codebook_size'] 90 | self.loss_type = opt['loss_type'] 91 | self.mask_schedule = opt['mask_schedule'] 92 | 93 | self.sample_steps = opt['sample_steps'] 94 | 95 | self.text_seq_len = opt['text_seq_len'] 96 | 97 | self.init_training_settings() 98 | 99 | self.get_fixed_language_model() 100 | 101 | def load_pretrained_image_vae(self): 102 | # load pretrained vqgan for segmentation mask 103 | img_ae_checkpoint = torch.load(self.opt['img_ae_path']) 104 | self.img_encoder.load_state_dict( 105 | img_ae_checkpoint['encoder'], strict=True) 106 | self.img_decoder.load_state_dict( 107 | img_ae_checkpoint['decoder'], strict=True) 108 | self.quantize_identity.load_state_dict( 109 | img_ae_checkpoint['quantize_identity'], strict=True) 110 | self.quant_conv_identity.load_state_dict( 111 | img_ae_checkpoint['quant_conv_identity'], strict=True) 112 | self.post_quant_conv_identity.load_state_dict( 113 | img_ae_checkpoint['post_quant_conv_identity'], strict=True) 114 | self.quantize_others.load_state_dict( 115 | img_ae_checkpoint['quantize_others'], strict=True) 116 | self.quant_conv_others.load_state_dict( 117 | img_ae_checkpoint['quant_conv_others'], strict=True) 118 | self.post_quant_conv_others.load_state_dict( 119 | img_ae_checkpoint['post_quant_conv_others'], strict=True) 120 | self.img_encoder.eval() 121 | self.img_decoder.eval() 122 | self.quantize_identity.eval() 123 | self.quant_conv_identity.eval() 124 | self.post_quant_conv_identity.eval() 125 | self.quantize_others.eval() 126 | self.quant_conv_others.eval() 127 | self.post_quant_conv_others.eval() 128 | 129 | def init_training_settings(self): 130 | optim_params = [] 131 | for v in self._denoise_fn.parameters(): 132 | if v.requires_grad: 133 | optim_params.append(v) 134 | # set up optimizer 135 | self.optimizer = torch.optim.Adam( 136 | optim_params, 137 | self.opt['lr'], 138 | weight_decay=self.opt['weight_decay']) 139 | self.log_dict = OrderedDict() 140 | 141 | @torch.no_grad() 142 | def get_quantized_img(self, image): 143 | h_identity, _ = self.img_encoder(image) 144 | h_identity = self.quant_conv_identity(h_identity) 145 | _, _, [_, _, identity_tokens] = self.quantize_identity(h_identity) 146 | 147 | _, h_frame = self.img_encoder(image) 148 | h_frame = self.quant_conv_others(h_frame) 149 | _, _, [_, _, pose_tokens] = self.quantize_others(h_frame) 150 | 151 | # reshape the tokens 152 | b = image.size(0) 153 | identity_tokens = identity_tokens.view(b, -1) 154 | pose_tokens = pose_tokens.view(b, -1) 155 | 156 | return identity_tokens, pose_tokens 157 | 158 | @torch.no_grad() 159 | def decode(self, quant_list): 160 | quant_identity = self.post_quant_conv_identity(quant_list[0]) 161 | quant_frame = self.post_quant_conv_others(quant_list[1]) 162 | dec = self.img_decoder(quant_identity, quant_frame) 163 | return dec 164 | 165 | @torch.no_grad() 166 | def decode_image_indices(self, quant_identity, quant_frame): 167 | quant_identity = self.quantize_identity.get_codebook_entry( 168 | quant_identity, (quant_identity.size(0), self.shape[0], 169 | self.shape[1], self.opt["img_z_channels"])) 170 | quant_frame = self.quantize_others.get_codebook_entry( 171 | quant_frame, (quant_frame.size(0), self.shape[0] // 4, 172 | self.shape[1] // 4, self.opt["img_z_channels"] // 2)) 173 | dec = self.decode([quant_identity, quant_frame]) 174 | 175 | return dec 176 | 177 | def sample_time(self, b, device, method='uniform'): 178 | if method == 'importance': 179 | if not (self.Lt_count > 10).all(): 180 | return self.sample_time(b, device, method='uniform') 181 | 182 | Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 183 | Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1. 184 | pt_all = Lt_sqrt / Lt_sqrt.sum() 185 | 186 | t = torch.multinomial(pt_all, num_samples=b, replacement=True) 187 | 188 | pt = pt_all.gather(dim=0, index=t) 189 | 190 | return t, pt 191 | 192 | elif method == 'uniform': 193 | t = torch.randint( 194 | 1, self.num_timesteps + 1, (b, ), device=device).long() 195 | pt = torch.ones_like(t).float() / self.num_timesteps 196 | return t, pt 197 | 198 | else: 199 | raise ValueError 200 | 201 | def q_sample(self, x_0, x_0_gt, t): 202 | # samples q(x_t | x_0) 203 | # randomly set token to mask with probability t/T 204 | # x_t, x_0_ignore = x_0.clone(), x_0.clone() 205 | x_t = x_0.clone() 206 | 207 | mask = torch.rand_like(x_t.float()) < ( 208 | t.float().unsqueeze(-1) / self.num_timesteps) 209 | x_t[mask] = self.mask_id 210 | # x_0_ignore[torch.bitwise_not(mask)] = -1 211 | 212 | # for every gt token list, we also need to do the mask 213 | x_0_gt_ignore = x_0_gt.clone() 214 | x_0_gt_ignore[torch.bitwise_not(mask)] = -1 215 | 216 | return x_t, x_0_gt_ignore, mask 217 | 218 | def _train_loss(self, x_identity_0, x_pose_0): 219 | b, device = x_identity_0.size(0), x_identity_0.device 220 | 221 | # choose what time steps to compute loss at 222 | t, pt = self.sample_time(b, device, 'uniform') 223 | 224 | # make x noisy and denoise 225 | if self.mask_schedule == 'random': 226 | x_identity_t, x_identity_0_gt_ignore, mask = self.q_sample( 227 | x_0=x_identity_0, x_0_gt=x_identity_0, t=t) 228 | x_pose_t, x_pose_0_gt_ignore, mask = self.q_sample( 229 | x_0=x_pose_0, x_0_gt=x_pose_0, t=t) 230 | else: 231 | raise NotImplementedError 232 | 233 | # sample p(x_0 | x_t) 234 | x_identity_0_hat_logits, x_pose_0_hat_logits = self._denoise_fn( 235 | self.text_embedding, x_identity_t, x_pose_t, t=t) 236 | 237 | x_identity_0_hat_logits = x_identity_0_hat_logits[:, 238 | 1:1 + self.shape[0] * 239 | self.shape[1], :] 240 | x_pose_0_hat_logits = x_pose_0_hat_logits[:, 1 + self.shape[0] * 241 | self.shape[1]:] 242 | 243 | # Always compute ELBO for comparison purposes 244 | cross_entropy_loss = 0 245 | 246 | cross_entropy_loss = F.cross_entropy( 247 | x_identity_0_hat_logits.permute(0, 2, 1), 248 | x_identity_0_gt_ignore, 249 | ignore_index=-1, 250 | reduction='none').sum(1) + F.cross_entropy( 251 | x_pose_0_hat_logits.permute(0, 2, 1), 252 | x_pose_0_gt_ignore, 253 | ignore_index=-1, 254 | reduction='none').sum(1) 255 | vb_loss = cross_entropy_loss / t 256 | vb_loss = vb_loss / pt 257 | vb_loss = vb_loss / (math.log(2) * x_identity_0.shape[1:].numel()) 258 | if self.loss_type == 'elbo': 259 | loss = vb_loss 260 | elif self.loss_type == 'mlm': 261 | denom = mask.float().sum(1) 262 | denom[denom == 0] = 1 # prevent divide by 0 errors. 263 | loss = cross_entropy_loss / denom 264 | elif self.loss_type == 'reweighted_elbo': 265 | weight = (1 - (t / self.num_timesteps)) 266 | loss = weight * cross_entropy_loss 267 | loss = loss / (math.log(2) * x_identity_0.shape[1:].numel()) 268 | else: 269 | raise ValueError 270 | 271 | return loss.mean(), vb_loss.mean() 272 | 273 | def feed_data(self, data): 274 | self.image = data['image'].to(self.device) 275 | self.text = data['text'] #.to(self.device) 276 | self.get_text_embedding() 277 | 278 | self.identity_tokens, self.pose_tokens = self.get_quantized_img( 279 | self.image) 280 | 281 | def get_fixed_language_model(self): 282 | self.language_model = SentenceTransformer('all-MiniLM-L6-v2') 283 | self.text_feature_dim = 384 284 | 285 | @torch.no_grad() 286 | def get_text_embedding(self): 287 | self.text_embedding = self.language_model.encode(self.text, show_progress_bar=False) 288 | self.text_embedding = torch.Tensor(self.text_embedding).to( 289 | self.device).unsqueeze(1) 290 | 291 | def optimize_parameters(self): 292 | self._denoise_fn.train() 293 | 294 | loss, vb_loss = self._train_loss(self.identity_tokens, 295 | self.pose_tokens) 296 | 297 | self.optimizer.zero_grad() 298 | loss.backward() 299 | self.optimizer.step() 300 | 301 | self.log_dict['loss'] = loss 302 | self.log_dict['vb_loss'] = vb_loss 303 | 304 | self._denoise_fn.eval() 305 | 306 | def sample_fn(self, temp=1.0, sample_steps=None): 307 | self._denoise_fn.eval() 308 | 309 | b, device = self.image.size(0), 'cuda' 310 | x_identity_t = torch.ones( 311 | (b, np.prod(self.shape)), device=device).long() * self.mask_id 312 | x_pose_t = torch.ones((b, np.prod(self.shape) // 16), 313 | device=device).long() * self.mask_id 314 | unmasked_identity = torch.zeros_like( 315 | x_identity_t, device=device).bool() 316 | unmasked_pose = torch.zeros_like(x_pose_t, device=device).bool() 317 | sample_steps = list(range(1, sample_steps + 1)) 318 | 319 | for t in reversed(sample_steps): 320 | print(f'Sample timestep {t:4d}', end='\r') 321 | t = torch.full((b, ), t, device=device, dtype=torch.long) 322 | 323 | # where to unmask 324 | changes_identity = torch.rand( 325 | x_identity_t.shape, 326 | device=device) < 1 / t.float().unsqueeze(-1) 327 | # don't unmask somewhere already unmasked 328 | changes_identity = torch.bitwise_xor( 329 | changes_identity, 330 | torch.bitwise_and(changes_identity, unmasked_identity)) 331 | # update mask with changes 332 | unmasked_identity = torch.bitwise_or(unmasked_identity, 333 | changes_identity) 334 | 335 | changes_pose = torch.rand( 336 | x_pose_t.shape, device=device) < 1 / t.float().unsqueeze(-1) 337 | # don't unmask somewhere already unmasked 338 | changes_pose = torch.bitwise_xor( 339 | changes_pose, torch.bitwise_and(changes_pose, unmasked_pose)) 340 | # update mask with changes 341 | unmasked_pose = torch.bitwise_or(unmasked_pose, changes_pose) 342 | 343 | x_identity_0_hat_logits, x_pose_0_hat_logits = self._denoise_fn( 344 | self.text_embedding, x_identity_t, x_pose_t, t=t) 345 | 346 | x_identity_0_hat_logits = x_identity_0_hat_logits[:, 1:1 + 347 | self.shape[0] * 348 | self.shape[1], :] 349 | x_pose_0_hat_logits = x_pose_0_hat_logits[:, 1 + self.shape[0] * 350 | self.shape[1]:] 351 | 352 | # scale by temperature 353 | x_identity_0_hat_logits = x_identity_0_hat_logits / temp 354 | x_identity_0_dist = dists.Categorical( 355 | logits=x_identity_0_hat_logits) 356 | x_identity_0_hat = x_identity_0_dist.sample().long() 357 | 358 | x_pose_0_hat_logits = x_pose_0_hat_logits / temp 359 | x_pose_0_dist = dists.Categorical(logits=x_pose_0_hat_logits) 360 | x_pose_0_hat = x_pose_0_dist.sample().long() 361 | 362 | # x_t would be the input to the transformer, so the index range should be continual one 363 | x_identity_t[changes_identity] = x_identity_0_hat[changes_identity] 364 | x_pose_t[changes_pose] = x_pose_0_hat[changes_pose] 365 | 366 | self._denoise_fn.train() 367 | 368 | return x_identity_t, x_pose_t 369 | 370 | def get_vis(self, image, gt_quant_identity, gt_quant_frame, quant_identity, 371 | quant_frame, save_path): 372 | # original image 373 | ori_img = self.decode_image_indices(gt_quant_identity, gt_quant_frame) 374 | # pred image 375 | pred_img = self.decode_image_indices(quant_identity, quant_frame) 376 | img_cat = torch.cat([ 377 | image, 378 | ori_img, 379 | pred_img, 380 | ], dim=3).detach() 381 | img_cat = ((img_cat + 1) / 2) 382 | img_cat = img_cat.clamp_(0, 1) 383 | save_image(img_cat, save_path, nrow=1, padding=4) 384 | 385 | def inference(self, data_loader, save_dir): 386 | self._denoise_fn.eval() 387 | 388 | for _, data in enumerate(data_loader): 389 | img_name = data['img_name'] 390 | self.feed_data(data) 391 | b = self.image.size(0) 392 | with torch.no_grad(): 393 | x_identity_t, x_pose_t = self.sample_fn( 394 | temp=1, sample_steps=self.sample_steps) 395 | for idx in range(b): 396 | self.get_vis(self.image[idx:idx + 1], 397 | self.identity_tokens[idx:idx + 1], 398 | self.pose_tokens[idx:idx + 1], 399 | x_identity_t[idx:idx + 1], x_pose_t[idx:idx + 1], 400 | f'{save_dir}/{img_name[idx]}.png') 401 | 402 | self._denoise_fn.train() 403 | 404 | def sample_appearance(self, text, save_path, shape=[256, 128]): 405 | self._denoise_fn.eval() 406 | 407 | self.text = text 408 | self.image = torch.zeros([1, 3, shape[0], shape[1]]).to(self.device) 409 | self.get_text_embedding() 410 | 411 | with torch.no_grad(): 412 | x_identity_t, x_pose_t = self.sample_fn( 413 | temp=1, sample_steps=self.sample_steps) 414 | 415 | self.get_vis_generated_only(x_identity_t, x_pose_t, save_path) 416 | 417 | quant_identity = self.quantize_identity.get_codebook_entry( 418 | x_identity_t, (x_identity_t.size(0), self.shape[0], self.shape[1], 419 | self.opt["img_z_channels"])) 420 | quant_frame = self.quantize_others.get_codebook_entry( 421 | x_pose_t, 422 | (x_pose_t.size(0), self.shape[0] // 4, self.shape[1] // 4, 423 | self.opt["img_z_channels"] // 2)).view( 424 | x_pose_t.size(0), self.opt["img_z_channels"] // 2, 425 | -1).permute(0, 2, 1) 426 | 427 | self._denoise_fn.train() 428 | 429 | return quant_identity, quant_frame 430 | 431 | def get_vis_generated_only(self, quant_identity, quant_frame, save_path): 432 | # pred image 433 | pred_img = self.decode_image_indices(quant_identity, quant_frame) 434 | img_cat = ((pred_img.detach() + 1) / 2) 435 | img_cat = img_cat.clamp_(0, 1) 436 | save_image(img_cat, save_path, nrow=1, padding=4) 437 | 438 | def get_current_log(self): 439 | return self.log_dict 440 | 441 | def update_learning_rate(self, epoch, iters=None): 442 | """Update learning rate. 443 | 444 | Args: 445 | current_iter (int): Current iteration. 446 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 447 | Default: -1. 448 | """ 449 | lr = self.optimizer.param_groups[0]['lr'] 450 | 451 | if self.opt['lr_decay'] == 'step': 452 | lr = self.opt['lr'] * ( 453 | self.opt['gamma']**(epoch // self.opt['step'])) 454 | elif self.opt['lr_decay'] == 'cos': 455 | lr = self.opt['lr'] * ( 456 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 457 | elif self.opt['lr_decay'] == 'linear': 458 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) 459 | elif self.opt['lr_decay'] == 'linear2exp': 460 | if epoch < self.opt['turning_point'] + 1: 461 | # learning rate decay as 95% 462 | # at the turning point (1 / 95% = 1.0526) 463 | lr = self.opt['lr'] * ( 464 | 1 - epoch / int(self.opt['turning_point'] * 1.0526)) 465 | else: 466 | lr *= self.opt['gamma'] 467 | elif self.opt['lr_decay'] == 'schedule': 468 | if epoch in self.opt['schedule']: 469 | lr *= self.opt['gamma'] 470 | elif self.opt['lr_decay'] == 'warm_up': 471 | if iters <= self.opt['warmup_iters']: 472 | lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters'] 473 | else: 474 | lr = self.opt['lr'] 475 | else: 476 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) 477 | # set learning rate 478 | for param_group in self.optimizer.param_groups: 479 | param_group['lr'] = lr 480 | 481 | return lr 482 | 483 | def save_network(self, net, save_path): 484 | """Save networks. 485 | 486 | Args: 487 | net (nn.Module): Network to be saved. 488 | net_label (str): Network label. 489 | current_iter (int): Current iter number. 490 | """ 491 | state_dict = net.state_dict() 492 | torch.save(state_dict, save_path) 493 | 494 | def load_network(self): 495 | checkpoint = torch.load(self.opt['pretrained_sampler']) 496 | self._denoise_fn.load_state_dict(checkpoint, strict=True) 497 | self._denoise_fn.eval() 498 | -------------------------------------------------------------------------------- /models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/models/archs/__init__.py -------------------------------------------------------------------------------- /models/archs/dalle_transformer_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import wraps 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from torch import einsum, nn 8 | 9 | from models.archs.einops_exts import repeat_many 10 | from models.archs.rotary_embedding_torch import RotaryEmbedding 11 | 12 | 13 | # helper functions 14 | def exists(val): 15 | return val is not None 16 | 17 | 18 | def l2norm(t): 19 | return F.normalize(t, dim=-1) 20 | 21 | 22 | # relative positional bias for causal transformer 23 | class RelPosBias(nn.Module): 24 | 25 | def __init__( 26 | self, 27 | heads=8, 28 | num_buckets=32, 29 | max_distance=128, 30 | ): 31 | super().__init__() 32 | self.num_buckets = num_buckets 33 | self.max_distance = max_distance 34 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 35 | 36 | @staticmethod 37 | def _relative_position_bucket(relative_position, 38 | num_buckets=32, 39 | max_distance=128): 40 | n = -relative_position 41 | n = torch.max(n, torch.zeros_like(n)) 42 | 43 | max_exact = num_buckets // 2 44 | is_small = n < max_exact 45 | 46 | val_if_large = max_exact + (torch.log(n.float() / max_exact) / 47 | math.log(max_distance / max_exact) * 48 | (num_buckets - max_exact)).long() 49 | val_if_large = torch.min( 50 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 51 | return torch.where(is_small, n, val_if_large) 52 | 53 | def forward(self, i, j, *, device): 54 | q_pos = torch.arange(i, dtype=torch.long, device=device) 55 | k_pos = torch.arange(j, dtype=torch.long, device=device) 56 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') 57 | rp_bucket = self._relative_position_bucket( 58 | rel_pos, 59 | num_buckets=self.num_buckets, 60 | max_distance=self.max_distance) 61 | values = self.relative_attention_bias(rp_bucket) 62 | return rearrange(values, 'i j h -> h i j') 63 | 64 | 65 | class LayerNorm(nn.Module): 66 | 67 | def __init__(self, dim, eps=1e-5, fp16_eps=1e-3, stable=False): 68 | super().__init__() 69 | self.eps = eps 70 | self.fp16_eps = fp16_eps 71 | self.stable = stable 72 | self.g = nn.Parameter(torch.ones(dim)) 73 | 74 | def forward(self, x): 75 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps 76 | 77 | if self.stable: 78 | x = x / x.amax(dim=-1, keepdim=True).detach() 79 | 80 | var = torch.var(x, dim=-1, unbiased=False, keepdim=True) 81 | mean = torch.mean(x, dim=-1, keepdim=True) 82 | return (x - mean) * (var + eps).rsqrt() * self.g 83 | 84 | 85 | # attention 86 | class Attention(nn.Module): 87 | 88 | def __init__(self, 89 | dim, 90 | *, 91 | dim_head=64, 92 | heads=8, 93 | dropout=0., 94 | causal=False, 95 | rotary_emb=None, 96 | cosine_sim=True, 97 | cosine_sim_scale=16): 98 | super().__init__() 99 | self.scale = cosine_sim_scale if cosine_sim else (dim_head**-0.5) 100 | self.cosine_sim = cosine_sim 101 | 102 | self.heads = heads 103 | inner_dim = dim_head * heads 104 | 105 | self.causal = causal 106 | self.norm = LayerNorm(dim) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 110 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 111 | self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) 112 | 113 | self.rotary_emb = rotary_emb 114 | 115 | self.to_out = nn.Sequential( 116 | nn.Linear(inner_dim, dim, bias=False), LayerNorm(dim)) 117 | 118 | def forward(self, x, mask=None, attn_bias=None): 119 | b, n, device = *x.shape[:2], x.device 120 | 121 | x = self.norm(x) 122 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) 123 | 124 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads) 125 | q = q * self.scale 126 | 127 | # rotary embeddings 128 | 129 | if exists(self.rotary_emb): 130 | q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k)) 131 | 132 | # add null key / value for classifier free guidance in prior net 133 | 134 | nk, nv = repeat_many(self.null_kv.unbind(dim=-2), 'd -> b 1 d', b=b) 135 | k = torch.cat((nk, k), dim=-2) 136 | v = torch.cat((nv, v), dim=-2) 137 | 138 | # whether to use cosine sim 139 | 140 | if self.cosine_sim: 141 | q, k = map(l2norm, (q, k)) 142 | 143 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k)) 144 | 145 | # calculate query / key similarities 146 | 147 | sim = einsum('b h i d, b j d -> b h i j', q, k) 148 | 149 | # relative positional encoding (T5 style) 150 | 151 | if exists(attn_bias): 152 | sim = sim + attn_bias 153 | 154 | # masking 155 | 156 | max_neg_value = -torch.finfo(sim.dtype).max 157 | 158 | if exists(mask): 159 | mask = F.pad(mask, (1, 0), value=True) 160 | mask = rearrange(mask, 'b j -> b 1 1 j') 161 | sim = sim.masked_fill(~mask, max_neg_value) 162 | 163 | if self.causal: 164 | i, j = sim.shape[-2:] 165 | causal_mask = torch.ones((i, j), dtype=torch.bool, 166 | device=device).triu(j - i + 1) 167 | sim = sim.masked_fill(causal_mask, max_neg_value) 168 | 169 | # attention 170 | 171 | attn = sim.softmax(dim=-1, dtype=torch.float32) 172 | attn = self.dropout(attn) 173 | 174 | # aggregate values 175 | 176 | out = einsum('b h i j, b j d -> b h i d', attn, v) 177 | 178 | out = rearrange(out, 'b h n d -> b n (h d)') 179 | return self.to_out(out) 180 | 181 | 182 | # feedforward 183 | class SwiGLU(nn.Module): 184 | """ used successfully in https://arxiv.org/abs/2204.0231 """ 185 | 186 | def forward(self, x): 187 | x, gate = x.chunk(2, dim=-1) 188 | return x * F.silu(gate) 189 | 190 | 191 | def FeedForward(dim, mult=4, dropout=0., post_activation_norm=False): 192 | """ post-activation norm https://arxiv.org/abs/2110.09456 """ 193 | 194 | inner_dim = int(mult * dim) 195 | return nn.Sequential( 196 | LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias=False), SwiGLU(), 197 | LayerNorm(inner_dim) if post_activation_norm else nn.Identity(), 198 | nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias=False)) 199 | 200 | 201 | class NonCausalTransformerLanguage(nn.Module): 202 | 203 | def __init__(self, 204 | *, 205 | dim, 206 | depth, 207 | dim_head=64, 208 | heads=8, 209 | ff_mult=4, 210 | norm_in=False, 211 | norm_out=True, 212 | attn_dropout=0., 213 | ff_dropout=0., 214 | final_proj=True, 215 | normformer=False, 216 | rotary_emb=True): 217 | super().__init__() 218 | self.init_norm = LayerNorm(dim) if norm_in else nn.Identity( 219 | ) # from latest BLOOM model and Yandex's YaLM 220 | 221 | self.rel_pos_bias = RelPosBias(heads=heads) 222 | 223 | self.text_feature_mapping = nn.Sequential( 224 | nn.LayerNorm(384), 225 | nn.Linear(384, 256), 226 | nn.LayerNorm(256), 227 | nn.Linear(256, dim), 228 | nn.LayerNorm(dim), 229 | ) 230 | 231 | rotary_emb = RotaryEmbedding( 232 | dim=min(32, dim_head)) if rotary_emb else None 233 | 234 | self.mask_emb = nn.Parameter(torch.zeros(1, 1, dim)) 235 | 236 | self.layers = nn.ModuleList([]) 237 | for _ in range(depth): 238 | self.layers.append( 239 | nn.ModuleList([ 240 | Attention( 241 | dim=dim, 242 | causal=False, 243 | dim_head=dim_head, 244 | heads=heads, 245 | dropout=attn_dropout, 246 | rotary_emb=rotary_emb), 247 | FeedForward( 248 | dim=dim, 249 | mult=ff_mult, 250 | dropout=ff_dropout, 251 | post_activation_norm=normformer) 252 | ])) 253 | 254 | self.norm = LayerNorm( 255 | dim, stable=True 256 | ) if norm_out else nn.Identity( 257 | ) # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options 258 | self.project_out = nn.Linear( 259 | dim, dim, bias=False) if final_proj else nn.Identity() 260 | 261 | def forward(self, x, exemplar_frame_embeddings, text_embedding, masks): 262 | x_masked = x.clone() 263 | x_masked[masks, :] = self.mask_emb 264 | 265 | x = torch.cat((self.text_feature_mapping(text_embedding), 266 | exemplar_frame_embeddings, x_masked), 267 | dim=1).clone() 268 | 269 | n, device = x.shape[1], x.device 270 | x = self.init_norm(x) 271 | 272 | attn_bias = self.rel_pos_bias(n, n + 1, device=device) 273 | 274 | for attn, ff in self.layers: 275 | x = attn(x, attn_bias=attn_bias) + x 276 | x = ff(x) + x 277 | 278 | out = self.norm(x) 279 | return self.project_out(out) -------------------------------------------------------------------------------- /models/archs/einops_exts.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import partial, wraps 3 | 4 | from einops import rearrange, reduce, repeat 5 | from torch import nn 6 | 7 | # checking shape 8 | # @nils-werner 9 | # https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838 10 | 11 | 12 | def check_shape(tensor, pattern, **kwargs): 13 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs) 14 | 15 | 16 | # do same einops operations on a list of tensors 17 | 18 | 19 | def _many(fn): 20 | 21 | @wraps(fn) 22 | def inner(tensors, pattern, **kwargs): 23 | return (fn(tensor, pattern, **kwargs) for tensor in tensors) 24 | 25 | return inner 26 | 27 | 28 | # do einops with unflattening of anonymously named dimensions 29 | # (...flattened) -> ...flattened 30 | 31 | 32 | def _with_anon_dims(fn): 33 | 34 | @wraps(fn) 35 | def inner(tensor, pattern, **kwargs): 36 | regex = r'(\.\.\.[a-zA-Z]+)' 37 | matches = re.findall(regex, pattern) 38 | get_anon_dim_name = lambda t: t.lstrip('...') 39 | dim_prefixes = tuple(map(get_anon_dim_name, set(matches))) 40 | 41 | update_kwargs_dict = dict() 42 | 43 | for prefix in dim_prefixes: 44 | assert prefix in kwargs, f'dimension list "{prefix}" was not passed in' 45 | dim_list = kwargs[prefix] 46 | assert isinstance( 47 | dim_list, (list, tuple) 48 | ), f'dimension list "{prefix}" needs to be a tuple of list of dimensions' 49 | dim_names = list( 50 | map(lambda ind: f'{prefix}{ind}', range(len(dim_list)))) 51 | update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list)) 52 | 53 | def sub_with_anonymous_dims(t): 54 | dim_name_prefix = get_anon_dim_name(t.groups()[0]) 55 | return ' '.join(update_kwargs_dict[dim_name_prefix].keys()) 56 | 57 | pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern) 58 | 59 | for prefix, update_dict in update_kwargs_dict.items(): 60 | del kwargs[prefix] 61 | kwargs.update(update_dict) 62 | 63 | return fn(tensor, pattern_new, **kwargs) 64 | 65 | return inner 66 | 67 | 68 | # generate all helper functions 69 | 70 | rearrange_many = _many(rearrange) 71 | repeat_many = _many(repeat) 72 | reduce_many = _many(reduce) 73 | 74 | rearrange_with_anon_dims = _with_anon_dims(rearrange) 75 | repeat_with_anon_dims = _with_anon_dims(repeat) 76 | reduce_with_anon_dims = _with_anon_dims(reduce) 77 | 78 | # for rearranging to and from a pattern 79 | 80 | 81 | class EinopsToAndFrom(nn.Module): 82 | 83 | def __init__(self, from_einops, to_einops, fn): 84 | super().__init__() 85 | self.from_einops = from_einops 86 | self.to_einops = to_einops 87 | self.fn = fn 88 | 89 | def forward(self, x, **kwargs): 90 | shape = x.shape 91 | reconstitute_kwargs = dict( 92 | tuple(zip(self.from_einops.split(' '), shape))) 93 | x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') 94 | x = self.fn(x, **kwargs) 95 | x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', 96 | **reconstitute_kwargs) 97 | return x 98 | -------------------------------------------------------------------------------- /models/archs/rotary_embedding_torch.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | from math import log, pi 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from torch import einsum, nn 7 | 8 | # helper functions 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def broadcat(tensors, dim=-1): 16 | num_tensors = len(tensors) 17 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 18 | assert len( 19 | shape_lens) == 1, 'tensors must all have the same number of dimensions' 20 | shape_len = list(shape_lens)[0] 21 | 22 | dim = (dim + shape_len) if dim < 0 else dim 23 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 24 | 25 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 26 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims) 27 | ]), 'invalid dimensions for broadcastable concatentation' 28 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 29 | expanded_dims = list( 30 | map(lambda t: (t[0], (t[1], ) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list( 34 | map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 35 | return torch.cat(tensors, dim=dim) 36 | 37 | 38 | # rotary embedding helper functions 39 | 40 | 41 | def rotate_half(x): 42 | x = rearrange(x, '... (d r) -> ... d r', r=2) 43 | x1, x2 = x.unbind(dim=-1) 44 | x = torch.stack((-x2, x1), dim=-1) 45 | return rearrange(x, '... d r -> ... (d r)') 46 | 47 | 48 | def apply_rotary_emb(freqs, t, start_index=0): 49 | freqs = freqs.to(t) 50 | rot_dim = freqs.shape[-1] 51 | end_index = start_index + rot_dim 52 | assert rot_dim <= t.shape[ 53 | -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 54 | t_left, t, t_right = t[..., :start_index], t[ 55 | ..., start_index:end_index], t[..., end_index:] 56 | t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 57 | return torch.cat((t_left, t, t_right), dim=-1) 58 | 59 | 60 | # learned rotation helpers 61 | 62 | 63 | def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): 64 | if exists(freq_ranges): 65 | rotations = einsum('..., f -> ... f', rotations, freq_ranges) 66 | rotations = rearrange(rotations, '... r f -> ... (r f)') 67 | 68 | rotations = repeat(rotations, '... n -> ... (n r)', r=2) 69 | return apply_rotary_emb(rotations, t, start_index=start_index) 70 | 71 | 72 | # classes 73 | 74 | 75 | class RotaryEmbedding(nn.Module): 76 | 77 | def __init__(self, 78 | dim, 79 | custom_freqs=None, 80 | freqs_for='lang', 81 | theta=10000, 82 | max_freq=10, 83 | num_freqs=1, 84 | learned_freq=False): 85 | super().__init__() 86 | if exists(custom_freqs): 87 | freqs = custom_freqs 88 | elif freqs_for == 'lang': 89 | freqs = 1. / ( 90 | theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 91 | elif freqs_for == 'pixel': 92 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 93 | elif freqs_for == 'constant': 94 | freqs = torch.ones(num_freqs).float() 95 | else: 96 | raise ValueError(f'unknown modality {freqs_for}') 97 | 98 | self.cache = dict() 99 | 100 | if learned_freq: 101 | self.freqs = nn.Parameter(freqs) 102 | else: 103 | self.register_buffer('freqs', freqs) 104 | 105 | def rotate_queries_or_keys(self, t, seq_dim=-2): 106 | device = t.device 107 | seq_len = t.shape[seq_dim] 108 | freqs = self.forward( 109 | lambda: torch.arange(seq_len, device=device), cache_key=seq_len) 110 | return apply_rotary_emb(freqs, t) 111 | 112 | def forward(self, t, cache_key=None): 113 | if exists(cache_key) and cache_key in self.cache: 114 | return self.cache[cache_key] 115 | 116 | if isfunction(t): 117 | t = t() 118 | 119 | freqs = self.freqs 120 | 121 | freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) 122 | freqs = repeat(freqs, '... n -> ... (n r)', r=2) 123 | 124 | if exists(cache_key): 125 | self.cache[cache_key] = freqs 126 | 127 | return freqs 128 | -------------------------------------------------------------------------------- /models/archs/transformer_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class CausalSelfAttention(nn.Module): 9 | """ 10 | A vanilla multi-head masked self-attention layer with a projection at the end. 11 | It is possible to use torch.nn.MultiheadAttention here but I am including an 12 | explicit implementation here to show that there is nothing too scary here. 13 | """ 14 | 15 | def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop, 16 | block_size, sampler): 17 | super().__init__() 18 | assert bert_n_emb % bert_n_head == 0 19 | # key, query, value projections for all heads 20 | self.key = nn.Linear(bert_n_emb, bert_n_emb) 21 | self.query = nn.Linear(bert_n_emb, bert_n_emb) 22 | self.value = nn.Linear(bert_n_emb, bert_n_emb) 23 | # regularization 24 | self.attn_drop = nn.Dropout(attn_pdrop) 25 | self.resid_drop = nn.Dropout(resid_pdrop) 26 | # output projection 27 | self.proj = nn.Linear(bert_n_emb, bert_n_emb) 28 | self.n_head = bert_n_head 29 | self.causal = True if sampler == 'autoregressive' else False 30 | if self.causal: 31 | mask = torch.tril(torch.ones(block_size, block_size)) 32 | self.register_buffer("mask", mask.view(1, 1, block_size, 33 | block_size)) 34 | 35 | def forward(self, x, layer_past=None): 36 | B, T, C = x.size() 37 | 38 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 39 | k = self.key(x).view(B, T, self.n_head, 40 | C // self.n_head).transpose(1, 41 | 2) # (B, nh, T, hs) 42 | q = self.query(x).view(B, T, self.n_head, 43 | C // self.n_head).transpose(1, 44 | 2) # (B, nh, T, hs) 45 | v = self.value(x).view(B, T, self.n_head, 46 | C // self.n_head).transpose(1, 47 | 2) # (B, nh, T, hs) 48 | 49 | present = torch.stack((k, v)) 50 | if self.causal and layer_past is not None: 51 | past_key, past_value = layer_past 52 | k = torch.cat((past_key, k), dim=-2) 53 | v = torch.cat((past_value, v), dim=-2) 54 | 55 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 56 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 57 | 58 | if self.causal and layer_past is None: 59 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 60 | 61 | att = F.softmax(att, dim=-1) 62 | att = self.attn_drop(att) 63 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 64 | # re-assemble all head outputs side by side 65 | y = y.transpose(1, 2).contiguous().view(B, T, C) 66 | 67 | # output projection 68 | y = self.resid_drop(self.proj(y)) 69 | return y, present 70 | 71 | 72 | class Block(nn.Module): 73 | """ an unassuming Transformer block """ 74 | 75 | def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, 76 | block_size, sampler): 77 | super().__init__() 78 | self.ln1 = nn.LayerNorm(bert_n_emb) 79 | self.ln2 = nn.LayerNorm(bert_n_emb) 80 | self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop, 81 | resid_pdrop, block_size, sampler) 82 | self.mlp = nn.Sequential( 83 | nn.Linear(bert_n_emb, 4 * bert_n_emb), 84 | nn.GELU(), # nice 85 | nn.Linear(4 * bert_n_emb, bert_n_emb), 86 | nn.Dropout(resid_pdrop), 87 | ) 88 | 89 | def forward(self, x, layer_past=None, return_present=False): 90 | 91 | attn, present = self.attn(self.ln1(x), layer_past) 92 | x = x + attn 93 | x = x + self.mlp(self.ln2(x)) 94 | 95 | if layer_past is not None or return_present: 96 | return x, present 97 | return x 98 | 99 | 100 | class TransformerLanguage(nn.Module): 101 | """ the full GPT language model, with a context size of block_size """ 102 | 103 | def __init__(self, 104 | codebook_size, 105 | bert_n_emb, 106 | bert_n_layers, 107 | bert_n_head, 108 | block_size, 109 | embd_pdrop, 110 | resid_pdrop, 111 | attn_pdrop, 112 | sampler='absorbing'): 113 | super().__init__() 114 | 115 | self.vocab_size = codebook_size + 1 116 | self.n_embd = bert_n_emb 117 | self.block_size = block_size 118 | self.n_layers = bert_n_layers 119 | self.codebook_size = codebook_size 120 | self.causal = sampler == 'autoregressive' 121 | if self.causal: 122 | self.vocab_size = codebook_size 123 | 124 | self.text_feature_mapping = nn.Sequential( 125 | nn.LayerNorm(384), 126 | nn.Linear(384, 256), 127 | nn.LayerNorm(256), 128 | nn.Linear(256, self.n_embd), 129 | nn.LayerNorm(self.n_embd), 130 | ) 131 | 132 | self.identity_tok_emb = nn.Embedding(self.vocab_size, self.n_embd) 133 | self.pose_tok_emb = nn.Embedding(self.vocab_size, self.n_embd) 134 | self.pos_emb = nn.Parameter( 135 | torch.zeros(1, self.block_size, self.n_embd)) 136 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd)) 137 | self.drop = nn.Dropout(embd_pdrop) 138 | 139 | # transformer 140 | self.blocks = nn.Sequential(*[ 141 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, block_size, 142 | sampler) for _ in range(self.n_layers) 143 | ]) 144 | # decoder head 145 | self.ln_f = nn.LayerNorm(self.n_embd) 146 | self.identity_head = nn.Linear( 147 | self.n_embd, self.codebook_size, bias=False) 148 | self.pose_head = nn.Linear(self.n_embd, self.codebook_size, bias=False) 149 | 150 | def get_block_size(self): 151 | return self.block_size 152 | 153 | def _init_weights(self, module): 154 | if isinstance(module, (nn.Linear, nn.Embedding)): 155 | module.weight.data.normal_(mean=0.0, std=0.02) 156 | if isinstance(module, nn.Linear) and module.bias is not None: 157 | module.bias.data.zero_() 158 | elif isinstance(module, nn.LayerNorm): 159 | module.bias.data.zero_() 160 | module.weight.data.fill_(1.0) 161 | 162 | def forward(self, text_embedding, identity_idx, pose_idx, t=None): 163 | # each index maps to a (learnable) vector 164 | identity_token_embeddings = self.identity_tok_emb(identity_idx) 165 | pose_token_embeddings = self.pose_tok_emb(pose_idx) 166 | 167 | token_embeddings = torch.cat( 168 | (identity_token_embeddings, pose_token_embeddings), dim=1) 169 | 170 | if self.causal: 171 | token_embeddings = torch.cat((self.start_tok.repeat( 172 | token_embeddings.size(0), 1, 1), token_embeddings), 173 | dim=1) 174 | 175 | t = token_embeddings.shape[1] 176 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 177 | # each position maps to a (learnable) vector 178 | 179 | position_embeddings = self.pos_emb[:, :t, :] 180 | 181 | x = token_embeddings + position_embeddings 182 | x = self.drop(x) 183 | x = torch.cat([self.text_feature_mapping(text_embedding), x], dim=1) 184 | 185 | for block in self.blocks: 186 | x = block(x) 187 | x = self.ln_f(x) 188 | identity_logits = self.identity_head(x) 189 | pose_logits = self.pose_head(x) 190 | 191 | return identity_logits, pose_logits 192 | -------------------------------------------------------------------------------- /models/archs/vqgan_arch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | 7 | class VectorQuantizer(nn.Module): 8 | """ 9 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 10 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 11 | """ 12 | 13 | # NOTE: due to a bug the beta term was applied to the wrong term. for 14 | # backwards compatibility we use the buggy version by default, but you can 15 | # specify legacy=False to fix it. 16 | def __init__(self, 17 | n_e, 18 | e_dim, 19 | beta, 20 | remap=None, 21 | unknown_index="random", 22 | sane_index_shape=False, 23 | legacy=True): 24 | super().__init__() 25 | self.n_e = n_e 26 | self.e_dim = e_dim 27 | self.beta = beta 28 | self.legacy = legacy 29 | 30 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 31 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 32 | 33 | self.remap = remap 34 | if self.remap is not None: 35 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 36 | self.re_embed = self.used.shape[0] 37 | self.unknown_index = unknown_index # "random" or "extra" or integer 38 | if self.unknown_index == "extra": 39 | self.unknown_index = self.re_embed 40 | self.re_embed = self.re_embed + 1 41 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 42 | f"Using {self.unknown_index} for unknown indices.") 43 | else: 44 | self.re_embed = n_e 45 | 46 | self.sane_index_shape = sane_index_shape 47 | 48 | def remap_to_used(self, inds): 49 | ishape = inds.shape 50 | assert len(ishape) > 1 51 | inds = inds.reshape(ishape[0], -1) 52 | used = self.used.to(inds) 53 | match = (inds[:, :, None] == used[None, None, ...]).long() 54 | new = match.argmax(-1) 55 | unknown = match.sum(2) < 1 56 | if self.unknown_index == "random": 57 | new[unknown] = torch.randint( 58 | 0, self.re_embed, 59 | size=new[unknown].shape).to(device=new.device) 60 | else: 61 | new[unknown] = self.unknown_index 62 | return new.reshape(ishape) 63 | 64 | def unmap_to_all(self, inds): 65 | ishape = inds.shape 66 | assert len(ishape) > 1 67 | inds = inds.reshape(ishape[0], -1) 68 | used = self.used.to(inds) 69 | if self.re_embed > self.used.shape[0]: # extra token 70 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 71 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 72 | return back.reshape(ishape) 73 | 74 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 75 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" 76 | assert rescale_logits == False, "Only for interface compatible with Gumbel" 77 | assert return_logits == False, "Only for interface compatible with Gumbel" 78 | # reshape z -> (batch, height, width, channel) and flatten 79 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 80 | z_flattened = z.view(-1, self.e_dim) 81 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 82 | 83 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 84 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 85 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 86 | 87 | min_encoding_indices = torch.argmin(d, dim=1) 88 | z_q = self.embedding(min_encoding_indices).view(z.shape) 89 | perplexity = None 90 | min_encodings = None 91 | 92 | # compute loss for embedding 93 | if not self.legacy: 94 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 95 | torch.mean((z_q - z.detach()) ** 2) 96 | else: 97 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 98 | torch.mean((z_q - z.detach()) ** 2) 99 | 100 | # preserve gradients 101 | z_q = z + (z_q - z).detach() 102 | 103 | # reshape back to match original input shape 104 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 105 | 106 | if self.remap is not None: 107 | min_encoding_indices = min_encoding_indices.reshape( 108 | z.shape[0], -1) # add batch axis 109 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 110 | min_encoding_indices = min_encoding_indices.reshape(-1, 111 | 1) # flatten 112 | 113 | if self.sane_index_shape: 114 | min_encoding_indices = min_encoding_indices.reshape( 115 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 116 | 117 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 118 | 119 | def get_codebook_entry(self, indices, shape): 120 | # shape specifying (batch, height, width, channel) 121 | if self.remap is not None: 122 | indices = indices.reshape(shape[0], -1) # add batch axis 123 | indices = self.unmap_to_all(indices) 124 | indices = indices.reshape(-1) # flatten again 125 | 126 | # get quantized latent vectors 127 | z_q = self.embedding(indices) 128 | 129 | if shape is not None: 130 | z_q = z_q.view(shape) 131 | # reshape back to match original input shape 132 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 133 | 134 | return z_q 135 | 136 | def get_nearest_codebook_embeddings(self, z, return_loss=False): 137 | # z = rearrange(z, 'b c h w -> b h w c').contiguous() 138 | z_flattened = z.view(-1, self.e_dim) 139 | 140 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 141 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 142 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 143 | 144 | min_encoding_indices = torch.argmin(d, dim=1) 145 | z_q = self.embedding(min_encoding_indices).view(z.shape) 146 | 147 | if return_loss: 148 | # compute loss for embedding 149 | if not self.legacy: 150 | loss = self.beta * torch.mean((z_q.detach() - z)**2) 151 | else: 152 | loss = torch.mean((z_q.detach() - z)**2) 153 | 154 | # preserve gradients 155 | z_q = z + (z_q - z).detach() 156 | 157 | # reshape back to match original input shape 158 | # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 159 | 160 | if return_loss: 161 | return z_q, loss 162 | else: 163 | return z_q 164 | 165 | 166 | class ResnetBlock(nn.Module): 167 | 168 | def __init__(self, 169 | *, 170 | in_channels, 171 | out_channels=None, 172 | conv_shortcut=False, 173 | dropout, 174 | temb_channels=512): 175 | super().__init__() 176 | self.in_channels = in_channels 177 | out_channels = in_channels if out_channels is None else out_channels 178 | self.out_channels = out_channels 179 | self.use_conv_shortcut = conv_shortcut 180 | 181 | self.norm1 = Normalize(in_channels) 182 | self.conv1 = torch.nn.Conv2d( 183 | in_channels, out_channels, kernel_size=3, stride=1, padding=1) 184 | if temb_channels > 0: 185 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 186 | self.norm2 = Normalize(out_channels) 187 | self.dropout = torch.nn.Dropout(dropout) 188 | self.conv2 = torch.nn.Conv2d( 189 | out_channels, out_channels, kernel_size=3, stride=1, padding=1) 190 | if self.in_channels != self.out_channels: 191 | if self.use_conv_shortcut: 192 | self.conv_shortcut = torch.nn.Conv2d( 193 | in_channels, 194 | out_channels, 195 | kernel_size=3, 196 | stride=1, 197 | padding=1) 198 | else: 199 | self.nin_shortcut = torch.nn.Conv2d( 200 | in_channels, 201 | out_channels, 202 | kernel_size=1, 203 | stride=1, 204 | padding=0) 205 | 206 | def forward(self, x, temb): 207 | h = x 208 | h = self.norm1(h) 209 | h = nonlinearity(h) 210 | h = self.conv1(h) 211 | 212 | if temb is not None: 213 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 214 | 215 | h = self.norm2(h) 216 | h = nonlinearity(h) 217 | h = self.dropout(h) 218 | h = self.conv2(h) 219 | 220 | if self.in_channels != self.out_channels: 221 | if self.use_conv_shortcut: 222 | x = self.conv_shortcut(x) 223 | else: 224 | x = self.nin_shortcut(x) 225 | 226 | return x + h 227 | 228 | 229 | class AttnBlock(nn.Module): 230 | 231 | def __init__(self, in_channels): 232 | super().__init__() 233 | self.in_channels = in_channels 234 | 235 | self.norm = Normalize(in_channels) 236 | self.q = torch.nn.Conv2d( 237 | in_channels, in_channels, kernel_size=1, stride=1, padding=0) 238 | self.k = torch.nn.Conv2d( 239 | in_channels, in_channels, kernel_size=1, stride=1, padding=0) 240 | self.v = torch.nn.Conv2d( 241 | in_channels, in_channels, kernel_size=1, stride=1, padding=0) 242 | self.proj_out = torch.nn.Conv2d( 243 | in_channels, in_channels, kernel_size=1, stride=1, padding=0) 244 | 245 | def forward(self, x): 246 | h_ = x 247 | h_ = self.norm(h_) 248 | q = self.q(h_) 249 | k = self.k(h_) 250 | v = self.v(h_) 251 | 252 | # compute attention 253 | b, c, h, w = q.shape 254 | q = q.reshape(b, c, h * w) 255 | q = q.permute(0, 2, 1) # b,hw,c 256 | k = k.reshape(b, c, h * w) # b,c,hw 257 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 258 | w_ = w_ * (int(c)**(-0.5)) 259 | w_ = torch.nn.functional.softmax(w_, dim=2) 260 | 261 | # attend to values 262 | v = v.reshape(b, c, h * w) 263 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 264 | h_ = torch.bmm( 265 | v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 266 | h_ = h_.reshape(b, c, h, w) 267 | 268 | h_ = self.proj_out(h_) 269 | 270 | return x + h_ 271 | 272 | 273 | class Upsample(nn.Module): 274 | 275 | def __init__(self, in_channels, with_conv): 276 | super().__init__() 277 | self.with_conv = with_conv 278 | if self.with_conv: 279 | self.conv = torch.nn.Conv2d( 280 | in_channels, in_channels, kernel_size=3, stride=1, padding=1) 281 | 282 | def forward(self, x): 283 | x = torch.nn.functional.interpolate( 284 | x, scale_factor=2.0, mode="nearest") 285 | if self.with_conv: 286 | x = self.conv(x) 287 | return x 288 | 289 | 290 | class Downsample(nn.Module): 291 | 292 | def __init__(self, in_channels, with_conv): 293 | super().__init__() 294 | self.with_conv = with_conv 295 | if self.with_conv: 296 | # no asymmetric padding in torch conv, must do it ourselves 297 | self.conv = torch.nn.Conv2d( 298 | in_channels, in_channels, kernel_size=3, stride=2, padding=0) 299 | 300 | def forward(self, x): 301 | if self.with_conv: 302 | pad = (0, 1, 0, 1) 303 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 304 | x = self.conv(x) 305 | else: 306 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 307 | return x 308 | 309 | 310 | def nonlinearity(x): 311 | # swish 312 | return x * torch.sigmoid(x) 313 | 314 | 315 | def Normalize(in_channels): 316 | return torch.nn.GroupNorm( 317 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 318 | 319 | 320 | class DecoderUpOthersDoubleIdentity(nn.Module): 321 | 322 | def __init__(self, 323 | in_channels, 324 | resolution, 325 | z_channels, 326 | ch, 327 | out_ch, 328 | num_res_blocks, 329 | attn_resolutions, 330 | ch_mult=(1, 2, 4, 8), 331 | other_ch_mult=(8, 8), 332 | dropout=0.0, 333 | resamp_with_conv=True, 334 | give_pre_end=False): 335 | super().__init__() 336 | self.ch = ch 337 | self.temb_ch = 0 338 | self.num_resolutions = len(ch_mult) 339 | self.num_res_blocks = num_res_blocks 340 | self.resolution = resolution 341 | self.in_channels = in_channels 342 | self.give_pre_end = give_pre_end 343 | 344 | self.num_other_resolutions = len(other_ch_mult) 345 | 346 | # compute in_ch_mult, block_in and curr_res at lowest res 347 | in_ch_mult = (1, ) + tuple(ch_mult) 348 | block_in = ch * ch_mult[self.num_resolutions - 1] 349 | 350 | curr_res = resolution // 2**(self.num_resolutions - 1) 351 | self.z_shape = (1, z_channels, curr_res, curr_res // 2) 352 | print("Working with z of shape {} = {} dimensions.".format( 353 | self.z_shape, np.prod(self.z_shape))) 354 | 355 | # z to block_in 356 | self.conv_in_identity = torch.nn.Conv2d( 357 | z_channels, block_in, kernel_size=3, stride=1, padding=1) 358 | 359 | # z to block_in 360 | self.conv_in_others = torch.nn.Conv2d( 361 | z_channels // 2, block_in // 2, kernel_size=3, stride=1, padding=1) 362 | 363 | self.conv_in = torch.nn.Conv2d( 364 | block_in // 2 + block_in, 365 | block_in, 366 | kernel_size=3, 367 | stride=1, 368 | padding=1) 369 | 370 | # others upsampling 371 | self.others_up = nn.ModuleList() 372 | block_in_others = ch // 2 * ch_mult[self.num_resolutions - 1] 373 | for i_level in reversed(range(self.num_other_resolutions)): 374 | block = nn.ModuleList() 375 | block_out_others = ch // 2 * other_ch_mult[i_level] 376 | for i_block in range(self.num_res_blocks + 1): 377 | block.append( 378 | ResnetBlock( 379 | in_channels=block_in_others, 380 | out_channels=block_out_others, 381 | temb_channels=self.temb_ch, 382 | dropout=dropout)) 383 | block_in_others = block_out_others 384 | up = nn.Module() 385 | up.block = block 386 | up.upsample = Upsample(block_in_others, resamp_with_conv) 387 | self.others_up.insert(0, up) # prepend to get consistent order 388 | 389 | # middle 390 | self.mid = nn.Module() 391 | self.mid.block_1 = ResnetBlock( 392 | in_channels=block_in, 393 | out_channels=block_in, 394 | temb_channels=self.temb_ch, 395 | dropout=dropout) 396 | self.mid.attn_1 = AttnBlock(block_in) 397 | self.mid.block_2 = ResnetBlock( 398 | in_channels=block_in, 399 | out_channels=block_in, 400 | temb_channels=self.temb_ch, 401 | dropout=dropout) 402 | 403 | # upsampling 404 | self.up = nn.ModuleList() 405 | for i_level in reversed(range(self.num_resolutions)): 406 | block = nn.ModuleList() 407 | attn = nn.ModuleList() 408 | block_out = ch * ch_mult[i_level] 409 | for i_block in range(self.num_res_blocks + 1): 410 | block.append( 411 | ResnetBlock( 412 | in_channels=block_in, 413 | out_channels=block_out, 414 | temb_channels=self.temb_ch, 415 | dropout=dropout)) 416 | block_in = block_out 417 | if curr_res in attn_resolutions: 418 | attn.append(AttnBlock(block_in)) 419 | up = nn.Module() 420 | up.block = block 421 | up.attn = attn 422 | if i_level != 0: 423 | up.upsample = Upsample(block_in, resamp_with_conv) 424 | curr_res = curr_res * 2 425 | self.up.insert(0, up) # prepend to get consistent order 426 | 427 | # end 428 | self.norm_out = Normalize(block_in) 429 | self.conv_out = torch.nn.Conv2d( 430 | block_in, out_ch, kernel_size=3, stride=1, padding=1) 431 | 432 | def forward(self, z_identity, z_others): 433 | # timestep embedding 434 | temb = None 435 | 436 | # upsampling others 437 | h_others = self.conv_in_others(z_others) 438 | for i_level in reversed(range(self.num_other_resolutions)): 439 | for i_block in range(self.num_res_blocks + 1): 440 | h_others = self.others_up[i_level].block[i_block](h_others, 441 | temb) 442 | h_others = self.others_up[i_level].upsample(h_others) 443 | 444 | # z to block_in 445 | h_identity = self.conv_in_identity(z_identity) 446 | 447 | h = self.conv_in(torch.cat([h_identity, h_others], dim=1)) 448 | 449 | # middle 450 | h = self.mid.block_1(h, temb) 451 | h = self.mid.attn_1(h) 452 | h = self.mid.block_2(h, temb) 453 | 454 | # upsampling 455 | for i_level in reversed(range(self.num_resolutions)): 456 | for i_block in range(self.num_res_blocks + 1): 457 | h = self.up[i_level].block[i_block](h, temb) 458 | if len(self.up[i_level].attn) > 0: 459 | h = self.up[i_level].attn[i_block](h) 460 | if i_level != 0: 461 | h = self.up[i_level].upsample(h) 462 | 463 | # end 464 | if self.give_pre_end: 465 | return h 466 | 467 | h = self.norm_out(h) 468 | h = nonlinearity(h) 469 | h = self.conv_out(h) 470 | return h 471 | 472 | 473 | class EncoderDecomposeBaseDownOthersDoubleIdentity(nn.Module): 474 | 475 | def __init__(self, 476 | ch, 477 | num_res_blocks, 478 | attn_resolutions, 479 | in_channels, 480 | resolution, 481 | z_channels, 482 | ch_mult=(1, 2, 4, 8), 483 | other_ch_mult=(8, 8), 484 | dropout=0.0, 485 | resamp_with_conv=True, 486 | double_z=True): 487 | super().__init__() 488 | self.ch = ch 489 | self.temb_ch = 0 490 | self.num_resolutions = len(ch_mult) 491 | self.num_res_blocks = num_res_blocks 492 | self.resolution = resolution 493 | self.in_channels = in_channels 494 | 495 | self.num_other_resolutions = len(other_ch_mult) 496 | 497 | # downsampling 498 | self.conv_in = torch.nn.Conv2d( 499 | in_channels, self.ch, kernel_size=3, stride=1, padding=1) 500 | 501 | curr_res = resolution 502 | in_ch_mult = (1, ) + tuple(ch_mult) 503 | self.down = nn.ModuleList() 504 | for i_level in range(self.num_resolutions): 505 | block = nn.ModuleList() 506 | attn = nn.ModuleList() 507 | block_in = ch * in_ch_mult[i_level] 508 | block_out = ch * ch_mult[i_level] 509 | for i_block in range(self.num_res_blocks): 510 | block.append( 511 | ResnetBlock( 512 | in_channels=block_in, 513 | out_channels=block_out, 514 | temb_channels=self.temb_ch, 515 | dropout=dropout)) 516 | block_in = block_out 517 | if curr_res in attn_resolutions: 518 | attn.append(AttnBlock(block_in)) 519 | down = nn.Module() 520 | down.block = block 521 | down.attn = attn 522 | if i_level != self.num_resolutions - 1: 523 | down.downsample = Downsample(block_in, resamp_with_conv) 524 | curr_res = curr_res // 2 525 | self.down.append(down) 526 | 527 | # identity branch 528 | # middle 529 | self.mid_identity = nn.Module() 530 | self.mid_identity.block_1 = ResnetBlock( 531 | in_channels=block_in, 532 | out_channels=block_in, 533 | temb_channels=self.temb_ch, 534 | dropout=dropout) 535 | self.mid_identity.attn_1 = AttnBlock(block_in) 536 | self.mid_identity.block_2 = ResnetBlock( 537 | in_channels=block_in, 538 | out_channels=block_in, 539 | temb_channels=self.temb_ch, 540 | dropout=dropout) 541 | 542 | # end 543 | self.norm_out_identity = Normalize(block_in) 544 | self.conv_out_identity = torch.nn.Conv2d( 545 | block_in, 546 | z_channels * 2 if double_z else z_channels, 547 | kernel_size=3, 548 | stride=1, 549 | padding=1) 550 | 551 | self.other_down = nn.ModuleList() 552 | for i_level in range(self.num_other_resolutions): 553 | block = nn.ModuleList() 554 | block_in = ch * other_ch_mult[i_level] 555 | block_out = ch * other_ch_mult[i_level] 556 | for i_block in range(self.num_res_blocks): 557 | block.append( 558 | ResnetBlock( 559 | in_channels=block_in, 560 | out_channels=block_out, 561 | temb_channels=self.temb_ch, 562 | dropout=dropout)) 563 | block_in = block_out 564 | 565 | down = nn.Module() 566 | down.block = block 567 | down.downsample = Downsample(block_in, resamp_with_conv) 568 | self.other_down.append(down) 569 | 570 | # other branch 571 | # middle 572 | self.mid_other = nn.Module() 573 | self.mid_other.block_1 = ResnetBlock( 574 | in_channels=block_in, 575 | out_channels=block_in, 576 | temb_channels=self.temb_ch, 577 | dropout=dropout) 578 | self.mid_other.attn_1 = AttnBlock(block_in) 579 | self.mid_other.block_2 = ResnetBlock( 580 | in_channels=block_in, 581 | out_channels=block_in, 582 | temb_channels=self.temb_ch, 583 | dropout=dropout) 584 | 585 | self.norm_out_other = Normalize(block_in) 586 | self.conv_out_other = torch.nn.Conv2d( 587 | block_in, 588 | z_channels if double_z else z_channels // 2, 589 | kernel_size=3, 590 | stride=1, 591 | padding=1) 592 | 593 | def forward(self, x): 594 | #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 595 | 596 | # timestep embedding 597 | temb = None 598 | 599 | # downsampling 600 | hs = [self.conv_in(x)] 601 | for i_level in range(self.num_resolutions): 602 | for i_block in range(self.num_res_blocks): 603 | h = self.down[i_level].block[i_block](hs[-1], temb) 604 | if len(self.down[i_level].attn) > 0: 605 | h = self.down[i_level].attn[i_block](h) 606 | hs.append(h) 607 | if i_level != self.num_resolutions - 1: 608 | hs.append(self.down[i_level].downsample(hs[-1])) 609 | 610 | # identity branch 611 | # middle 612 | h_identity = self.mid_identity.block_1(hs[-1], temb) 613 | h_identity = self.mid_identity.attn_1(h_identity) 614 | h_identity = self.mid_identity.block_2(h_identity, temb) 615 | 616 | # end 617 | h_identity = self.norm_out_identity(h_identity) 618 | h_identity = nonlinearity(h_identity) 619 | h_identity = self.conv_out_identity(h_identity) 620 | 621 | # other branch 622 | for i_level in range(self.num_other_resolutions): 623 | for i_block in range(self.num_res_blocks): 624 | h = self.other_down[i_level].block[i_block](hs[-1], temb) 625 | hs.append(h) 626 | hs.append(self.other_down[i_level].downsample(hs[-1])) 627 | 628 | # middle 629 | h_other = self.mid_other.block_1(hs[-1], temb) 630 | h_other = self.mid_other.attn_1(h_other) 631 | h_other = self.mid_other.block_2(h_other, temb) 632 | 633 | # end 634 | h_other = self.norm_out_other(h_other) 635 | h_other = nonlinearity(h_other) 636 | h_other = self.conv_out_other(h_other) 637 | return h_identity, h_other 638 | 639 | 640 | # patch based discriminator 641 | class Discriminator(nn.Module): 642 | 643 | def __init__(self, nc, ndf, n_layers=3): 644 | super().__init__() 645 | 646 | layers = [ 647 | nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), 648 | nn.LeakyReLU(0.2, True) 649 | ] 650 | ndf_mult = 1 651 | ndf_mult_prev = 1 652 | for n in range(1, 653 | n_layers): # gradually increase the number of filters 654 | ndf_mult_prev = ndf_mult 655 | ndf_mult = min(2**n, 8) 656 | layers += [ 657 | nn.Conv2d( 658 | ndf * ndf_mult_prev, 659 | ndf * ndf_mult, 660 | kernel_size=4, 661 | stride=2, 662 | padding=1, 663 | bias=False), 664 | nn.BatchNorm2d(ndf * ndf_mult), 665 | nn.LeakyReLU(0.2, True) 666 | ] 667 | 668 | ndf_mult_prev = ndf_mult 669 | ndf_mult = min(2**n_layers, 8) 670 | 671 | layers += [ 672 | nn.Conv2d( 673 | ndf * ndf_mult_prev, 674 | ndf * ndf_mult, 675 | kernel_size=4, 676 | stride=1, 677 | padding=1, 678 | bias=False), 679 | nn.BatchNorm2d(ndf * ndf_mult), 680 | nn.LeakyReLU(0.2, True) 681 | ] 682 | 683 | layers += [ 684 | nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1) 685 | ] # output 1 channel prediction map 686 | self.main = nn.Sequential(*layers) 687 | 688 | def forward(self, x): 689 | return self.main(x) 690 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | 6 | import torch 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | from utils.dist_util import master_only 9 | 10 | logger = logging.getLogger('base') 11 | 12 | 13 | class BaseModel(): 14 | """Base model.""" 15 | 16 | def __init__(self, opt): 17 | self.opt = opt 18 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 19 | self.is_train = opt['is_train'] 20 | 21 | def feed_data(self, data): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | pass 29 | 30 | def save(self, epoch, current_iter): 31 | """Save networks and training state.""" 32 | pass 33 | 34 | def validation(self, dataloader, current_iter, tb_logger, save_img=False): 35 | """Validation function. 36 | 37 | Args: 38 | dataloader (torch.utils.data.DataLoader): Validation dataloader. 39 | current_iter (int): Current iteration. 40 | tb_logger (tensorboard logger): Tensorboard logger. 41 | save_img (bool): Whether to save images. Default: False. 42 | """ 43 | if self.opt['dist']: 44 | self.dist_validation(dataloader, current_iter, tb_logger, save_img) 45 | else: 46 | self.nondist_validation(dataloader, current_iter, tb_logger, 47 | save_img) 48 | 49 | def get_current_log(self): 50 | return self.log_dict 51 | 52 | def model_to_device(self, net): 53 | """Model to device. It also warps models with DistributedDataParallel 54 | or DataParallel. 55 | 56 | Args: 57 | net (nn.Module) 58 | """ 59 | net = net.to(self.device) 60 | if self.opt['dist']: 61 | find_unused_parameters = self.opt.get('find_unused_parameters', 62 | False) 63 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 64 | net = DistributedDataParallel( 65 | net, 66 | device_ids=[torch.cuda.current_device()], 67 | find_unused_parameters=find_unused_parameters) 68 | elif self.opt['num_gpu'] > 1: 69 | net = DataParallel(net) 70 | return net 71 | 72 | def get_bare_model(self, net): 73 | """Get bare model, especially under wrapping with 74 | DistributedDataParallel or DataParallel. 75 | """ 76 | if isinstance(net, (DataParallel, DistributedDataParallel)): 77 | net = net.module 78 | return net 79 | 80 | @master_only 81 | def save_network(self, net, save_path): 82 | """Save networks. 83 | 84 | Args: 85 | net (nn.Module | list[nn.Module]): Network(s) to be saved. 86 | net_label (str): Network label. 87 | current_iter (int): Current iter number. 88 | param_key (str | list[str]): The parameter key(s) to save network. 89 | Default: 'params'. 90 | """ 91 | 92 | net = self.get_bare_model(net) 93 | state_dict = net.state_dict() 94 | for key, param in state_dict.items(): 95 | if key.startswith('module.'): # remove unnecessary 'module.' 96 | key = key[7:] 97 | state_dict[key] = param.cpu() 98 | 99 | torch.save(state_dict, save_path) 100 | 101 | def load_network(self, net, load_path, strict=True, param_key='params'): 102 | """Load network. 103 | 104 | Args: 105 | load_path (str): The path of networks to be loaded. 106 | net (nn.Module): Network. 107 | strict (bool): Whether strictly loaded. 108 | param_key (str): The parameter key of loaded network. 109 | Default: 'params'. 110 | """ 111 | net = self.get_bare_model(net) 112 | logger.info( 113 | f'Loading {net.__class__.__name__} model from {load_path}.') 114 | load_net = torch.load( 115 | load_path, map_location=lambda storage, loc: storage)[param_key] 116 | # remove unnecessary 'module.' 117 | for k, v in deepcopy(load_net).items(): 118 | if k.startswith('module.'): 119 | load_net[k[7:]] = v 120 | load_net.pop(k) 121 | # self._print_different_keys_loading(net, load_net, strict) 122 | net.load_state_dict(load_net, strict=strict) 123 | 124 | def update_learning_rate(self, epoch, iters=None): 125 | """Update learning rate. 126 | 127 | Args: 128 | current_iter (int): Current iteration. 129 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 130 | Default: -1. 131 | """ 132 | lr = self.optimizer.param_groups[0]['lr'] 133 | 134 | if self.opt['lr_decay'] == 'step': 135 | lr = self.opt['lr'] * ( 136 | self.opt['gamma']**(epoch // self.opt['step'])) 137 | elif self.opt['lr_decay'] == 'cos': 138 | lr = self.opt['lr'] * ( 139 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 140 | elif self.opt['lr_decay'] == 'linear': 141 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) 142 | elif self.opt['lr_decay'] == 'linear2exp': 143 | if epoch < self.opt['turning_point'] + 1: 144 | # learning rate decay as 95% 145 | # at the turning point (1 / 95% = 1.0526) 146 | lr = self.opt['lr'] * ( 147 | 1 - epoch / int(self.opt['turning_point'] * 1.0526)) 148 | else: 149 | lr *= self.opt['gamma'] 150 | elif self.opt['lr_decay'] == 'schedule': 151 | if epoch in self.opt['schedule']: 152 | lr *= self.opt['gamma'] 153 | elif self.opt['lr_decay'] == 'warm_up': 154 | if iters <= self.opt['warmup_iters']: 155 | lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters'] 156 | else: 157 | lr = self.opt['lr'] 158 | else: 159 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) 160 | # set learning rate 161 | for param_group in self.optimizer.param_groups: 162 | param_group['lr'] = lr 163 | 164 | return lr 165 | 166 | def reduce_loss_dict(self, loss_dict): 167 | """reduce loss dict. 168 | 169 | In distributed training, it averages the losses among different GPUs . 170 | 171 | Args: 172 | loss_dict (OrderedDict): Loss dict. 173 | """ 174 | with torch.no_grad(): 175 | if self.opt['dist']: 176 | keys = [] 177 | losses = [] 178 | for name, value in loss_dict.items(): 179 | keys.append(name) 180 | losses.append(value) 181 | losses = torch.stack(losses, 0) 182 | torch.distributed.reduce(losses, dst=0) 183 | if self.opt['rank'] == 0: 184 | losses /= self.opt['world_size'] 185 | loss_dict = {key: loss for key, loss in zip(keys, losses)} 186 | 187 | log_dict = OrderedDict() 188 | for name, value in loss_dict.items(): 189 | log_dict[name] = value.mean().item() 190 | 191 | return log_dict 192 | -------------------------------------------------------------------------------- /models/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/models/losses/__init__.py -------------------------------------------------------------------------------- /models/losses/vqgan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | @torch.jit.script 6 | def hinge_d_loss(logits_real, logits_fake): 7 | loss_real = torch.mean(F.relu(1. - logits_real)) 8 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 9 | d_loss = 0.5 * (loss_real + loss_fake) 10 | 11 | return d_loss, loss_real, loss_fake 12 | 13 | 14 | def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max): 15 | recon_grads = torch.autograd.grad( 16 | recon_loss, last_layer, retain_graph=True)[0] 17 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 18 | 19 | d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) 20 | d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() 21 | return d_weight 22 | 23 | 24 | def adopt_weight(weight, global_step, threshold=0, value=0.): 25 | if global_step < threshold: 26 | weight = value 27 | return weight 28 | 29 | 30 | def DiffAugment(x, policy='', channels_first=True): 31 | if policy: 32 | if not channels_first: 33 | x = x.permute(0, 3, 1, 2) 34 | for p in policy.split(','): 35 | for f in AUGMENT_FNS[p]: 36 | x = f(x) 37 | if not channels_first: 38 | x = x.permute(0, 2, 3, 1) 39 | x = x.contiguous() 40 | return x 41 | 42 | 43 | def rand_brightness(x): 44 | x = x + ( 45 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 46 | return x 47 | 48 | 49 | def rand_saturation(x): 50 | x_mean = x.mean(dim=1, keepdim=True) 51 | x = (x - x_mean) * (torch.rand( 52 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 53 | return x 54 | 55 | 56 | def rand_contrast(x): 57 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 58 | x = (x - x_mean) * (torch.rand( 59 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 60 | return x 61 | 62 | 63 | def rand_translation(x, ratio=0.125): 64 | shift_x, shift_y = int(x.size(2) * ratio + 65 | 0.5), int(x.size(3) * ratio + 0.5) 66 | translation_x = torch.randint( 67 | -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 68 | translation_y = torch.randint( 69 | -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 70 | grid_batch, grid_x, grid_y = torch.meshgrid( 71 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 72 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 73 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 74 | ) 75 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 76 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 77 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 78 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, 79 | grid_y].permute(0, 3, 1, 2) 80 | return x 81 | 82 | 83 | def rand_cutout(x, ratio=0.5): 84 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 85 | offset_x = torch.randint( 86 | 0, 87 | x.size(2) + (1 - cutout_size[0] % 2), 88 | size=[x.size(0), 1, 1], 89 | device=x.device) 90 | offset_y = torch.randint( 91 | 0, 92 | x.size(3) + (1 - cutout_size[1] % 2), 93 | size=[x.size(0), 1, 1], 94 | device=x.device) 95 | grid_batch, grid_x, grid_y = torch.meshgrid( 96 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 97 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 98 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 99 | ) 100 | grid_x = torch.clamp( 101 | grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 102 | grid_y = torch.clamp( 103 | grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 104 | mask = torch.ones( 105 | x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 106 | mask[grid_batch, grid_x, grid_y] = 0 107 | x = x * mask.unsqueeze(1) 108 | return x 109 | 110 | 111 | AUGMENT_FNS = { 112 | 'color': [rand_brightness, rand_saturation, rand_contrast], 113 | 'translation': [rand_translation], 114 | 'cutout': [rand_cutout], 115 | } 116 | -------------------------------------------------------------------------------- /models/vqgan_decompose_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import lpips 5 | import torch 6 | from torchvision.utils import save_image 7 | 8 | from models.archs.vqgan_arch import ( 9 | DecoderUpOthersDoubleIdentity, Discriminator, 10 | EncoderDecomposeBaseDownOthersDoubleIdentity, VectorQuantizer) 11 | from models.base_model import BaseModel 12 | from models.losses.vqgan_loss import (DiffAugment, adopt_weight, 13 | calculate_adaptive_weight, hinge_d_loss) 14 | from utils.dist_util import master_only 15 | 16 | 17 | class VQGANDecomposeModel(BaseModel): 18 | 19 | def __init__(self, opt): 20 | super().__init__(opt) 21 | 22 | self.encoder = self.model_to_device( 23 | EncoderDecomposeBaseDownOthersDoubleIdentity( 24 | ch=opt['ch'], 25 | num_res_blocks=opt['num_res_blocks'], 26 | attn_resolutions=opt['attn_resolutions'], 27 | ch_mult=opt['ch_mult'], 28 | other_ch_mult=opt['other_ch_mult'], 29 | in_channels=opt['in_channels'], 30 | resolution=opt['resolution'], 31 | z_channels=opt['z_channels'], 32 | double_z=opt['double_z'], 33 | dropout=opt['dropout'])) 34 | self.decoder = self.model_to_device( 35 | DecoderUpOthersDoubleIdentity( 36 | in_channels=opt['in_channels'], 37 | resolution=opt['resolution'], 38 | z_channels=opt['z_channels'], 39 | ch=opt['ch'], 40 | out_ch=opt['out_ch'], 41 | num_res_blocks=opt['num_res_blocks'], 42 | attn_resolutions=opt['attn_resolutions'], 43 | ch_mult=opt['ch_mult'], 44 | other_ch_mult=opt['other_ch_mult'], 45 | dropout=opt['dropout'], 46 | resamp_with_conv=True, 47 | give_pre_end=False)) 48 | self.quantize_identity = self.model_to_device( 49 | VectorQuantizer(opt['n_embed'], opt['embed_dim'], beta=0.25)) 50 | self.quant_conv_identity = self.model_to_device( 51 | torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], 1)) 52 | self.post_quant_conv_identity = self.model_to_device( 53 | torch.nn.Conv2d(opt['embed_dim'], opt["z_channels"], 1)) 54 | 55 | self.quantize_others = self.model_to_device( 56 | VectorQuantizer(opt['n_embed'], opt['embed_dim'] // 2, beta=0.25)) 57 | self.quant_conv_others = self.model_to_device( 58 | torch.nn.Conv2d(opt["z_channels"] // 2, opt['embed_dim'] // 2, 1)) 59 | self.post_quant_conv_others = self.model_to_device( 60 | torch.nn.Conv2d(opt['embed_dim'] // 2, opt["z_channels"] // 2, 1)) 61 | 62 | self.disc = self.model_to_device( 63 | Discriminator( 64 | opt['n_channels'], opt['ndf'], n_layers=opt['disc_layers'])) 65 | self.perceptual = lpips.LPIPS(net="vgg").to(self.device) 66 | self.perceptual_weight = opt['perceptual_weight'] 67 | self.disc_start_step = opt['disc_start_step'] 68 | self.disc_weight_max = opt['disc_weight_max'] 69 | self.diff_aug = opt['diff_aug'] 70 | self.policy = "color,translation" 71 | 72 | self.disc.train() 73 | 74 | if self.opt['pretrained_models'] is not None: 75 | self.load_pretrained_network() 76 | 77 | self.init_training_settings() 78 | 79 | def init_training_settings(self): 80 | self.configure_optimizers() 81 | 82 | def configure_optimizers(self): 83 | self.optimizer = torch.optim.Adam( 84 | list(self.encoder.parameters()) + list(self.decoder.parameters()) + 85 | list(self.quantize_identity.parameters()) + 86 | list(self.quant_conv_identity.parameters()) + 87 | list(self.post_quant_conv_identity.parameters()) + 88 | list(self.quantize_others.parameters()) + 89 | list(self.quant_conv_others.parameters()) + 90 | list(self.post_quant_conv_others.parameters()), 91 | lr=self.opt['lr']) 92 | 93 | self.disc_optimizer = torch.optim.Adam( 94 | self.disc.parameters(), lr=self.opt['lr']) 95 | 96 | @master_only 97 | def save_network(self, save_path): 98 | """Save networks. 99 | 100 | Args: 101 | net (nn.Module): Network to be saved. 102 | net_label (str): Network label. 103 | current_iter (int): Current iter number. 104 | """ 105 | 106 | save_dict = {} 107 | save_dict['encoder'] = self.get_bare_model(self.encoder).state_dict() 108 | save_dict['decoder'] = self.get_bare_model(self.decoder).state_dict() 109 | save_dict['quantize_identity'] = self.get_bare_model( 110 | self.quantize_identity).state_dict() 111 | save_dict['quant_conv_identity'] = self.get_bare_model( 112 | self.quant_conv_identity).state_dict() 113 | save_dict['post_quant_conv_identity'] = self.get_bare_model( 114 | self.post_quant_conv_identity).state_dict() 115 | save_dict['quantize_others'] = self.get_bare_model( 116 | self.quantize_others).state_dict() 117 | save_dict['quant_conv_others'] = self.get_bare_model( 118 | self.quant_conv_others).state_dict() 119 | save_dict['post_quant_conv_others'] = self.get_bare_model( 120 | self.post_quant_conv_others).state_dict() 121 | save_dict['disc'] = self.get_bare_model(self.disc).state_dict() 122 | torch.save(save_dict, save_path) 123 | 124 | def load_pretrained_network(self): 125 | 126 | self.load_network( 127 | self.encoder, self.opt['pretrained_models'], param_key='encoder') 128 | self.load_network( 129 | self.decoder, self.opt['pretrained_models'], param_key='decoder') 130 | self.load_network( 131 | self.quantize_identity, 132 | self.opt['pretrained_models'], 133 | param_key='quantize_identity') 134 | self.load_network( 135 | self.quant_conv_identity, 136 | self.opt['pretrained_models'], 137 | param_key='quant_conv_identity') 138 | self.load_network( 139 | self.post_quant_conv_identity, 140 | self.opt['pretrained_models'], 141 | param_key='post_quant_conv_identity') 142 | self.load_network( 143 | self.quantize_others, 144 | self.opt['pretrained_models'], 145 | param_key='quantize_others') 146 | self.load_network( 147 | self.quant_conv_others, 148 | self.opt['pretrained_models'], 149 | param_key='quant_conv_others') 150 | self.load_network( 151 | self.post_quant_conv_others, 152 | self.opt['pretrained_models'], 153 | param_key='post_quant_conv_others') 154 | 155 | def optimize_parameters(self, data, current_iter): 156 | self.encoder.train() 157 | self.decoder.train() 158 | self.quantize_identity.train() 159 | self.quant_conv_identity.train() 160 | self.post_quant_conv_identity.train() 161 | self.quantize_others.train() 162 | self.quant_conv_others.train() 163 | self.post_quant_conv_others.train() 164 | 165 | loss, d_loss = self.training_step(data, current_iter) 166 | self.optimizer.zero_grad() 167 | loss.backward() 168 | self.optimizer.step() 169 | 170 | if current_iter > self.disc_start_step: 171 | self.disc_optimizer.zero_grad() 172 | d_loss.backward() 173 | self.disc_optimizer.step() 174 | 175 | def feed_data(self, data): 176 | x_identity = data['identity_image'] 177 | x_frame_aug = data['frame_img_aug'] 178 | x_frame = data['frame_img'] 179 | 180 | return x_identity.float().to(self.device), x_frame_aug.float().to( 181 | self.device), x_frame.float().to(self.device) 182 | 183 | def encode(self, x_identity, x_frame): 184 | h_identity, _ = self.encoder(x_identity) 185 | h_identity = self.quant_conv_identity(h_identity) 186 | quant_identity, emb_loss_identity, _ = self.quantize_identity( 187 | h_identity) 188 | 189 | _, h_frame = self.encoder(x_frame) 190 | h_frame = self.quant_conv_others(h_frame) 191 | quant_frame, emb_loss_frame, _ = self.quantize_others(h_frame) 192 | return [quant_identity, 193 | quant_frame], emb_loss_identity + emb_loss_frame 194 | 195 | def decode(self, quant_list): 196 | quant_identity = self.post_quant_conv_identity(quant_list[0]) 197 | quant_frame = self.post_quant_conv_others(quant_list[1]) 198 | dec = self.decoder(quant_identity, quant_frame) 199 | return dec 200 | 201 | def forward_step(self, x_identity, x_frame_aug): 202 | quant_list, diff = self.encode(x_identity, x_frame_aug) 203 | dec = self.decode(quant_list) 204 | return dec, diff 205 | 206 | def training_step(self, data, step): 207 | x_identity, x_frame_aug, x_frame = self.feed_data(data) 208 | xrec, codebook_loss = self.forward_step(x_identity, x_frame_aug) 209 | 210 | # get recon/perceptual loss 211 | recon_loss = torch.abs(x_frame.contiguous() - xrec.contiguous()) 212 | p_loss = self.perceptual(x_frame.contiguous(), xrec.contiguous()) 213 | nll_loss = recon_loss + self.perceptual_weight * p_loss 214 | nll_loss = torch.mean(nll_loss) 215 | 216 | # augment for input to discriminator 217 | if self.diff_aug: 218 | xrec = DiffAugment(xrec, policy=self.policy) 219 | 220 | # update generator 221 | logits_fake = self.disc(xrec) 222 | g_loss = -torch.mean(logits_fake) 223 | last_layer = self.get_bare_model(self.decoder).conv_out.weight 224 | d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, 225 | self.disc_weight_max) 226 | d_weight *= adopt_weight(1, step, self.disc_start_step) 227 | loss = nll_loss + d_weight * g_loss + codebook_loss 228 | 229 | loss_dict = OrderedDict() 230 | 231 | loss_dict["loss"] = loss 232 | loss_dict["l1"] = recon_loss.mean() 233 | loss_dict["perceptual"] = p_loss.mean() 234 | loss_dict["nll_loss"] = nll_loss 235 | loss_dict["g_loss"] = g_loss 236 | loss_dict["d_weight"] = d_weight 237 | loss_dict["codebook_loss"] = codebook_loss 238 | 239 | if step > self.disc_start_step: 240 | if self.diff_aug: 241 | logits_real = self.disc( 242 | DiffAugment( 243 | x_frame.contiguous().detach(), policy=self.policy)) 244 | else: 245 | logits_real = self.disc(x_frame.contiguous().detach()) 246 | logits_fake = self.disc(xrec.contiguous().detach( 247 | )) # detach so that generator isn"t also updated 248 | d_loss, _, _ = hinge_d_loss(logits_real, logits_fake) 249 | loss_dict["d_loss"] = d_loss 250 | else: 251 | d_loss = None 252 | 253 | self.log_dict = self.reduce_loss_dict(loss_dict) 254 | 255 | return loss, d_loss 256 | 257 | @master_only 258 | def get_vis(self, x_identity, x_frame_aug, x_frame, xrec, save_dir, 259 | img_name): 260 | os.makedirs(save_dir, exist_ok=True) 261 | img_cat = torch.cat([x_identity, x_frame_aug, x_frame, xrec], 262 | dim=3).detach() 263 | img_cat = ((img_cat + 1) / 2) 264 | img_cat = img_cat.clamp_(0, 1) 265 | save_image(img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) 266 | 267 | @torch.no_grad() 268 | def inference(self, data_loader, save_dir): 269 | self.encoder.eval() 270 | self.decoder.eval() 271 | self.quantize_identity.eval() 272 | self.quant_conv_identity.eval() 273 | self.post_quant_conv_identity.eval() 274 | self.quantize_others.eval() 275 | self.quant_conv_others.eval() 276 | self.post_quant_conv_others.eval() 277 | 278 | loss_total = 0 279 | num = 0 280 | 281 | for _, data in enumerate(data_loader): 282 | img_name = data['video_name'][0] 283 | x_identity, x_frame_aug, x_frame = self.feed_data(data) 284 | xrec, _ = self.forward_step(x_identity, x_frame_aug) 285 | 286 | recon_loss = torch.abs(x_frame.contiguous() - xrec.contiguous()) 287 | p_loss = self.perceptual(x_frame.contiguous(), xrec.contiguous()) 288 | nll_loss = recon_loss + self.perceptual_weight * p_loss 289 | nll_loss = torch.mean(nll_loss) 290 | loss_total += nll_loss 291 | 292 | num += x_frame.size(0) 293 | 294 | self.get_vis(x_identity, x_frame_aug, x_frame, xrec, save_dir, 295 | img_name) 296 | 297 | return (loss_total / num).item() * (-1) 298 | -------------------------------------------------------------------------------- /train_dist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import os.path as osp 6 | import random 7 | import time 8 | 9 | import torch 10 | 11 | from data import create_dataloader, create_dataset 12 | from data.data_sampler import EnlargedSampler 13 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 14 | from models import create_model 15 | from utils.dist_util import get_dist_info, init_dist 16 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 17 | from utils.options import dict2str, dict_to_nonedict, parse 18 | from utils.util import make_exp_dirs, set_random_seed 19 | 20 | 21 | def get_dataloader(opt, logger): 22 | # create train, test, val dataloaders 23 | train_loader, val_loader, test_loader = None, None, None 24 | dataset_enlarge_ratio = opt.get('dataset_enlarge_ratio', 1) 25 | train_set = create_dataset(opt['datasets']['train']) 26 | opt['max_iters'] = opt['num_epochs'] * len(train_set) // ( 27 | opt['batch_size_per_gpu'] * opt['num_gpu']) 28 | logger.info(f'Number of train set: {len(train_set)}.') 29 | train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], 30 | dataset_enlarge_ratio) 31 | train_loader = create_dataloader( 32 | train_set, 33 | opt, 34 | phase='train', 35 | num_gpu=opt['num_gpu'], 36 | dist=opt['dist'], 37 | sampler=train_sampler, 38 | seed=opt['manual_seed']) 39 | 40 | val_set = create_dataset(opt['datasets']['val']) 41 | logger.info(f'Number of val set: {len(val_set)}.') 42 | val_loader = create_dataloader( 43 | val_set, 44 | opt, 45 | phase='val', 46 | num_gpu=opt['num_gpu'], 47 | dist=opt['dist'], 48 | sampler=None, 49 | seed=opt['manual_seed']) 50 | 51 | test_set = create_dataset(opt['datasets']['test']) 52 | logger.info(f'Number of test set: {len(test_set)}.') 53 | test_loader = create_dataloader( 54 | test_set, 55 | opt, 56 | phase='test', 57 | num_gpu=opt['num_gpu'], 58 | dist=opt['dist'], 59 | sampler=None, 60 | seed=opt['manual_seed']) 61 | 62 | return train_loader, train_sampler, val_loader, test_loader 63 | 64 | 65 | def main(): 66 | # options 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 69 | parser.add_argument( 70 | '--launcher', choices=['none', 'pytorch', 'slurm'], default='none') 71 | parser.add_argument('--local_rank', type=int, default=0) 72 | args = parser.parse_args() 73 | opt = parse(args.opt, is_train=True) 74 | 75 | # distributed settings 76 | if args.launcher == 'none': 77 | opt['dist'] = False 78 | print('Disable distributed.', flush=True) 79 | else: 80 | opt['dist'] = True 81 | if args.launcher == 'slurm' and 'dist_params' in opt: 82 | init_dist(args.launcher, **opt['dist_params']) 83 | else: 84 | init_dist(args.launcher) 85 | 86 | opt['rank'], opt['world_size'] = get_dist_info() 87 | 88 | # mkdir and loggers 89 | if opt['rank'] == 0: 90 | make_exp_dirs(opt) 91 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 92 | logger = get_root_logger( 93 | logger_name='base', log_level=logging.INFO, log_file=log_file) 94 | logger.info(dict2str(opt)) 95 | 96 | # initialize tensorboard logger 97 | tb_logger = None 98 | if opt['use_tb_logger'] and 'debug' not in opt['name'] and opt['rank'] == 0: 99 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 100 | 101 | # random seed 102 | seed = opt['manual_seed'] 103 | if seed is None: 104 | seed = random.randint(1, 10000) 105 | logger.info(f'Random seed: {seed}') 106 | set_random_seed(seed + opt['rank']) 107 | 108 | torch.backends.cudnn.benchmark = True 109 | # torch.backends.cudnn.deterministic = True 110 | 111 | # convert to NoneDict, which returns None for missing keys 112 | opt = dict_to_nonedict(opt) 113 | 114 | # set up data loader 115 | train_loader, train_sampler, val_loader, test_loader = get_dataloader( 116 | opt, logger) 117 | 118 | # dataloader prefetcher 119 | prefetch_mode = opt.get('prefetch_mode') 120 | if prefetch_mode is None or prefetch_mode == 'cpu': 121 | prefetcher = CPUPrefetcher(train_loader) 122 | elif prefetch_mode == 'cuda': 123 | prefetcher = CUDAPrefetcher(train_loader, opt) 124 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 125 | if opt['datasets']['train'].get('pin_memory') is not True: 126 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 127 | else: 128 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' 129 | "Supported ones are: None, 'cuda', 'cpu'.") 130 | 131 | current_iter = 0 132 | best_epoch = None 133 | best_acc = -100 134 | 135 | model = create_model(opt) 136 | 137 | data_time, iter_time = 0, 0 138 | current_iter = 0 139 | 140 | # create message logger (formatted outputs) 141 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 142 | 143 | for epoch in range(opt['num_epochs']): 144 | train_sampler.set_epoch(epoch) 145 | prefetcher.reset() 146 | train_data = prefetcher.next() 147 | 148 | lr = model.update_learning_rate(epoch) 149 | 150 | while train_data is not None: 151 | data_time = time.time() - data_time 152 | 153 | current_iter += 1 154 | 155 | model.feed_data(train_data) 156 | model.optimize_parameters() 157 | 158 | iter_time = time.time() - iter_time 159 | if current_iter % opt['print_freq'] == 0: 160 | log_vars = {'epoch': (epoch + 1), 'iter': current_iter} 161 | log_vars.update({'lrs': [lr]}) 162 | log_vars.update({'time': iter_time, 'data_time': data_time}) 163 | log_vars.update(model.get_current_log()) 164 | msg_logger(log_vars) 165 | 166 | data_time = time.time() 167 | iter_time = time.time() 168 | train_data = prefetcher.next() 169 | 170 | if (epoch + 1) % opt['val_freq'] == 0: 171 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{(epoch + 1):03d}' # noqa 172 | val_acc = model.inference(val_loader, f'{save_dir}/inference') 173 | model.sample_multinomial(val_loader, f'{save_dir}/sample') 174 | 175 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{(epoch + 1):03d}' # noqa 176 | test_acc = model.inference(test_loader, f'{save_dir}/inference') 177 | model.sample_multinomial(test_loader, f'{save_dir}/sample') 178 | 179 | logger.info(f'Epoch: {(epoch + 1)}, ' 180 | f'val_acc: {val_acc: .4f}, ' 181 | f'test_acc: {test_acc: .4f}.') 182 | 183 | if test_acc > best_acc: 184 | best_epoch = (epoch + 1) 185 | best_acc = test_acc 186 | 187 | logger.info(f'Best epoch: {best_epoch}, ' 188 | f'Best test acc: {best_acc: .4f}.') 189 | 190 | # save model 191 | model.save_network( 192 | model.sampler, 193 | f'{opt["path"]["models"]}/epoch_{(epoch + 1)}.pth') 194 | 195 | # torch.cuda.empty_cache() 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /train_sampler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | 8 | import torch 9 | 10 | from data.sample_identity_dataset import SampleIdentityDataset 11 | from models import create_model 12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 13 | from utils.options import dict2str, dict_to_nonedict, parse 14 | from utils.util import make_exp_dirs, set_random_seed 15 | 16 | 17 | def main(): 18 | # options 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 21 | args = parser.parse_args() 22 | opt = parse(args.opt, is_train=True) 23 | 24 | # mkdir and loggers 25 | make_exp_dirs(opt) 26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 27 | logger = get_root_logger( 28 | logger_name='base', log_level=logging.INFO, log_file=log_file) 29 | logger.info(dict2str(opt)) 30 | # initialize tensorboard logger 31 | tb_logger = None 32 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 34 | 35 | # random seed 36 | seed = opt['manual_seed'] 37 | if seed is None: 38 | seed = random.randint(1, 10000) 39 | logger.info(f'Random seed: {seed}') 40 | set_random_seed(seed) 41 | 42 | # convert to NoneDict, which returns None for missing keys 43 | opt = dict_to_nonedict(opt) 44 | 45 | # set up data loader 46 | train_dataset = SampleIdentityDataset(opt['datasets']['train']) 47 | train_loader = torch.utils.data.DataLoader( 48 | dataset=train_dataset, 49 | batch_size=opt['batch_size'], 50 | shuffle=True, 51 | num_workers=opt['num_workers'], 52 | persistent_workers=True, 53 | drop_last=True) 54 | logger.info(f'Number of train set: {len(train_dataset)}.') 55 | opt['max_iters'] = opt['num_epochs'] * len( 56 | train_dataset) // opt['batch_size'] 57 | 58 | val_dataset = SampleIdentityDataset(opt['datasets']['val']) 59 | val_loader = torch.utils.data.DataLoader( 60 | dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False) 61 | logger.info(f'Number of val set: {len(val_dataset)}.') 62 | 63 | test_dataset = SampleIdentityDataset(opt['datasets']['test']) 64 | test_loader = torch.utils.data.DataLoader( 65 | dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False) 66 | logger.info(f'Number of test set: {len(test_dataset)}.') 67 | 68 | current_iter = 0 69 | 70 | model = create_model(opt) 71 | 72 | data_time, iter_time = 0, 0 73 | current_iter = 0 74 | 75 | # create message logger (formatted outputs) 76 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 77 | 78 | for epoch in range(opt['num_epochs']): 79 | lr = model.update_learning_rate(epoch, current_iter) 80 | 81 | for _, batch_data in enumerate(train_loader): 82 | data_time = time.time() - data_time 83 | 84 | current_iter += 1 85 | 86 | model.feed_data(batch_data) 87 | model.optimize_parameters() 88 | 89 | iter_time = time.time() - iter_time 90 | if current_iter % opt['print_freq'] == 0: 91 | log_vars = {'epoch': epoch, 'iter': current_iter} 92 | log_vars.update({'lrs': [lr]}) 93 | log_vars.update({'time': iter_time, 'data_time': data_time}) 94 | log_vars.update(model.get_current_log()) 95 | msg_logger(log_vars) 96 | 97 | data_time = time.time() 98 | iter_time = time.time() 99 | 100 | if epoch % opt['val_freq'] == 0 and epoch != 0: 101 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa 102 | os.makedirs(save_dir, exist_ok=opt['debug']) 103 | model.inference(val_loader, save_dir) 104 | 105 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa 106 | os.makedirs(save_dir, exist_ok=opt['debug']) 107 | model.inference(test_loader, save_dir) 108 | 109 | # save model 110 | model.save_network( 111 | model._denoise_fn, 112 | f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth') 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /train_vqvae_iter_dist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os.path as osp 4 | import random 5 | import time 6 | 7 | import torch 8 | 9 | from data import create_dataloader, create_dataset 10 | from data.data_sampler import EnlargedSampler 11 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 12 | from models import create_model 13 | from utils.dist_util import get_dist_info, init_dist 14 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger 15 | from utils.options import dict2str, dict_to_nonedict, parse 16 | from utils.util import make_exp_dirs, set_random_seed 17 | 18 | 19 | def get_dataloader(opt, logger): 20 | # create train, test, val dataloaders 21 | train_loader, val_loader, test_loader = None, None, None 22 | dataset_enlarge_ratio = opt.get('dataset_enlarge_ratio', 1) 23 | train_set = create_dataset(opt['datasets']['train']) 24 | opt['max_iters'] = opt['num_epochs'] * len(train_set) // ( 25 | opt['batch_size_per_gpu'] * opt['num_gpu']) 26 | logger.info(f'Number of train set: {len(train_set)}.') 27 | train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], 28 | dataset_enlarge_ratio) 29 | train_loader = create_dataloader( 30 | train_set, 31 | opt, 32 | phase='train', 33 | num_gpu=opt['num_gpu'], 34 | dist=opt['dist'], 35 | sampler=train_sampler, 36 | seed=opt['manual_seed']) 37 | 38 | val_set = create_dataset(opt['datasets']['val']) 39 | logger.info(f'Number of val set: {len(val_set)}.') 40 | val_loader = create_dataloader( 41 | val_set, 42 | opt, 43 | phase='val', 44 | num_gpu=opt['num_gpu'], 45 | dist=opt['dist'], 46 | sampler=None, 47 | seed=opt['manual_seed']) 48 | 49 | test_set = create_dataset(opt['datasets']['test']) 50 | logger.info(f'Number of test set: {len(test_set)}.') 51 | test_loader = create_dataloader( 52 | test_set, 53 | opt, 54 | phase='test', 55 | num_gpu=opt['num_gpu'], 56 | dist=opt['dist'], 57 | sampler=None, 58 | seed=opt['manual_seed']) 59 | 60 | return train_loader, train_sampler, val_loader, test_loader 61 | 62 | 63 | def main(): 64 | # options 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('-opt', type=str, help='Path to option YAML file.') 67 | parser.add_argument( 68 | '--launcher', choices=['none', 'pytorch', 'slurm'], default='none') 69 | parser.add_argument('--local_rank', type=int, default=0) 70 | args = parser.parse_args() 71 | opt = parse(args.opt, is_train=True) 72 | 73 | # distributed settings 74 | if args.launcher == 'none': 75 | opt['dist'] = False 76 | print('Disable distributed.', flush=True) 77 | else: 78 | opt['dist'] = True 79 | if args.launcher == 'slurm' and 'dist_params' in opt: 80 | init_dist(args.launcher, **opt['dist_params']) 81 | else: 82 | init_dist(args.launcher) 83 | 84 | opt['rank'], opt['world_size'] = get_dist_info() 85 | 86 | # mkdir and loggers 87 | if opt['rank'] == 0: 88 | make_exp_dirs(opt) 89 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") 90 | logger = get_root_logger( 91 | logger_name='base', log_level=logging.INFO, log_file=log_file) 92 | logger.info(dict2str(opt)) 93 | 94 | # initialize tensorboard logger 95 | tb_logger = None 96 | if opt['use_tb_logger'] and 'debug' not in opt['name'] and opt['rank'] == 0: 97 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name']) 98 | 99 | # random seed 100 | seed = opt['manual_seed'] 101 | if seed is None: 102 | seed = random.randint(1, 10000) 103 | logger.info(f'Random seed: {seed}') 104 | set_random_seed(seed + opt['rank']) 105 | 106 | torch.backends.cudnn.benchmark = True 107 | # torch.backends.cudnn.deterministic = True 108 | 109 | # convert to NoneDict, which returns None for missing keys 110 | opt = dict_to_nonedict(opt) 111 | 112 | # set up data loader 113 | train_loader, train_sampler, val_loader, test_loader = get_dataloader( 114 | opt, logger) 115 | 116 | # dataloader prefetcher 117 | prefetch_mode = opt.get('prefetch_mode') 118 | if prefetch_mode is None or prefetch_mode == 'cpu': 119 | prefetcher = CPUPrefetcher(train_loader) 120 | elif prefetch_mode == 'cuda': 121 | prefetcher = CUDAPrefetcher(train_loader, opt) 122 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 123 | if opt['datasets']['train'].get('pin_memory') is not True: 124 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 125 | else: 126 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' 127 | "Supported ones are: None, 'cuda', 'cpu'.") 128 | 129 | current_iter = 0 130 | best_epoch = None 131 | best_acc = -100 132 | 133 | model = create_model(opt) 134 | 135 | data_time, iter_time = 0, 0 136 | current_iter = 0 137 | 138 | # create message logger (formatted outputs) 139 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 140 | 141 | for epoch in range(opt['num_epochs']): 142 | train_sampler.set_epoch(epoch) 143 | prefetcher.reset() 144 | train_data = prefetcher.next() 145 | 146 | lr = model.update_learning_rate(epoch) 147 | 148 | while train_data is not None: 149 | data_time = time.time() - data_time 150 | 151 | current_iter += 1 152 | 153 | model.optimize_parameters(train_data, current_iter) 154 | 155 | iter_time = time.time() - iter_time 156 | if current_iter % opt['print_freq'] == 0: 157 | log_vars = {'epoch': (epoch + 1), 'iter': current_iter} 158 | log_vars.update({'lrs': [lr]}) 159 | log_vars.update({'time': iter_time, 'data_time': data_time}) 160 | log_vars.update(model.get_current_log()) 161 | msg_logger(log_vars) 162 | 163 | data_time = time.time() 164 | iter_time = time.time() 165 | train_data = prefetcher.next() 166 | 167 | if (current_iter + 1) % opt['val_freq'] == 0: 168 | save_dir = f'{opt["path"]["visualization"]}/valset/iter_{(current_iter + 1):03d}' # noqa 169 | val_acc = model.inference(val_loader, f'{save_dir}/inference') 170 | 171 | save_dir = f'{opt["path"]["visualization"]}/testset/iter_{(current_iter + 1):03d}' # noqa 172 | test_acc = model.inference(test_loader, 173 | f'{save_dir}/inference') 174 | 175 | logger.info(f'Iter: {(current_iter + 1)}, ' 176 | f'val_acc: {val_acc: .4f}, ' 177 | f'test_acc: {test_acc: .4f}.') 178 | 179 | if test_acc > best_acc: 180 | best_iter = (current_iter + 1) 181 | best_acc = test_acc 182 | 183 | logger.info(f'Best iter: {best_iter}, ' 184 | f'Best test acc: {best_acc: .4f}.') 185 | 186 | # save model 187 | model.save_network( 188 | f'{opt["path"]["models"]}/iter_{(current_iter + 1)}.pth') 189 | 190 | torch.cuda.empty_cache() 191 | 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/utils/__init__.py -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | 11 | def init_dist(launcher, backend='nccl', **kwargs): 12 | if mp.get_start_method(allow_none=True) is None: 13 | mp.set_start_method('spawn') 14 | if launcher == 'pytorch': 15 | _init_dist_pytorch(backend, **kwargs) 16 | elif launcher == 'slurm': 17 | _init_dist_slurm(backend, **kwargs) 18 | else: 19 | raise ValueError(f'Invalid launcher type: {launcher}') 20 | 21 | 22 | def _init_dist_pytorch(backend, **kwargs): 23 | rank = int(os.environ['RANK']) 24 | num_gpus = torch.cuda.device_count() 25 | torch.cuda.set_device(rank % num_gpus) 26 | dist.init_process_group(backend=backend, **kwargs) 27 | 28 | 29 | def _init_dist_slurm(backend, port=None): 30 | """Initialize slurm distributed training environment. 31 | 32 | If argument ``port`` is not specified, then the master port will be system 33 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 34 | environment variable, then a default port ``29500`` will be used. 35 | 36 | Args: 37 | backend (str): Backend of torch.distributed. 38 | port (int, optional): Master port. Defaults to None. 39 | """ 40 | proc_id = int(os.environ['SLURM_PROCID']) 41 | ntasks = int(os.environ['SLURM_NTASKS']) 42 | node_list = os.environ['SLURM_NODELIST'] 43 | num_gpus = torch.cuda.device_count() 44 | torch.cuda.set_device(proc_id % num_gpus) 45 | addr = subprocess.getoutput( 46 | f'scontrol show hostname {node_list} | head -n1') 47 | # specify master port 48 | if port is not None: 49 | os.environ['MASTER_PORT'] = str(port) 50 | elif 'MASTER_PORT' in os.environ: 51 | pass # use MASTER_PORT in the environment variable 52 | else: 53 | # 29500 is torch.distributed default port 54 | os.environ['MASTER_PORT'] = '29500' 55 | os.environ['MASTER_ADDR'] = addr 56 | os.environ['WORLD_SIZE'] = str(ntasks) 57 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 58 | os.environ['RANK'] = str(proc_id) 59 | dist.init_process_group(backend=backend) 60 | 61 | 62 | def get_dist_info(): 63 | if dist.is_available(): 64 | initialized = dist.is_initialized() 65 | else: 66 | initialized = False 67 | if initialized: 68 | rank = dist.get_rank() 69 | world_size = dist.get_world_size() 70 | else: 71 | rank = 0 72 | world_size = 1 73 | return rank, world_size 74 | 75 | 76 | def master_only(func): 77 | 78 | @functools.wraps(func) 79 | def wrapper(*args, **kwargs): 80 | rank, _ = get_dist_info() 81 | if rank == 0: 82 | return func(*args, **kwargs) 83 | 84 | return wrapper 85 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | 8 | class MessageLogger(): 9 | """Message logger for printing. 10 | 11 | Args: 12 | opt (dict): Config. It contains the following keys: 13 | name (str): Exp name. 14 | logger (dict): Contains 'print_freq' (str) for logger interval. 15 | train (dict): Contains 'niter' (int) for total iters. 16 | use_tb_logger (bool): Use tensorboard logger. 17 | start_iter (int): Start iter. Default: 1. 18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 19 | """ 20 | 21 | def __init__(self, opt, start_iter=1, tb_logger=None): 22 | self.exp_name = opt['name'] 23 | self.interval = opt['print_freq'] 24 | self.start_iter = start_iter 25 | self.max_iters = opt['max_iters'] 26 | self.use_tb_logger = opt['use_tb_logger'] 27 | self.tb_logger = tb_logger 28 | self.start_time = time.time() 29 | self.logger = get_root_logger() 30 | 31 | @master_only 32 | def __call__(self, log_vars): 33 | """Format logging message. 34 | 35 | Args: 36 | log_vars (dict): It contains the following keys: 37 | epoch (int): Epoch number. 38 | iter (int): Current iter. 39 | lrs (list): List for learning rates. 40 | 41 | time (float): Iter time. 42 | data_time (float): Data time for each iter. 43 | """ 44 | # epoch, iter, learning rates 45 | epoch = log_vars.pop('epoch') 46 | current_iter = log_vars.pop('iter') 47 | lrs = log_vars.pop('lrs') 48 | 49 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 50 | f'iter:{current_iter:8,d}, lr:(') 51 | for v in lrs: 52 | message += f'{v:.3e},' 53 | message += ')] ' 54 | 55 | # time and estimated time 56 | if 'time' in log_vars.keys(): 57 | iter_time = log_vars.pop('time') 58 | data_time = log_vars.pop('data_time') 59 | 60 | total_time = time.time() - self.start_time 61 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 62 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 63 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 64 | message += f'[eta: {eta_str}, ' 65 | message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] ' 66 | 67 | # other items, especially losses 68 | for k, v in log_vars.items(): 69 | message += f'{k}: {v:.4e} ' 70 | # tensorboard logger 71 | if self.use_tb_logger and 'debug' not in self.exp_name: 72 | self.tb_logger.add_scalar(k, v, current_iter) 73 | 74 | self.logger.info(message) 75 | 76 | 77 | @master_only 78 | def init_tb_logger(log_dir): 79 | from torch.utils.tensorboard import SummaryWriter 80 | tb_logger = SummaryWriter(log_dir=log_dir) 81 | return tb_logger 82 | 83 | 84 | def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None): 85 | """Get the root logger. 86 | 87 | The logger will be initialized if it has not been initialized. By default a 88 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 89 | also be added. 90 | 91 | Args: 92 | logger_name (str): root logger name. Default: base. 93 | log_file (str | None): The log filename. If specified, a FileHandler 94 | will be added to the root logger. 95 | log_level (int): The root logger level. Note that only the process of 96 | rank 0 is affected, while other processes will set the level to 97 | "Error" and be silent most of the time. 98 | 99 | Returns: 100 | logging.Logger: The root logger. 101 | """ 102 | logger = logging.getLogger(logger_name) 103 | # if the logger has been initialized, just return it 104 | if logger.hasHandlers(): 105 | return logger 106 | 107 | format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s' 108 | logging.basicConfig(format=format_str, level=log_level) 109 | rank, _ = get_dist_info() 110 | if rank != 0: 111 | logger.setLevel('ERROR') 112 | elif log_file is not None: 113 | file_handler = logging.FileHandler(log_file, 'w') 114 | file_handler.setFormatter(logging.Formatter(format_str)) 115 | file_handler.setLevel(log_level) 116 | logger.addHandler(file_handler) 117 | 118 | return logger 119 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from collections import OrderedDict 4 | 5 | import yaml 6 | 7 | 8 | def ordered_yaml(): 9 | """Support OrderedDict for yaml. 10 | 11 | Returns: 12 | yaml Loader and Dumper. 13 | """ 14 | try: 15 | from yaml import CDumper as Dumper 16 | from yaml import CLoader as Loader 17 | except ImportError: 18 | from yaml import Dumper, Loader 19 | 20 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 21 | 22 | def dict_representer(dumper, data): 23 | return dumper.represent_dict(data.items()) 24 | 25 | def dict_constructor(loader, node): 26 | return OrderedDict(loader.construct_pairs(node)) 27 | 28 | Dumper.add_representer(OrderedDict, dict_representer) 29 | Loader.add_constructor(_mapping_tag, dict_constructor) 30 | return Loader, Dumper 31 | 32 | 33 | def parse(opt_path, is_train=True): 34 | """Parse option file. 35 | 36 | Args: 37 | opt_path (str): Option file path. 38 | is_train (str): Indicate whether in training or not. Default: True. 39 | 40 | Returns: 41 | (dict): Options. 42 | """ 43 | with open(opt_path, mode='r') as f: 44 | Loader, _ = ordered_yaml() 45 | opt = yaml.load(f, Loader=Loader) 46 | 47 | # gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 48 | # if opt.get('set_CUDA_VISIBLE_DEVICES', None): 49 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 50 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True) 51 | # else: 52 | # print('gpu_list: ', gpu_list, flush=True) 53 | 54 | opt['is_train'] = is_train 55 | 56 | # datasets 57 | if 'datasets' in opt.keys(): 58 | for phase, dataset in opt['datasets'].items(): 59 | # for several datasets, e.g., test_1, test_2 60 | phase = phase.split('_')[0] 61 | dataset['phase'] = phase 62 | 63 | # paths 64 | opt['path'] = {} 65 | opt['path']['root'] = osp.abspath( 66 | osp.join(__file__, osp.pardir, osp.pardir)) 67 | if is_train: 68 | experiments_root = osp.join(opt['path']['root'], 'experiments', 69 | opt['name']) 70 | opt['path']['experiments_root'] = experiments_root 71 | opt['path']['models'] = osp.join(experiments_root, 'models') 72 | opt['path']['log'] = experiments_root 73 | opt['path']['visualization'] = osp.join(experiments_root, 74 | 'visualization') 75 | 76 | # change some options for debug mode 77 | if 'debug' in opt['name']: 78 | opt['debug'] = True 79 | opt['val_freq'] = 1 80 | opt['print_freq'] = 1 81 | opt['save_checkpoint_freq'] = 1 82 | else: # test 83 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 84 | opt['path']['results_root'] = results_root 85 | opt['path']['log'] = results_root 86 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 87 | 88 | return opt 89 | 90 | 91 | def dict2str(opt, indent_level=1): 92 | """dict to string for printing options. 93 | 94 | Args: 95 | opt (dict): Option dict. 96 | indent_level (int): Indent level. Default: 1. 97 | 98 | Return: 99 | (str): Option string for printing. 100 | """ 101 | msg = '' 102 | for k, v in opt.items(): 103 | if isinstance(v, dict): 104 | msg += ' ' * (indent_level * 2) + k + ':[\n' 105 | msg += dict2str(v, indent_level + 1) 106 | msg += ' ' * (indent_level * 2) + ']\n' 107 | else: 108 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 109 | return msg 110 | 111 | 112 | class NoneDict(dict): 113 | """None dict. It will return none if key is not in the dict.""" 114 | 115 | def __missing__(self, key): 116 | return None 117 | 118 | 119 | def dict_to_nonedict(opt): 120 | """Convert to NoneDict, which returns None for missing keys. 121 | 122 | Args: 123 | opt (dict): Option dict. 124 | 125 | Returns: 126 | (dict): NoneDict for options. 127 | """ 128 | if isinstance(opt, dict): 129 | new_opt = dict() 130 | for key, sub_opt in opt.items(): 131 | new_opt[key] = dict_to_nonedict(sub_opt) 132 | return NoneDict(**new_opt) 133 | elif isinstance(opt, list): 134 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 135 | else: 136 | return opt 137 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | import time 6 | from shutil import get_terminal_size 7 | 8 | import numpy as np 9 | import torch 10 | 11 | logger = logging.getLogger('base') 12 | 13 | 14 | def make_exp_dirs(opt): 15 | """Make dirs for experiments.""" 16 | path_opt = opt['path'].copy() 17 | if opt['is_train']: 18 | overwrite = True if 'debug' in opt['name'] else False 19 | os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite) 20 | os.makedirs(path_opt.pop('models'), exist_ok=overwrite) 21 | else: 22 | os.makedirs(path_opt.pop('results_root')) 23 | 24 | 25 | def set_random_seed(seed): 26 | """Set random seeds.""" 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | 33 | 34 | class ProgressBar(object): 35 | """A progress bar which can print the progress. 36 | 37 | Modified from: 38 | https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 39 | """ 40 | 41 | def __init__(self, task_num=0, bar_width=50, start=True): 42 | self.task_num = task_num 43 | max_bar_width = self._get_max_bar_width() 44 | self.bar_width = ( 45 | bar_width if bar_width <= max_bar_width else max_bar_width) 46 | self.completed = 0 47 | if start: 48 | self.start() 49 | 50 | def _get_max_bar_width(self): 51 | terminal_width, _ = get_terminal_size() 52 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 53 | if max_bar_width < 10: 54 | print(f'terminal width is too small ({terminal_width}), ' 55 | 'please consider widen the terminal for better ' 56 | 'progressbar visualization') 57 | max_bar_width = 10 58 | return max_bar_width 59 | 60 | def start(self): 61 | if self.task_num > 0: 62 | sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, " 63 | f'elapsed: 0s, ETA:\nStart...\n') 64 | else: 65 | sys.stdout.write('completed: 0, elapsed: 0s') 66 | sys.stdout.flush() 67 | self.start_time = time.time() 68 | 69 | def update(self, msg='In progress...'): 70 | self.completed += 1 71 | elapsed = time.time() - self.start_time 72 | fps = self.completed / elapsed 73 | if self.task_num > 0: 74 | percentage = self.completed / float(self.task_num) 75 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 76 | mark_width = int(self.bar_width * percentage) 77 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 78 | sys.stdout.write('\033[2F') # cursor up 2 lines 79 | sys.stdout.write( 80 | '\033[J' 81 | ) # clean the output (remove extra chars since last display) 82 | sys.stdout.write( 83 | f'[{bar_chars}] {self.completed}/{self.task_num}, ' 84 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' 85 | f'ETA: {eta:5}s\n{msg}\n') 86 | else: 87 | sys.stdout.write( 88 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, ' 89 | f'{fps:.1f} tasks/s') 90 | sys.stdout.flush() 91 | 92 | 93 | class AverageMeter(object): 94 | """ 95 | Computes and stores the average and current value 96 | Imported from 97 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 98 | """ 99 | 100 | def __init__(self): 101 | self.reset() 102 | 103 | def reset(self): 104 | self.val = 0 105 | self.avg = 0 # running average = running sum / running count 106 | self.sum = 0 # running sum 107 | self.count = 0 # running count 108 | 109 | def update(self, val, n=1): 110 | # n = batch_size 111 | 112 | # val = batch accuracy for an attribute 113 | # self.val = val 114 | 115 | # sum = 100 * accumulative correct predictions for this attribute 116 | self.sum += val * n 117 | 118 | # count = total samples so far 119 | self.count += n 120 | 121 | # avg = 100 * avg accuracy for this attribute 122 | # for all the batches so far 123 | self.avg = self.sum / self.count 124 | --------------------------------------------------------------------------------