├── .gitignore
├── LICENCE
├── README.md
├── configs
├── sampler
│ ├── sampler_high_res.yml
│ └── sampler_low_res.yml
├── video_transformer
│ ├── video_trans_high_res.yml
│ └── video_trans_low_res.yml
└── vqgan
│ ├── vqgan_decompose_high_res.yml
│ └── vqgan_decompose_low_res.yml
├── data
├── __init__.py
├── data_sampler.py
├── decompose_dataset.py
├── moving_label_clip_all_rate_dataset.py
├── prefetch_dataloader.py
└── sample_identity_dataset.py
├── env.yaml
├── generate_long_video.ipynb
├── img
└── teaser.png
├── models
├── __init__.py
├── app_transformer_model.py
├── archs
│ ├── __init__.py
│ ├── dalle_transformer_arch.py
│ ├── einops_exts.py
│ ├── rotary_embedding_torch.py
│ ├── transformer_arch.py
│ └── vqgan_arch.py
├── base_model.py
├── losses
│ ├── __init__.py
│ └── vqgan_loss.py
├── video_transformer_model.py
└── vqgan_decompose_model.py
├── train_dist.py
├── train_sampler.py
├── train_vqvae_iter_dist.py
└── utils
├── __init__.py
├── dist_util.py
├── logger.py
├── options.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results/*
2 | .ipynb_checkpoints/*
3 | *.pyc
4 | experiments/*
5 | pretrained_models/*
6 | tb_logger/*
7 | datasets/*
8 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2023 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Text2Performer: Text-Driven Human Video Generation
4 |
5 |
13 |
14 | 1S-Lab, Nanyang Technological University 2Shanghai AI Laboratory
15 |
16 |
17 | [Paper](https://arxiv.org/pdf/2304.08483.pdf) | [Project Page](https://yumingj.github.io/projects/Text2Performer.html) | [Dataset](https://github.com/yumingj/Fashion-Text2Video) | [Video](https://youtu.be/YwhaJUk_qo0)
18 |
19 |
20 |
Text2Performer synthesizes human videos by taking the text descriptions as the only input.
21 |
22 |
23 |

24 |
25 |
26 | :open_book: For more visual results, go checkout our
project page
27 |
28 |
29 |
30 | ## Installation
31 | **Clone this repo:**
32 | ```bash
33 | git clone https://github.com/yumingj/Text2Performer.git
34 | cd Text2Performer
35 | ```
36 |
37 | **Dependencies:**
38 |
39 | ```bash
40 | conda env create -f env.yaml
41 | conda activate text2performer
42 | ```
43 |
44 | ## (1) Dataset Preparation
45 |
46 | In this work, we contribute a human video dataset with rich label and text annotations named [Fashion-Text2Video](https://github.com/yumingj/Fashion-Text2Video) Dataset.
47 |
48 | You can download our processed dataset from this [Google Drive](https://drive.google.com/drive/folders/1NFd_irnw8kgNcu5KfWhRA8RZPdBK5p1I?usp=sharing).
49 | After downloading the dataset, unzip the file and put them under the dataset folder with the following structure:
50 | ```
51 | ./datasets
52 | ├── FashionDataset_frames_crop
53 | ├── xxxxxx
54 | ├── 000.png
55 | ├── 001.png
56 | ├── ...
57 | ├── xxxxxx
58 | └── xxxxxx
59 | ├── train_frame_num.txt
60 | ├── val_frame_num.txt
61 | ├── test_frame_num.txt
62 | ├── moving_frames.npy
63 | ├── captions_app.json
64 | ├── caption_motion_template.json
65 | ├── action_label
66 | ├── xxxxxx.txt
67 | ├── xxxxxx.txt
68 | ├── ...
69 | └── xxxxxx.txt
70 | └── shhq_dataset % optional
71 | ```
72 |
73 | ## (2) Sampling
74 |
75 | ### Pretrained Models
76 |
77 | Pretrained models can be downloaded from the [Google Drive](https://drive.google.com/drive/folders/1Dgg0EaldNfyPhykHw1TYrm4qme3CqrDz?usp=sharing). Unzip the file and put them under the pretrained_models folder with the following structure:
78 | ```
79 | pretrained_models
80 | ├── sampler_high_res.pth
81 | ├── video_trans_high_res.pth
82 | └── vqgan_decomposed_high_res.pth
83 | ```
84 |
85 | After downloading pretrained models, you can use ```generate_long_video.ipynb``` to generate videos.
86 |
87 | ## (3) Training Text2Performer
88 | ### Stage I: Decomposed VQGAN
89 | Train the decomposed VQGAN. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1G59bRoOUEQA8xljRDsfyiw6g8spV3Y7_/view?usp=sharing).
90 |
91 | For better performance, we also use the data from [SHHQ dataset](https://github.com/stylegan-human/StyleGAN-Human/blob/main/docs/Dataset.md) to train this stage.
92 | ```python
93 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=29596 train_vqvae_iter_dist.py -opt ./configs/vqgan/vqgan_decompose_high_res.yml --launcher pytorch
94 | ```
95 |
96 | ### Stage II: Video Transformer
97 | Train the video transformer. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1QRQlhl8z4-BQfmUvHoVrJnSpxQaKDPZH/view?usp=sharing).
98 | ```python
99 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=29596 train_dist.py -opt ./configs/video_transformer/video_trans_high_res.yml --launcher pytorch
100 | ```
101 |
102 | ### Stage III: Appearance Transformer
103 | Train the appearance transformer. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/19nYQT511XsBzq1sMUc2MmfpDKT7HVi8Z/view?usp=sharing).
104 | ```python
105 | python train_sampler.py -opt ./configs/sampler/sampler_high_res.yml
106 | ```
107 |
108 | ## Citation
109 |
110 | If you find this work useful for your research, please consider citing our paper:
111 |
112 | ```bibtex
113 | @inproceedings{jiang2023text2performer,
114 | title={Text2Performer: Text-Driven Human Video Generation},
115 | author={Jiang, Yuming and Yang, Shuai and Koh, Tong Liang and Wu, Wayne and Loy, Chen Change and Liu, Ziwei},
116 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
117 | year={2023}
118 | }
119 | ```
120 |
121 | ## :newspaper_roll: License
122 |
123 | Distributed under the S-Lab License. See `LICENSE` for more information.
124 |
125 | 
126 |
127 |
--------------------------------------------------------------------------------
/configs/sampler/sampler_high_res.yml:
--------------------------------------------------------------------------------
1 | name: sampler_high_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 1
9 | datasets:
10 | train:
11 | video_dir: ./datasets/FashionDataset_frames_crop
12 | data_name_txt: ./datasets/train_frame_num.txt
13 | text_file: ./datasets/captions_app.json
14 | downsample_factor: 1
15 | xflip: True
16 |
17 | val:
18 | video_dir: ./datasets/FashionDataset_frames_crop
19 | data_name_txt: ./datasets/val_frame_num.txt
20 | text_file: ./datasets/captions_app.json
21 | downsample_factor: 1
22 | xflip: False
23 |
24 | test:
25 | video_dir: ./datasets/FashionDataset_frames_crop
26 | data_name_txt: ./datasets/test_frame_num.txt
27 | text_file: ./datasets/captions_app.json
28 | downsample_factor: 1
29 | xflip: False
30 |
31 | # pretrained models
32 | img_ae_path: ./pretrained_models/vqgan_decomposed_high_res.pth
33 |
34 | model_type: AppTransformerModel
35 | # network configs
36 |
37 | # image autoencoder
38 | img_embed_dim: 256
39 | img_n_embed: 1024
40 | img_double_z: false
41 | img_z_channels: 256
42 | img_resolution: 512
43 | img_in_channels: 3
44 | img_out_ch: 3
45 | img_ch: 128
46 | img_ch_mult: [1, 1, 2, 2, 4]
47 | img_other_ch_mult: [4, 4]
48 | img_num_res_blocks: 2
49 | img_attn_resolutions: [32]
50 | img_dropout: 0.0
51 |
52 | # sampler configs
53 | codebook_size: 1024
54 | bert_n_emb: 512
55 | bert_n_layers: 24
56 | bert_n_head: 8
57 | block_size: 512 # 32 x 16
58 | latent_shape: [32, 16]
59 | embd_pdrop: 0.0
60 | resid_pdrop: 0.0
61 | attn_pdrop: 0.0
62 |
63 | # loss configs
64 | loss_type: reweighted_elbo
65 | mask_schedule: random
66 |
67 | sample_steps: 64
68 |
69 | # training configs
70 | val_freq: 50
71 | print_freq: 100
72 | weight_decay: 0
73 | manual_seed: 2021
74 | num_epochs: 1000
75 | lr: !!float 1e-4
76 | lr_decay: step
77 | gamma: 1.0
78 | step: 50
79 |
80 | text_seq_len: 50
81 |
--------------------------------------------------------------------------------
/configs/sampler/sampler_low_res.yml:
--------------------------------------------------------------------------------
1 | name: sampler_low_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 |
6 | # dataset configs
7 | batch_size: 4
8 | num_workers: 1
9 | datasets:
10 | train:
11 | video_dir: ./datasets/FashionDataset_frames_crop
12 | data_name_txt: ./datasets/train_frame_num.txt
13 | text_file: ./datasets/captions_app.json
14 | downsample_factor: 2
15 | xflip: True
16 |
17 | val:
18 | video_dir: ./datasets/FashionDataset_frames_crop
19 | data_name_txt: ./datasets/val_frame_num.txt
20 | text_file: ./datasets/captions_app.json
21 | downsample_factor: 2
22 | xflip: False
23 |
24 | test:
25 | video_dir: ./datasets/FashionDataset_frames_crop
26 | data_name_txt: ./datasets/test_frame_num.txt
27 | text_file: ./datasets/captions_app.json
28 | downsample_factor: 2
29 | xflip: False
30 |
31 | # pretrained models
32 | img_ae_path: ./pretrained_models/vqgan_decomposed_low_res.pth
33 |
34 | model_type: AppTransformerModel
35 | # network configs
36 |
37 | # image autoencoder
38 | img_embed_dim: 256
39 | img_n_embed: 1024
40 | img_double_z: false
41 | img_z_channels: 256
42 | img_resolution: 512
43 | img_in_channels: 3
44 | img_out_ch: 3
45 | img_ch: 128
46 | img_ch_mult: [1, 1, 2, 2, 4]
47 | img_other_ch_mult: [4, 4]
48 | img_num_res_blocks: 2
49 | img_attn_resolutions: [32]
50 | img_dropout: 0.0
51 |
52 | # sampler configs
53 | codebook_size: 1024
54 | bert_n_emb: 512
55 | bert_n_layers: 24
56 | bert_n_head: 8
57 | block_size: 128 # 32 x 16
58 | latent_shape: [16, 8]
59 | embd_pdrop: 0.0
60 | resid_pdrop: 0.0
61 | attn_pdrop: 0.0
62 |
63 | # loss configs
64 | loss_type: reweighted_elbo
65 | mask_schedule: random
66 |
67 | sample_steps: 64
68 |
69 | # training configs
70 | val_freq: 50
71 | print_freq: 100
72 | weight_decay: 0
73 | manual_seed: 2021
74 | num_epochs: 1000
75 | lr: !!float 1e-4
76 | lr_decay: step
77 | gamma: 1.0
78 | step: 50
79 |
80 | text_seq_len: 50
81 |
--------------------------------------------------------------------------------
/configs/video_transformer/video_trans_high_res.yml:
--------------------------------------------------------------------------------
1 | name: video_trans_high_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 | num_gpu: 4
6 |
7 | # dataset configs
8 | batch_size_per_gpu: 1
9 | num_worker_per_gpu: 0
10 | datasets:
11 | train:
12 | type: MovingLabelsClipAllRateTextDataset
13 | video_dir: ./datasets/FashionDataset_frames_crop
14 | action_label_folder: ./datasets/action_label
15 | data_name_txt: ./datasets/train_frame_num.txt
16 | frame_sample_rate: 1
17 | fixed_video_len: 8
18 | moving_frame_dict: ./datasets/moving_frames.npy
19 | overall_caption_templates_file: ./datasets/caption_motion_template.json
20 | interpolation_rate: 0.2
21 | downsample_factor: 1
22 | random_start: True
23 | xflip: False
24 |
25 | val:
26 | type: MovingLabelsClipAllRateTextDataset
27 | video_dir: ./datasets/FashionDataset_frames_crop
28 | action_label_folder: ./datasets/action_label
29 | data_name_txt: ./datasets/val_frame_num.txt
30 | frame_sample_rate: 1
31 | fixed_video_len: 8
32 | moving_frame_dict: ./datasets/moving_frames.npy
33 | overall_caption_templates_file: ./datasets/caption_motion_template.json
34 | downsample_factor: 1
35 | interpolation_rate: 0.0
36 | random_start: True
37 | xflip: False
38 |
39 | test:
40 | type: MovingLabelsClipAllRateTextDataset
41 | video_dir: ./datasets/FashionDataset_frames_crop
42 | action_label_folder: ./datasets/action_label
43 | data_name_txt: ./datasets/test_frame_num.txt
44 | frame_sample_rate: 1
45 | moving_frame_dict: ./datasets/moving_frames.npy
46 | fixed_video_len: 8
47 | overall_caption_templates_file: ./datasets/caption_motion_template.json
48 | downsample_factor: 1
49 | interpolation_rate: 0.0
50 | random_start: True
51 | xflip: False
52 |
53 | prefetch_mode: ~
54 |
55 | # pretrained models
56 | img_ae_path: ./pretrained_models/vqgan_decomposed_high_res.pth
57 |
58 | model_type: VideoTransformerModel
59 |
60 | # network configs
61 | # image autoencoder
62 | img_embed_dim: 256
63 | img_n_embed: 1024
64 | img_double_z: false
65 | img_z_channels: 256
66 | img_resolution: 512
67 | img_in_channels: 3
68 | img_out_ch: 3
69 | img_ch: 128
70 | img_ch_mult: [1, 1, 2, 2, 4]
71 | img_other_ch_mult: [4, 4]
72 | img_num_res_blocks: 2
73 | img_attn_resolutions: [32]
74 | img_dropout: 0.0
75 |
76 | # sampler configs
77 | dim: 128
78 | depth: 6
79 | dim_head: 64
80 | heads: 12
81 | ff_mult: 4
82 | norm_out: true
83 | attn_dropout: 0.0
84 | ff_dropout: 0.0
85 | final_proj: true
86 | normformer: true
87 | rotary_emb: true
88 | latent_shape: [8, 4]
89 | action_label_num: 23
90 |
91 | # training configs
92 | val_freq: 50
93 | print_freq: 100
94 | weight_decay: 0
95 | manual_seed: 2022
96 | num_epochs: 1000
97 | lr: !!float 1e-4
98 | lr_decay: step
99 | gamma: 1.0
100 | step: 50
101 | perceptual_weight: 1.0
102 |
103 | larger_ratio: 3
104 |
105 | num_inside_timesteps: 24
106 | inside_ratio: 0
--------------------------------------------------------------------------------
/configs/video_transformer/video_trans_low_res.yml:
--------------------------------------------------------------------------------
1 | name: video_trans_low_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 | num_gpu: 4
6 |
7 | # dataset configs
8 | batch_size_per_gpu: 1
9 | num_worker_per_gpu: 0
10 | datasets:
11 | train:
12 | type: MovingLabelsClipAllRateTextDataset
13 | video_dir: ./datasets/FashionDataset_frames_crop
14 | action_label_folder: ./datasets/action_label
15 | data_name_txt: ./datasets/train_frame_num.txt
16 | frame_sample_rate: 1
17 | fixed_video_len: 20
18 | moving_frame_dict: ./datasets/moving_frames.npy
19 | overall_caption_templates_file: ./datasets/caption_motion_template.json
20 | interpolation_rate: 0.2
21 | downsample_factor: 2
22 | random_start: True
23 | xflip: False
24 |
25 | val:
26 | type: MovingLabelsClipAllRateTextDataset
27 | video_dir: ./datasets/FashionDataset_frames_crop
28 | action_label_folder: ./datasets/action_label
29 | data_name_txt: ./datasets/val_frame_num.txt
30 | frame_sample_rate: 1
31 | fixed_video_len: 20
32 | moving_frame_dict: ./datasets/moving_frames.npy
33 | overall_caption_templates_file: ./datasets/caption_motion_template.json
34 | downsample_factor: 2
35 | interpolation_rate: 0.0
36 | random_start: True
37 | xflip: False
38 |
39 | test:
40 | type: MovingLabelsClipAllRateTextDataset
41 | video_dir: ./datasets/FashionDataset_frames_crop
42 | action_label_folder: ./datasets/action_label
43 | data_name_txt: ./datasets/test_frame_num.txt
44 | frame_sample_rate: 1
45 | moving_frame_dict: ./datasets/moving_frames.npy
46 | fixed_video_len: 20
47 | overall_caption_templates_file: ./datasets/caption_motion_template.json
48 | downsample_factor: 2
49 | interpolation_rate: 0.0
50 | random_start: True
51 | xflip: False
52 |
53 | prefetch_mode: ~
54 |
55 | # pretrained models
56 | img_ae_path: ./pretrained_models/vqgan_decomposed_low_res.pth
57 |
58 | model_type: VideoTransformerModel
59 |
60 | # network configs
61 | # image autoencoder
62 | img_embed_dim: 256
63 | img_n_embed: 1024
64 | img_double_z: false
65 | img_z_channels: 256
66 | img_resolution: 512
67 | img_in_channels: 3
68 | img_out_ch: 3
69 | img_ch: 128
70 | img_ch_mult: [1, 1, 2, 2, 4]
71 | img_other_ch_mult: [4, 4]
72 | img_num_res_blocks: 2
73 | img_attn_resolutions: [32]
74 | img_dropout: 0.0
75 |
76 | # sampler configs
77 | dim: 128
78 | depth: 6
79 | dim_head: 64
80 | heads: 12
81 | ff_mult: 4
82 | norm_out: true
83 | attn_dropout: 0.0
84 | ff_dropout: 0.0
85 | final_proj: true
86 | normformer: true
87 | rotary_emb: true
88 | latent_shape: [4, 2]
89 | action_label_num: 23
90 |
91 | # training configs
92 | val_freq: 50
93 | print_freq: 100
94 | weight_decay: 0
95 | manual_seed: 2022
96 | num_epochs: 1000
97 | lr: !!float 1e-4
98 | lr_decay: step
99 | gamma: 1.0
100 | step: 50
101 | perceptual_weight: 1.0
102 |
103 | larger_ratio: 6
104 |
105 | num_inside_timesteps: 6
106 | inside_ratio: 0
--------------------------------------------------------------------------------
/configs/vqgan/vqgan_decompose_high_res.yml:
--------------------------------------------------------------------------------
1 | name: vqgan_decompose_high_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 | num_gpu: 4
6 |
7 | # dataset configs
8 | batch_size_per_gpu: 4
9 | num_worker_per_gpu: 1
10 | datasets:
11 | train:
12 | type: DecomposeMixDataset
13 | video_dir: ./datasets/FashionDataset_frames_crop
14 | data_name_txt: ./datasets/train_frame_num.txt
15 | shhq_data_dir: ./datasets/shhq_data
16 | downsample_factor: 1
17 | xflip: True
18 |
19 | val:
20 | type: DecomposeDataset
21 | video_dir: ./datasets/FashionDataset_frames_crop
22 | data_name_txt: ./datasets/val_frame_num.txt
23 | downsample_factor: 1
24 | xflip: False
25 |
26 | test:
27 | type: DecomposeDataset
28 | video_dir: ./datasets/FashionDataset_frames_crop
29 | data_name_txt: ./datasets/test_frame_num.txt
30 | downsample_factor: 1
31 | xflip: False
32 |
33 |
34 | model_type: VQGANDecomposeModel
35 | # network configs
36 | embed_dim: 256
37 | n_embed: 1024
38 | double_z: false
39 | z_channels: 256
40 | resolution: 512
41 | in_channels: 3
42 | out_ch: 3
43 | ch: 128
44 | ch_mult: [1, 1, 2, 2, 4]
45 | other_ch_mult: [4, 4]
46 | num_res_blocks: 2
47 | attn_resolutions: [32]
48 | dropout: 0.0
49 |
50 | disc_layers: 3
51 | disc_weight_max: 1
52 | disc_start_step: 40001
53 | n_channels: 3
54 | ndf: 64
55 | nf: 128
56 | perceptual_weight: 1.0
57 |
58 | num_segm_classes: 24
59 |
60 |
61 | # training configs
62 | val_freq: 5000
63 | print_freq: 10
64 | weight_decay: 0
65 | manual_seed: 2021
66 | num_epochs: 100000
67 | lr: !!float 1.0e-04
68 | lr_decay: step
69 | gamma: 1.0
70 | step: 50
71 |
72 | random_dropout: 1.0
--------------------------------------------------------------------------------
/configs/vqgan/vqgan_decompose_low_res.yml:
--------------------------------------------------------------------------------
1 | name: vqgan_decompose_low_res
2 | use_tb_logger: true
3 | set_CUDA_VISIBLE_DEVICES: ~
4 | gpu_ids: [3]
5 | num_gpu: 4
6 |
7 | # dataset configs
8 | batch_size_per_gpu: 16
9 | num_worker_per_gpu: 4
10 | datasets:
11 | train:
12 | type: DecomposeMixDataset
13 | video_dir: ./datasets/FashionDataset_frames_crop
14 | data_name_txt: ./datasets/train_frame_num.txt
15 | shhq_data_dir: ./datasets/shhq_dataset
16 | downsample_factor: 2
17 | xflip: True
18 |
19 | val:
20 | type: DecomposeDataset
21 | video_dir: ./datasets/FashionDataset_frames_crop
22 | data_name_txt: ./datasets/val_frame_num.txt
23 | downsample_factor: 2
24 | xflip: False
25 |
26 | test:
27 | type: DecomposeDataset
28 | video_dir: ./datasets/FashionDataset_frames_crop
29 | data_name_txt: ./datasets/test_frame_num.txt
30 | downsample_factor: 2
31 | xflip: False
32 |
33 |
34 | model_type: VQGANDecomposeModel
35 | # network configs
36 | embed_dim: 256
37 | n_embed: 1024
38 | double_z: false
39 | z_channels: 256
40 | resolution: 512
41 | in_channels: 3
42 | out_ch: 3
43 | ch: 128
44 | ch_mult: [1, 1, 2, 2, 4]
45 | other_ch_mult: [4, 4]
46 | num_res_blocks: 2
47 | attn_resolutions: [32]
48 | dropout: 0.0
49 |
50 | disc_layers: 1
51 | disc_weight_max: 1
52 | disc_start_step: 10001
53 | n_channels: 3
54 | ndf: 64
55 | nf: 128
56 | perceptual_weight: 1.0
57 |
58 | num_segm_classes: 24
59 |
60 |
61 | # training configs
62 | val_freq: 5000
63 | print_freq: 10
64 | weight_decay: 0
65 | manual_seed: 2021
66 | num_epochs: 100000
67 | lr: !!float 1.0e-04
68 | lr_decay: step
69 | gamma: 1.0
70 | step: 50
71 |
72 | random_dropout: 1.0
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import importlib
3 | import random
4 | from functools import partial
5 | from os import path as osp
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 |
11 | from data.prefetch_dataloader import PrefetchDataLoader
12 | from utils.dist_util import get_dist_info
13 | from utils.logger import get_root_logger
14 |
15 | __all__ = ['create_dataset', 'create_dataloader']
16 |
17 | # automatically scan and import dataset modules
18 | # scan all the files under the data folder with '_dataset' in file names
19 | data_folder = osp.dirname(osp.abspath(__file__))
20 | dataset_filenames = [
21 | osp.splitext(osp.basename(v))[0]
22 | for v in glob.glob(f'{data_folder}/*_dataset.py')
23 | ]
24 | # import all the dataset modules
25 | _dataset_modules = [
26 | importlib.import_module(f'data.{file_name}')
27 | for file_name in dataset_filenames
28 | ]
29 |
30 |
31 | def create_dataset(dataset_opt):
32 | """Create dataset.
33 |
34 | Args:
35 | dataset_opt (dict): Configuration for dataset. It constains:
36 | name (str): Dataset name.
37 | type (str): Dataset type.
38 | """
39 | dataset_type = dataset_opt['type']
40 |
41 | # dynamic instantiation
42 | for module in _dataset_modules:
43 | dataset_cls = getattr(module, dataset_type, None)
44 | if dataset_cls is not None:
45 | break
46 | if dataset_cls is None:
47 | raise ValueError(f'Dataset {dataset_type} is not found.')
48 |
49 | dataset = dataset_cls(dataset_opt)
50 |
51 | logger = get_root_logger()
52 | logger.info(
53 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
54 | 'is created.')
55 |
56 | return dataset
57 |
58 |
59 | def create_dataloader(dataset,
60 | dataset_opt,
61 | phase,
62 | num_gpu=1,
63 | dist=False,
64 | sampler=None,
65 | seed=None):
66 | """Create dataloader.
67 |
68 | Args:
69 | dataset (torch.utils.data.Dataset): Dataset.
70 | dataset_opt (dict): Dataset options. It contains the following keys:
71 | phase (str): 'train' or 'val'.
72 | num_worker_per_gpu (int): Number of workers for each GPU.
73 | batch_size_per_gpu (int): Training batch size for each GPU.
74 | num_gpu (int): Number of GPUs. Used only in the train phase.
75 | Default: 1.
76 | dist (bool): Whether in distributed training. Used only in the train
77 | phase. Default: False.
78 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
79 | seed (int | None): Seed. Default: None
80 | """
81 | rank, _ = get_dist_info()
82 | if phase == 'train':
83 | if dist: # distributed training
84 | batch_size = dataset_opt['batch_size_per_gpu']
85 | num_workers = dataset_opt['num_worker_per_gpu']
86 | else: # non-distributed training
87 | multiplier = 1 if num_gpu == 0 else num_gpu
88 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
89 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
90 | dataloader_args = dict(
91 | dataset=dataset,
92 | batch_size=batch_size,
93 | shuffle=False,
94 | num_workers=num_workers,
95 | sampler=sampler,
96 | drop_last=True)
97 | if sampler is None:
98 | dataloader_args['shuffle'] = True
99 | dataloader_args['worker_init_fn'] = partial(
100 | worker_init_fn, num_workers=num_workers, rank=rank,
101 | seed=seed) if seed is not None else None
102 | elif phase in ['val', 'test']: # validation
103 | dataloader_args = dict(
104 | dataset=dataset, batch_size=1, shuffle=False, num_workers=1)
105 | else:
106 | raise ValueError(f'Wrong dataset phase: {phase}. '
107 | "Supported ones are 'train', 'val' and 'test'.")
108 |
109 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
110 |
111 | prefetch_mode = dataset_opt.get('prefetch_mode')
112 | if prefetch_mode == 'cpu': # CPUPrefetcher
113 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
114 | logger = get_root_logger()
115 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
116 | f'num_prefetch_queue = {num_prefetch_queue}')
117 | return PrefetchDataLoader(
118 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
119 | else:
120 | # prefetch_mode=None: Normal dataloader
121 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
122 | return torch.utils.data.DataLoader(**dataloader_args)
123 |
124 |
125 | def worker_init_fn(worker_id, num_workers, rank, seed):
126 | # Set the worker seed to num_workers * rank + worker_id + seed
127 | worker_seed = num_workers * rank + worker_id + seed
128 | np.random.seed(worker_seed)
129 | random.seed(worker_seed)
130 |
--------------------------------------------------------------------------------
/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | class EnlargedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | Modified from torch.utils.data.distributed.DistributedSampler
11 | Support enlarging the dataset for iteration-based training, for saving
12 | time when restart the dataloader after each epoch
13 |
14 | Args:
15 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
16 | num_replicas (int | None): Number of processes participating in
17 | the training. It is usually the world_size.
18 | rank (int | None): Rank of the current process within num_replicas.
19 | ratio (int): Enlarging ratio. Default: 1.
20 | """
21 |
22 | def __init__(self, dataset, num_replicas, rank, ratio=1):
23 | self.dataset = dataset
24 | self.num_replicas = num_replicas
25 | self.rank = rank
26 | self.epoch = 0
27 | self.num_samples = math.ceil(
28 | len(self.dataset) * ratio / self.num_replicas)
29 | self.total_size = self.num_samples * self.num_replicas
30 |
31 | def __iter__(self):
32 | # deterministically shuffle based on epoch
33 | g = torch.Generator()
34 | g.manual_seed(self.epoch)
35 | indices = torch.randperm(self.total_size, generator=g).tolist()
36 |
37 | dataset_size = len(self.dataset)
38 | indices = [v % dataset_size for v in indices]
39 |
40 | # subsample
41 | indices = indices[self.rank:self.total_size:self.num_replicas]
42 | assert len(indices) == self.num_samples
43 |
44 | return iter(indices)
45 |
46 | def __len__(self):
47 | return self.num_samples
48 |
49 | def set_epoch(self, epoch):
50 | self.epoch = epoch
51 |
--------------------------------------------------------------------------------
/data/decompose_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path
4 | import random
5 |
6 | import numpy as np
7 | import torch
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 | from PIL import Image
11 |
12 |
13 | class DecomposeDataset(data.Dataset):
14 |
15 | def __init__(self, opt):
16 | self._video_dir = opt['video_dir']
17 | self.downsample_factor = opt['downsample_factor']
18 |
19 | self._video_names = []
20 | self._frame_nums = []
21 |
22 | self.transform = transforms.Compose([
23 | transforms.ColorJitter(brightness=.5, hue=.3),
24 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
25 | ])
26 |
27 | self.xflip = opt['xflip']
28 |
29 | video_names_list = open(opt['data_name_txt'], 'r').readlines()
30 |
31 | for row in video_names_list:
32 | video_name = row.split()[0]
33 | frame_num = int(row.split()[1])
34 | self._video_names.append(video_name)
35 | self._frame_nums.append(frame_num)
36 |
37 | def _load_raw_image(self, img_path):
38 | with open(img_path, 'rb') as f:
39 | image = Image.open(f)
40 | image.load()
41 | if self.downsample_factor != 1:
42 | width, height = image.size
43 | width = width // self.downsample_factor
44 | height = height // self.downsample_factor
45 | image = image.resize(
46 | size=(width, height), resample=Image.LANCZOS)
47 |
48 | return image
49 |
50 | def __getitem__(self, index):
51 | video_name = self._video_names[index]
52 |
53 | random_frame_idx = random.randint(30, self._frame_nums[index] - 30)
54 |
55 | img_path = f'{self._video_dir}/{video_name}/{random_frame_idx:03d}.png'
56 | random_frame = self._load_raw_image(img_path)
57 |
58 | random_frame_aug = self.transform(random_frame)
59 |
60 | random_frame = np.array(random_frame).transpose(2, 0,
61 | 1).astype(np.float32)
62 | random_frame_aug = np.array(random_frame_aug).transpose(
63 | 2, 0, 1).astype(np.float32)
64 |
65 | identity_image = self._load_raw_image(
66 | f'{self._video_dir}/{video_name}/000.png')
67 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype(
68 | np.float32)
69 |
70 | if self.xflip and random.random() > 0.5:
71 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W]
72 | random_frame_aug = random_frame_aug[:, :, ::-1].copy()
73 |
74 | random_frame = random_frame / 127.5 - 1
75 | random_frame_aug = random_frame_aug / 127.5 - 1
76 | identity_image = identity_image / 127.5 - 1
77 |
78 | random_frame = torch.from_numpy(random_frame)
79 | identity_image = torch.from_numpy(identity_image)
80 |
81 | return_dict = {
82 | # 'densepose': pose,
83 | 'video_name': f'{video_name}_{random_frame_idx:03d}',
84 | 'frame_img': random_frame,
85 | 'frame_img_aug': random_frame_aug,
86 | 'identity_image': identity_image
87 | }
88 |
89 | return return_dict
90 |
91 | def __len__(self):
92 | return len(self._video_names)
93 |
94 |
95 | class DecomposeMixDataset(data.Dataset):
96 |
97 | def __init__(self, opt):
98 | self._video_dir = opt['video_dir']
99 | self.downsample_factor = opt['downsample_factor']
100 |
101 | _video_names = []
102 | _frame_nums = []
103 |
104 | self.transform = transforms.Compose([
105 | transforms.ColorJitter(brightness=.5, hue=.3),
106 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
107 | ])
108 |
109 | self.translate_transform = transforms.Compose(
110 | [transforms.RandomAffine(degrees=0, translate=(0.1, 0.3))])
111 |
112 | self.xflip = opt['xflip']
113 |
114 | video_names_list = open(opt['data_name_txt'], 'r').readlines()
115 |
116 | for row in video_names_list:
117 | video_name = row.split()[0]
118 | frame_num = int(row.split()[1])
119 | _video_names.append(video_name)
120 | _frame_nums.append(frame_num)
121 |
122 | self._shhq_data_dir = opt['shhq_data_dir']
123 | self._ann_dir = opt['ann_dir']
124 |
125 | _image_fnames = []
126 | # for idx, row in enumerate(
127 | # open(os.path.join(f'{self._ann_dir}/upper_fused.txt'), 'r')):
128 | # annotations = row.split()
129 | # if len(annotations[:-1]) == 1:
130 | # img_name = annotations[0]
131 | # else:
132 | # img_name = ''
133 | # for name in annotations[:-1]:
134 | # img_name += f'{name}\xa0'
135 | # img_name = img_name[:-1]
136 | # _image_fnames.append(img_name)
137 | shhq_path_list = glob.glob(f'{self._shhq_data_dir}/*.png')
138 | for shhq_path in shhq_path_list:
139 | _image_fnames.append(shhq_path.split('/')[-1])
140 |
141 | augment_times = max(1, len(_image_fnames) // len(_video_names))
142 |
143 | augmented_videos = _video_names * augment_times
144 | self._frame_nums = _frame_nums * augment_times
145 | self._all_file_name = augmented_videos + _image_fnames
146 | self.video_num = len(augmented_videos)
147 |
148 | def _load_raw_image(self, img_path):
149 | with open(img_path, 'rb') as f:
150 | image = Image.open(f)
151 | image.load()
152 | if self.downsample_factor != 1:
153 | width, height = image.size
154 | width = width // self.downsample_factor
155 | height = height // self.downsample_factor
156 | image = image.resize(
157 | size=(width, height), resample=Image.LANCZOS)
158 |
159 | return image
160 |
161 | def sample_video_data(self, index):
162 | video_name = self._all_file_name[index]
163 |
164 | random_frame_idx = random.randint(30, self._frame_nums[index] - 30)
165 |
166 | img_path = f'{self._video_dir}/{video_name}/{random_frame_idx:03d}.png'
167 | random_frame = self._load_raw_image(img_path)
168 |
169 | random_frame_aug = self.transform(random_frame)
170 |
171 | random_frame = np.array(random_frame).transpose(2, 0,
172 | 1).astype(np.float32)
173 | random_frame_aug = np.array(random_frame_aug).transpose(
174 | 2, 0, 1).astype(np.float32)
175 |
176 | identity_image = self._load_raw_image(
177 | f'{self._video_dir}/{video_name}/000.png')
178 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype(
179 | np.float32)
180 |
181 | if self.xflip and random.random() > 0.5:
182 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W]
183 | random_frame_aug = random_frame_aug[:, :, ::-1].copy()
184 |
185 | return identity_image, random_frame, random_frame_aug
186 |
187 | def sample_img_data(self, index):
188 | img_name = self._all_file_name[index]
189 |
190 | img_path = f'{self._shhq_data_dir}/{img_name}'
191 |
192 | identity_image = self._load_raw_image(img_path)
193 |
194 | random_frame = self.translate_transform(identity_image)
195 |
196 | random_frame_aug = self.transform(random_frame)
197 |
198 | random_frame = np.array(random_frame).transpose(2, 0,
199 | 1).astype(np.float32)
200 | random_frame_aug = np.array(random_frame_aug).transpose(
201 | 2, 0, 1).astype(np.float32)
202 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype(
203 | np.float32)
204 |
205 | if self.xflip and random.random() > 0.5:
206 | random_frame = random_frame[:, :, ::-1].copy() # [C, H ,W]
207 | random_frame_aug = random_frame_aug[:, :, ::-1].copy()
208 |
209 | return identity_image, random_frame, random_frame_aug
210 |
211 | def __getitem__(self, index):
212 |
213 | if index < self.video_num:
214 | identity_image, random_frame, random_frame_aug = self.sample_video_data(
215 | index)
216 | else:
217 | identity_image, random_frame, random_frame_aug = self.sample_img_data(
218 | index)
219 |
220 | random_frame = random_frame / 127.5 - 1
221 | random_frame_aug = random_frame_aug / 127.5 - 1
222 | identity_image = identity_image / 127.5 - 1
223 |
224 | random_frame = torch.from_numpy(random_frame)
225 | random_frame_aug = torch.from_numpy(random_frame_aug)
226 | identity_image = torch.from_numpy(identity_image)
227 |
228 | return_dict = {
229 | 'frame_img': random_frame,
230 | 'frame_img_aug': random_frame_aug,
231 | 'identity_image': identity_image
232 | }
233 |
234 | return return_dict
235 |
236 | def __len__(self):
237 | return len(self._all_file_name)
238 |
--------------------------------------------------------------------------------
/data/moving_label_clip_all_rate_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 | from PIL import Image
8 |
9 |
10 | def proper_capitalize(text):
11 | if len(text) > 0:
12 | text = text.lower()
13 | text = text[0].capitalize() + text[1:]
14 | for idx, char in enumerate(text):
15 | if char in ['.', '!', '?'] and (idx + 2) < len(text):
16 | text = text[:idx + 2] + text[idx + 2].capitalize() + text[idx +
17 | 3:]
18 | text = text.replace(' i ', ' I ')
19 | text = text.replace(',i ', ',I ')
20 | text = text.replace('.i ', '.I ')
21 | text = text.replace('!i ', '!I ')
22 | return text
23 |
24 |
25 | class MovingLabelsClipAllRateTextDataset(data.Dataset):
26 |
27 | def __init__(self, opt):
28 | self._video_dir = opt['video_dir']
29 | self.downsample_factor = opt['downsample_factor']
30 | self.random_start = opt['random_start']
31 | self._video_names = []
32 |
33 | self.frame_sample_rate = opt['frame_sample_rate']
34 |
35 | self._action_label_folder = opt['action_label_folder']
36 | self.action_labels = []
37 |
38 | self.fixed_video_len = opt['fixed_video_len']
39 | self.xflip = opt['xflip']
40 |
41 | video_names_list = open(opt['data_name_txt'], 'r').readlines()
42 |
43 | self.moving_dict = np.load(
44 | opt['moving_frame_dict'], allow_pickle=True).item()
45 |
46 | self.interpolation_rate = opt['interpolation_rate']
47 |
48 | self.all_clip_start_frame_list = []
49 | self.all_clip_end_frame_list = []
50 | self.all_clip_action_label_list = []
51 | self.frame_num_list = []
52 |
53 | with open(opt['overall_caption_templates_file'], 'r') as f:
54 | self.overall_caption_templates = json.load(f)
55 |
56 | for row in video_names_list:
57 | video_name = row.split()[0]
58 | frame_nums = int(row.split()[1])
59 | action_label_txt = open(
60 | f'{self._action_label_folder}/{video_name}.txt',
61 | 'r').readlines()
62 |
63 | clip_start_frame_list = []
64 | clip_end_frame_list = []
65 | clip_action_label_list = []
66 | for action_row in action_label_txt:
67 | start_frame, end_frame, action_label = action_row[:-1].split()
68 | start_frame = int(start_frame)
69 | end_frame = int(end_frame)
70 | action_label = int(action_label)
71 |
72 | if (end_frame - start_frame
73 | ) < self.frame_sample_rate * self.fixed_video_len:
74 | continue
75 |
76 | clip_start_frame_list.append(start_frame)
77 | clip_end_frame_list.append(end_frame)
78 | clip_action_label_list.append(action_label)
79 |
80 | if len(clip_start_frame_list) == 0:
81 | continue
82 |
83 | self._video_names.append(video_name)
84 | self.all_clip_start_frame_list.append(clip_start_frame_list)
85 | self.all_clip_end_frame_list.append(clip_end_frame_list)
86 | self.all_clip_action_label_list.append(clip_action_label_list)
87 | self.frame_num_list.append(frame_nums)
88 |
89 | assert len(self._video_names) == len(self.all_clip_start_frame_list)
90 | assert len(self._video_names) == len(self.all_clip_end_frame_list)
91 | assert len(self._video_names) == len(self.all_clip_action_label_list)
92 | assert len(self._video_names) == len(self.frame_num_list)
93 |
94 | def _load_raw_image(self, img_path):
95 | with open(img_path, 'rb') as f:
96 | image = Image.open(f)
97 | if self.downsample_factor != 1:
98 | width, height = image.size
99 | width = width // self.downsample_factor
100 | height = height // self.downsample_factor
101 | image = image.resize(
102 | size=(width, height), resample=Image.LANCZOS)
103 | image = np.array(image)
104 | if image.ndim == 2:
105 | image = image[:, :, np.newaxis] # HW => HWC
106 | image = image.transpose(2, 0, 1) # HWC => CHW
107 | return image.astype(np.float32)
108 |
109 | def sample_motion_clip(self, index):
110 | clip_start_frame_list = self.all_clip_start_frame_list[index]
111 | clip_end_frame_list = self.all_clip_end_frame_list[index]
112 | clip_action_label_list = self.all_clip_action_label_list[index]
113 |
114 | num_clip = len(clip_start_frame_list)
115 |
116 | clip_index = random.randint(0, num_clip - 1)
117 |
118 | action_label_list = clip_action_label_list[clip_index]
119 |
120 | clip_idx = list(
121 | range(clip_start_frame_list[clip_index],
122 | clip_end_frame_list[clip_index] + 1))
123 |
124 | segm = len(clip_idx) // self.fixed_video_len
125 |
126 | segm_dist = []
127 | for i in range(self.fixed_video_len - 1):
128 | segm_dist.append(segm)
129 |
130 | for i in range(
131 | min(
132 | len(clip_idx) - sum(segm_dist) - 2,
133 | self.fixed_video_len - 1)):
134 | segm_dist[i] += 1
135 |
136 | frame_idx_list = []
137 | frame_idx_list.append(clip_start_frame_list[clip_index])
138 |
139 | for i in range(len(segm_dist) - 1):
140 | frame_idx_list.append(clip_start_frame_list[clip_index] +
141 | sum(segm_dist[:i + 1]))
142 | frame_idx_list.append(clip_end_frame_list[clip_index])
143 |
144 | return frame_idx_list, action_label_list
145 |
146 | def sample_random_clip(self, index):
147 |
148 | video_name = self._video_names[index]
149 |
150 | video_len = self.fixed_video_len
151 |
152 | if len(self.moving_dict[video_name]) == 0:
153 | start_frame = random.randint(
154 | 30, self.frame_num_list[index] - 1 - video_len)
155 | else:
156 | start_frame = random.choice(self.moving_dict[video_name])
157 |
158 | while ((start_frame + video_len) >
159 | (self.frame_num_list[index] - 1)):
160 | start_frame = random.choice(self.moving_dict[video_name])
161 |
162 | frame_idx_list = []
163 | for frame_idx in range(video_len):
164 | video_frame_idx = start_frame + frame_idx
165 | frame_idx_list.append(video_frame_idx)
166 |
167 | action_label_list = 22
168 |
169 | return frame_idx_list, action_label_list
170 |
171 | def generate_caption(self, label):
172 | caption = random.choice(self.overall_caption_templates[label])
173 | replacing_word = random.choice(
174 | self.overall_caption_templates["gender"])
175 | caption = caption.replace('', replacing_word)
176 | caption = proper_capitalize(caption)
177 |
178 | return caption
179 |
180 | def __getitem__(self, index):
181 | video_name = self._video_names[index]
182 |
183 | if np.random.uniform(low=0.0, high=1.0) < self.interpolation_rate:
184 | frame_idx_list, action_label_list = self.sample_random_clip(index)
185 | interpolation_mode = True
186 | else:
187 | frame_idx_list, action_label_list = self.sample_motion_clip(index)
188 | interpolation_mode = False
189 |
190 | frames = []
191 | for frame_idx in frame_idx_list:
192 | img_path = f'{self._video_dir}/{video_name}/{frame_idx:03d}.png'
193 | frames.append(self._load_raw_image(img_path))
194 |
195 | frames = np.stack(frames, axis=0)
196 |
197 | exemplar_img = self._load_raw_image(
198 | f'{self._video_dir}/{video_name}/000.png')
199 |
200 | frames = frames / 127.5 - 1
201 | exemplar_img = exemplar_img / 127.5 - 1
202 |
203 | frames = torch.from_numpy(frames)
204 | exemplar_img = torch.from_numpy(exemplar_img)
205 |
206 | if action_label_list == 22:
207 | text_description = 'empty'
208 | else:
209 | text_description = self.generate_caption(str(action_label_list))
210 |
211 | return_dict = {
212 | 'video_name': video_name,
213 | 'video_frames': frames,
214 | 'video_len': self.fixed_video_len,
215 | 'exemplar_img': exemplar_img,
216 | 'action_labels': action_label_list,
217 | 'text': text_description,
218 | 'interpolation_mode': interpolation_mode
219 | }
220 |
221 | return return_dict
222 |
223 | def __len__(self):
224 | return len(self._video_names)
225 |
--------------------------------------------------------------------------------
/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class PrefetchGenerator(threading.Thread):
9 | """A general prefetch generator.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
13 |
14 | Args:
15 | generator: Python generator.
16 | num_prefetch_queue (int): Number of prefetch queue.
17 | """
18 |
19 | def __init__(self, generator, num_prefetch_queue):
20 | threading.Thread.__init__(self)
21 | self.queue = Queue.Queue(num_prefetch_queue)
22 | self.generator = generator
23 | self.daemon = True
24 | self.start()
25 |
26 | def run(self):
27 | for item in self.generator:
28 | self.queue.put(item)
29 | self.queue.put(None)
30 |
31 | def __next__(self):
32 | next_item = self.queue.get()
33 | if next_item is None:
34 | raise StopIteration
35 | return next_item
36 |
37 | def __iter__(self):
38 | return self
39 |
40 |
41 | class PrefetchDataLoader(DataLoader):
42 | """Prefetch version of dataloader.
43 |
44 | Ref:
45 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
46 |
47 | TODO:
48 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
49 | ddp.
50 |
51 | Args:
52 | num_prefetch_queue (int): Number of prefetch queue.
53 | kwargs (dict): Other arguments for dataloader.
54 | """
55 |
56 | def __init__(self, num_prefetch_queue, **kwargs):
57 | self.num_prefetch_queue = num_prefetch_queue
58 | super(PrefetchDataLoader, self).__init__(**kwargs)
59 |
60 | def __iter__(self):
61 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
62 |
63 |
64 | class CPUPrefetcher():
65 | """CPU prefetcher.
66 |
67 | Args:
68 | loader: Dataloader.
69 | """
70 |
71 | def __init__(self, loader):
72 | self.ori_loader = loader
73 | self.loader = iter(loader)
74 |
75 | def next(self):
76 | try:
77 | return next(self.loader)
78 | except StopIteration:
79 | return None
80 |
81 | def reset(self):
82 | self.loader = iter(self.ori_loader)
83 |
84 |
85 | class CUDAPrefetcher():
86 | """CUDA prefetcher.
87 |
88 | Ref:
89 | https://github.com/NVIDIA/apex/issues/304#
90 |
91 | It may consums more GPU memory.
92 |
93 | Args:
94 | loader: Dataloader.
95 | opt (dict): Options.
96 | """
97 |
98 | def __init__(self, loader, opt):
99 | self.ori_loader = loader
100 | self.loader = iter(loader)
101 | self.opt = opt
102 | self.stream = torch.cuda.Stream()
103 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
104 | self.preload()
105 |
106 | def preload(self):
107 | try:
108 | self.batch = next(self.loader) # self.batch is a dict
109 | except StopIteration:
110 | self.batch = None
111 | return None
112 | # put tensors to gpu
113 | with torch.cuda.stream(self.stream):
114 | for k, v in self.batch.items():
115 | if torch.is_tensor(v):
116 | self.batch[k] = self.batch[k].to(
117 | device=self.device, non_blocking=True)
118 |
119 | def next(self):
120 | torch.cuda.current_stream().wait_stream(self.stream)
121 | batch = self.batch
122 | self.preload()
123 | return batch
124 |
125 | def reset(self):
126 | self.loader = iter(self.ori_loader)
127 | self.preload()
128 |
--------------------------------------------------------------------------------
/data/sample_identity_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 | from PIL import Image
8 |
9 |
10 | class SampleIdentityDataset(data.Dataset):
11 |
12 | def __init__(self, opt):
13 | self._video_dir = opt['video_dir']
14 | self.downsample_factor = opt['downsample_factor']
15 |
16 | self._video_names = []
17 | self._frame_nums = []
18 |
19 | self.xflip = opt['xflip']
20 |
21 | video_names_list = open(opt['data_name_txt'], 'r').readlines()
22 |
23 | for row in video_names_list:
24 | video_name = row.split()[0]
25 | self._video_names.append(video_name)
26 |
27 | with open(opt['text_file']) as json_file:
28 | self.text_descriptions = json.load(json_file)
29 |
30 | def _load_raw_image(self, img_path):
31 | with open(img_path, 'rb') as f:
32 | image = Image.open(f)
33 | image.load()
34 | if self.downsample_factor != 1:
35 | width, height = image.size
36 | width = width // self.downsample_factor
37 | height = height // self.downsample_factor
38 | image = image.resize(
39 | size=(width, height), resample=Image.LANCZOS)
40 |
41 | return image
42 |
43 | def __getitem__(self, index):
44 | video_name = self._video_names[index]
45 |
46 | identity_image = self._load_raw_image(
47 | f'{self._video_dir}/{video_name}/000.png')
48 | identity_image = np.array(identity_image).transpose(2, 0, 1).astype(
49 | np.float32)
50 | if self.xflip and random.random() > 0.5:
51 | identity_image = identity_image[:, :, ::-1].copy() # [C, H ,W]
52 |
53 | identity_image = identity_image / 127.5 - 1
54 |
55 | identity_image = torch.from_numpy(identity_image)
56 |
57 | if len(self.text_descriptions[video_name]) == 1:
58 | text_description = self.text_descriptions[video_name]
59 | else:
60 | text_description = random.choice(
61 | self.text_descriptions[video_name])
62 |
63 | return_dict = {
64 | 'img_name': video_name,
65 | 'image': identity_image,
66 | 'text': text_description
67 | }
68 |
69 | return return_dict
70 |
71 | def __len__(self):
72 | return len(self._video_names)
73 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: text2performer
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - blas=1.0=mkl
8 | - ca-certificates=2023.08.22=h06a4308_0
9 | - certifi=2022.12.7=py37h06a4308_0
10 | - cudatoolkit=10.1.243=h6bb024c_0
11 | - freetype=2.11.0=h70c0345_0
12 | - giflib=5.2.1=h7b6447c_0
13 | - intel-openmp=2021.4.0=h06a4308_3561
14 | - jpeg=9b=h024ee3a_2
15 | - lcms2=2.12=h3be6417_0
16 | - ld_impl_linux-64=2.38=h1181459_1
17 | - libffi=3.3=he6710b0_2
18 | - libgcc-ng=9.1.0=hdf63c60_0
19 | - libpng=1.6.37=hbc83047_0
20 | - libstdcxx-ng=9.1.0=hdf63c60_0
21 | - libtiff=4.1.0=h2733197_1
22 | - libuv=1.40.0=h7b6447c_0
23 | - libwebp=1.2.0=h89dd481_0
24 | - lz4-c=1.9.3=h295c915_1
25 | - mkl=2021.4.0=h06a4308_640
26 | - mkl-service=2.4.0=py37h7f8727e_0
27 | - mkl_fft=1.3.1=py37hd3c417c_0
28 | - mkl_random=1.2.2=py37h51133e4_0
29 | - ncurses=6.3=h7f8727e_2
30 | - ninja=1.10.2=h06a4308_5
31 | - ninja-base=1.10.2=hd09550d_5
32 | - numpy=1.21.5=py37he7a7128_2
33 | - numpy-base=1.21.5=py37hf524024_2
34 | - openssl=1.1.1w=h7f8727e_0
35 | - pillow=9.0.1=py37h22f2fdc_0
36 | - pip=22.3.1=py37h06a4308_0
37 | - python=3.7.11=h12debd9_0
38 | - pytorch=1.7.1=py3.7_cuda10.1.243_cudnn7.6.3_0
39 | - readline=8.1.2=h7f8727e_1
40 | - setuptools=65.6.3=py37h06a4308_0
41 | - six=1.16.0=pyhd3eb1b0_1
42 | - sqlite=3.38.5=hc218d9a_0
43 | - tk=8.6.12=h1ccaba5_0
44 | - torchvision=0.8.2=py37_cu101
45 | - typing_extensions=4.3.0=py37h06a4308_0
46 | - wheel=0.38.4=py37h06a4308_0
47 | - xz=5.2.5=h7f8727e_1
48 | - zlib=1.2.12=h7f8727e_2
49 | - zstd=1.4.9=haebb681_0
50 | - pip:
51 | - absl-py==2.0.0
52 | - anyio==3.7.1
53 | - argon2-cffi==23.1.0
54 | - argon2-cffi-bindings==21.2.0
55 | - attrs==23.1.0
56 | - babel==2.12.1
57 | - backcall==0.2.0
58 | - beautifulsoup4==4.12.2
59 | - bleach==6.0.0
60 | - cachetools==4.2.4
61 | - cffi==1.15.1
62 | - charset-normalizer==3.2.0
63 | - click==8.1.7
64 | - debugpy==1.7.0
65 | - decorator==5.1.1
66 | - defusedxml==0.7.1
67 | - einops==0.6.1
68 | - entrypoints==0.4
69 | - exceptiongroup==1.1.3
70 | - fastjsonschema==2.18.0
71 | - filelock==3.12.2
72 | - fsspec==2023.1.0
73 | - google-auth==1.35.0
74 | - google-auth-oauthlib==0.4.6
75 | - grpcio==1.58.0
76 | - huggingface-hub==0.16.4
77 | - idna==3.4
78 | - importlib-metadata==6.7.0
79 | - importlib-resources==5.12.0
80 | - ipykernel==6.16.2
81 | - ipython==7.34.0
82 | - ipython-genutils==0.2.0
83 | - jedi==0.19.0
84 | - jinja2==3.1.2
85 | - joblib==1.3.2
86 | - json5==0.9.14
87 | - jsonschema==4.17.3
88 | - jupyter-client==7.4.9
89 | - jupyter-core==4.12.0
90 | - jupyter-server==1.24.0
91 | - jupyterlab==3.3.2
92 | - jupyterlab-pygments==0.2.2
93 | - jupyterlab-server==2.24.0
94 | - lpips==0.1.4
95 | - markdown==3.4.4
96 | - markupsafe==2.1.3
97 | - matplotlib-inline==0.1.6
98 | - mistune==3.0.1
99 | - nbclassic==0.5.6
100 | - nbclient==0.7.4
101 | - nbconvert==7.6.0
102 | - nbformat==5.8.0
103 | - nest-asyncio==1.5.8
104 | - nltk==3.8.1
105 | - notebook-shim==0.2.3
106 | - oauthlib==3.2.2
107 | - opencv-python==4.5.5.62
108 | - packaging==23.1
109 | - pandocfilters==1.5.0
110 | - parso==0.8.3
111 | - pexpect==4.8.0
112 | - pickleshare==0.7.5
113 | - pkgutil-resolve-name==1.3.10
114 | - prometheus-client==0.17.1
115 | - prompt-toolkit==3.0.39
116 | - protobuf==3.20.3
117 | - psutil==5.9.5
118 | - ptyprocess==0.7.0
119 | - pyasn1==0.5.0
120 | - pyasn1-modules==0.3.0
121 | - pycparser==2.21
122 | - pygments==2.16.1
123 | - pyrsistent==0.19.3
124 | - python-dateutil==2.8.2
125 | - pytz==2023.3.post1
126 | - pyyaml==6.0.1
127 | - pyzmq==25.1.1
128 | - regex==2023.8.8
129 | - requests==2.31.0
130 | - requests-oauthlib==1.3.1
131 | - safetensors==0.3.3
132 | - scikit-learn==1.0.2
133 | - scipy==1.7.3
134 | - send2trash==1.8.2
135 | - sentence-transformers==2.2.2
136 | - sentencepiece==0.1.99
137 | - sniffio==1.3.0
138 | - soupsieve==2.4.1
139 | - tensorboard==2.5.0
140 | - tensorboard-data-server==0.6.1
141 | - tensorboard-plugin-wit==1.8.1
142 | - terminado==0.17.1
143 | - threadpoolctl==3.1.0
144 | - tinycss2==1.2.1
145 | - tokenizers==0.13.3
146 | - tornado==6.2
147 | - tqdm==4.66.1
148 | - traitlets==5.9.0
149 | - transformers==4.30.2
150 | - urllib3==2.0.5
151 | - wcwidth==0.2.7
152 | - webencodings==0.5.1
153 | - websocket-client==1.6.1
154 | - werkzeug==2.2.3
155 | - zipp==3.15.0
156 |
157 |
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/img/teaser.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import importlib
3 | import logging
4 | import os.path as osp
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0]
12 | for v in glob.glob(f'{model_folder}/*_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamically instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = logging.getLogger('base')
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/models/app_transformer_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributions as dists
8 | import torch.nn.functional as F
9 | from sentence_transformers import SentenceTransformer
10 | from torchvision.utils import save_image
11 |
12 | from models.archs.transformer_arch import TransformerLanguage
13 | from models.archs.vqgan_arch import (
14 | DecoderUpOthersDoubleIdentity,
15 | EncoderDecomposeBaseDownOthersDoubleIdentity, VectorQuantizer)
16 |
17 | logger = logging.getLogger('base')
18 |
19 |
20 | class AppTransformerModel():
21 | """Texture-Aware Diffusion based Transformer model.
22 | """
23 |
24 | def __init__(self, opt):
25 | self.opt = opt
26 | self.device = torch.device('cuda')
27 | self.is_train = opt['is_train']
28 |
29 | # VQVAE for image
30 | self.img_encoder = EncoderDecomposeBaseDownOthersDoubleIdentity(
31 | ch=opt['img_ch'],
32 | num_res_blocks=opt['img_num_res_blocks'],
33 | attn_resolutions=opt['img_attn_resolutions'],
34 | ch_mult=opt['img_ch_mult'],
35 | other_ch_mult=opt['img_other_ch_mult'],
36 | in_channels=opt['img_in_channels'],
37 | resolution=opt['img_resolution'],
38 | z_channels=opt['img_z_channels'],
39 | double_z=opt['img_double_z'],
40 | dropout=opt['img_dropout']).to(self.device)
41 | self.img_decoder = DecoderUpOthersDoubleIdentity(
42 | in_channels=opt['img_in_channels'],
43 | resolution=opt['img_resolution'],
44 | z_channels=opt['img_z_channels'],
45 | ch=opt['img_ch'],
46 | out_ch=opt['img_out_ch'],
47 | num_res_blocks=opt['img_num_res_blocks'],
48 | attn_resolutions=opt['img_attn_resolutions'],
49 | ch_mult=opt['img_ch_mult'],
50 | other_ch_mult=opt['img_other_ch_mult'],
51 | dropout=opt['img_dropout'],
52 | resamp_with_conv=True,
53 | give_pre_end=False).to(self.device)
54 | self.quantize_identity = VectorQuantizer(
55 | opt['img_n_embed'], opt['img_embed_dim'],
56 | beta=0.25).to(self.device)
57 | self.quant_conv_identity = torch.nn.Conv2d(opt["img_z_channels"],
58 | opt['img_embed_dim'],
59 | 1).to(self.device)
60 | self.post_quant_conv_identity = torch.nn.Conv2d(
61 | opt['img_embed_dim'], opt["img_z_channels"], 1).to(self.device)
62 |
63 | self.quantize_others = VectorQuantizer(
64 | opt['img_n_embed'], opt['img_embed_dim'] // 2,
65 | beta=0.25).to(self.device)
66 | self.quant_conv_others = torch.nn.Conv2d(opt["img_z_channels"] // 2,
67 | opt['img_embed_dim'] // 2,
68 | 1).to(self.device)
69 | self.post_quant_conv_others = torch.nn.Conv2d(
70 | opt['img_embed_dim'] // 2, opt["img_z_channels"] // 2,
71 | 1).to(self.device)
72 | self.load_pretrained_image_vae()
73 |
74 | # define sampler
75 | self._denoise_fn = TransformerLanguage(
76 | codebook_size=opt['codebook_size'],
77 | bert_n_emb=opt['bert_n_emb'],
78 | bert_n_layers=opt['bert_n_layers'],
79 | bert_n_head=opt['bert_n_head'],
80 | block_size=opt['block_size'] * 2,
81 | embd_pdrop=opt['embd_pdrop'],
82 | resid_pdrop=opt['resid_pdrop'],
83 | attn_pdrop=opt['attn_pdrop']).to(self.device)
84 |
85 | self.num_classes = opt['codebook_size']
86 | self.shape = tuple(opt['latent_shape'])
87 | self.num_timesteps = 1000
88 |
89 | self.mask_id = opt['codebook_size']
90 | self.loss_type = opt['loss_type']
91 | self.mask_schedule = opt['mask_schedule']
92 |
93 | self.sample_steps = opt['sample_steps']
94 |
95 | self.text_seq_len = opt['text_seq_len']
96 |
97 | self.init_training_settings()
98 |
99 | self.get_fixed_language_model()
100 |
101 | def load_pretrained_image_vae(self):
102 | # load pretrained vqgan for segmentation mask
103 | img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
104 | self.img_encoder.load_state_dict(
105 | img_ae_checkpoint['encoder'], strict=True)
106 | self.img_decoder.load_state_dict(
107 | img_ae_checkpoint['decoder'], strict=True)
108 | self.quantize_identity.load_state_dict(
109 | img_ae_checkpoint['quantize_identity'], strict=True)
110 | self.quant_conv_identity.load_state_dict(
111 | img_ae_checkpoint['quant_conv_identity'], strict=True)
112 | self.post_quant_conv_identity.load_state_dict(
113 | img_ae_checkpoint['post_quant_conv_identity'], strict=True)
114 | self.quantize_others.load_state_dict(
115 | img_ae_checkpoint['quantize_others'], strict=True)
116 | self.quant_conv_others.load_state_dict(
117 | img_ae_checkpoint['quant_conv_others'], strict=True)
118 | self.post_quant_conv_others.load_state_dict(
119 | img_ae_checkpoint['post_quant_conv_others'], strict=True)
120 | self.img_encoder.eval()
121 | self.img_decoder.eval()
122 | self.quantize_identity.eval()
123 | self.quant_conv_identity.eval()
124 | self.post_quant_conv_identity.eval()
125 | self.quantize_others.eval()
126 | self.quant_conv_others.eval()
127 | self.post_quant_conv_others.eval()
128 |
129 | def init_training_settings(self):
130 | optim_params = []
131 | for v in self._denoise_fn.parameters():
132 | if v.requires_grad:
133 | optim_params.append(v)
134 | # set up optimizer
135 | self.optimizer = torch.optim.Adam(
136 | optim_params,
137 | self.opt['lr'],
138 | weight_decay=self.opt['weight_decay'])
139 | self.log_dict = OrderedDict()
140 |
141 | @torch.no_grad()
142 | def get_quantized_img(self, image):
143 | h_identity, _ = self.img_encoder(image)
144 | h_identity = self.quant_conv_identity(h_identity)
145 | _, _, [_, _, identity_tokens] = self.quantize_identity(h_identity)
146 |
147 | _, h_frame = self.img_encoder(image)
148 | h_frame = self.quant_conv_others(h_frame)
149 | _, _, [_, _, pose_tokens] = self.quantize_others(h_frame)
150 |
151 | # reshape the tokens
152 | b = image.size(0)
153 | identity_tokens = identity_tokens.view(b, -1)
154 | pose_tokens = pose_tokens.view(b, -1)
155 |
156 | return identity_tokens, pose_tokens
157 |
158 | @torch.no_grad()
159 | def decode(self, quant_list):
160 | quant_identity = self.post_quant_conv_identity(quant_list[0])
161 | quant_frame = self.post_quant_conv_others(quant_list[1])
162 | dec = self.img_decoder(quant_identity, quant_frame)
163 | return dec
164 |
165 | @torch.no_grad()
166 | def decode_image_indices(self, quant_identity, quant_frame):
167 | quant_identity = self.quantize_identity.get_codebook_entry(
168 | quant_identity, (quant_identity.size(0), self.shape[0],
169 | self.shape[1], self.opt["img_z_channels"]))
170 | quant_frame = self.quantize_others.get_codebook_entry(
171 | quant_frame, (quant_frame.size(0), self.shape[0] // 4,
172 | self.shape[1] // 4, self.opt["img_z_channels"] // 2))
173 | dec = self.decode([quant_identity, quant_frame])
174 |
175 | return dec
176 |
177 | def sample_time(self, b, device, method='uniform'):
178 | if method == 'importance':
179 | if not (self.Lt_count > 10).all():
180 | return self.sample_time(b, device, method='uniform')
181 |
182 | Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
183 | Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
184 | pt_all = Lt_sqrt / Lt_sqrt.sum()
185 |
186 | t = torch.multinomial(pt_all, num_samples=b, replacement=True)
187 |
188 | pt = pt_all.gather(dim=0, index=t)
189 |
190 | return t, pt
191 |
192 | elif method == 'uniform':
193 | t = torch.randint(
194 | 1, self.num_timesteps + 1, (b, ), device=device).long()
195 | pt = torch.ones_like(t).float() / self.num_timesteps
196 | return t, pt
197 |
198 | else:
199 | raise ValueError
200 |
201 | def q_sample(self, x_0, x_0_gt, t):
202 | # samples q(x_t | x_0)
203 | # randomly set token to mask with probability t/T
204 | # x_t, x_0_ignore = x_0.clone(), x_0.clone()
205 | x_t = x_0.clone()
206 |
207 | mask = torch.rand_like(x_t.float()) < (
208 | t.float().unsqueeze(-1) / self.num_timesteps)
209 | x_t[mask] = self.mask_id
210 | # x_0_ignore[torch.bitwise_not(mask)] = -1
211 |
212 | # for every gt token list, we also need to do the mask
213 | x_0_gt_ignore = x_0_gt.clone()
214 | x_0_gt_ignore[torch.bitwise_not(mask)] = -1
215 |
216 | return x_t, x_0_gt_ignore, mask
217 |
218 | def _train_loss(self, x_identity_0, x_pose_0):
219 | b, device = x_identity_0.size(0), x_identity_0.device
220 |
221 | # choose what time steps to compute loss at
222 | t, pt = self.sample_time(b, device, 'uniform')
223 |
224 | # make x noisy and denoise
225 | if self.mask_schedule == 'random':
226 | x_identity_t, x_identity_0_gt_ignore, mask = self.q_sample(
227 | x_0=x_identity_0, x_0_gt=x_identity_0, t=t)
228 | x_pose_t, x_pose_0_gt_ignore, mask = self.q_sample(
229 | x_0=x_pose_0, x_0_gt=x_pose_0, t=t)
230 | else:
231 | raise NotImplementedError
232 |
233 | # sample p(x_0 | x_t)
234 | x_identity_0_hat_logits, x_pose_0_hat_logits = self._denoise_fn(
235 | self.text_embedding, x_identity_t, x_pose_t, t=t)
236 |
237 | x_identity_0_hat_logits = x_identity_0_hat_logits[:,
238 | 1:1 + self.shape[0] *
239 | self.shape[1], :]
240 | x_pose_0_hat_logits = x_pose_0_hat_logits[:, 1 + self.shape[0] *
241 | self.shape[1]:]
242 |
243 | # Always compute ELBO for comparison purposes
244 | cross_entropy_loss = 0
245 |
246 | cross_entropy_loss = F.cross_entropy(
247 | x_identity_0_hat_logits.permute(0, 2, 1),
248 | x_identity_0_gt_ignore,
249 | ignore_index=-1,
250 | reduction='none').sum(1) + F.cross_entropy(
251 | x_pose_0_hat_logits.permute(0, 2, 1),
252 | x_pose_0_gt_ignore,
253 | ignore_index=-1,
254 | reduction='none').sum(1)
255 | vb_loss = cross_entropy_loss / t
256 | vb_loss = vb_loss / pt
257 | vb_loss = vb_loss / (math.log(2) * x_identity_0.shape[1:].numel())
258 | if self.loss_type == 'elbo':
259 | loss = vb_loss
260 | elif self.loss_type == 'mlm':
261 | denom = mask.float().sum(1)
262 | denom[denom == 0] = 1 # prevent divide by 0 errors.
263 | loss = cross_entropy_loss / denom
264 | elif self.loss_type == 'reweighted_elbo':
265 | weight = (1 - (t / self.num_timesteps))
266 | loss = weight * cross_entropy_loss
267 | loss = loss / (math.log(2) * x_identity_0.shape[1:].numel())
268 | else:
269 | raise ValueError
270 |
271 | return loss.mean(), vb_loss.mean()
272 |
273 | def feed_data(self, data):
274 | self.image = data['image'].to(self.device)
275 | self.text = data['text'] #.to(self.device)
276 | self.get_text_embedding()
277 |
278 | self.identity_tokens, self.pose_tokens = self.get_quantized_img(
279 | self.image)
280 |
281 | def get_fixed_language_model(self):
282 | self.language_model = SentenceTransformer('all-MiniLM-L6-v2')
283 | self.text_feature_dim = 384
284 |
285 | @torch.no_grad()
286 | def get_text_embedding(self):
287 | self.text_embedding = self.language_model.encode(self.text, show_progress_bar=False)
288 | self.text_embedding = torch.Tensor(self.text_embedding).to(
289 | self.device).unsqueeze(1)
290 |
291 | def optimize_parameters(self):
292 | self._denoise_fn.train()
293 |
294 | loss, vb_loss = self._train_loss(self.identity_tokens,
295 | self.pose_tokens)
296 |
297 | self.optimizer.zero_grad()
298 | loss.backward()
299 | self.optimizer.step()
300 |
301 | self.log_dict['loss'] = loss
302 | self.log_dict['vb_loss'] = vb_loss
303 |
304 | self._denoise_fn.eval()
305 |
306 | def sample_fn(self, temp=1.0, sample_steps=None):
307 | self._denoise_fn.eval()
308 |
309 | b, device = self.image.size(0), 'cuda'
310 | x_identity_t = torch.ones(
311 | (b, np.prod(self.shape)), device=device).long() * self.mask_id
312 | x_pose_t = torch.ones((b, np.prod(self.shape) // 16),
313 | device=device).long() * self.mask_id
314 | unmasked_identity = torch.zeros_like(
315 | x_identity_t, device=device).bool()
316 | unmasked_pose = torch.zeros_like(x_pose_t, device=device).bool()
317 | sample_steps = list(range(1, sample_steps + 1))
318 |
319 | for t in reversed(sample_steps):
320 | print(f'Sample timestep {t:4d}', end='\r')
321 | t = torch.full((b, ), t, device=device, dtype=torch.long)
322 |
323 | # where to unmask
324 | changes_identity = torch.rand(
325 | x_identity_t.shape,
326 | device=device) < 1 / t.float().unsqueeze(-1)
327 | # don't unmask somewhere already unmasked
328 | changes_identity = torch.bitwise_xor(
329 | changes_identity,
330 | torch.bitwise_and(changes_identity, unmasked_identity))
331 | # update mask with changes
332 | unmasked_identity = torch.bitwise_or(unmasked_identity,
333 | changes_identity)
334 |
335 | changes_pose = torch.rand(
336 | x_pose_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
337 | # don't unmask somewhere already unmasked
338 | changes_pose = torch.bitwise_xor(
339 | changes_pose, torch.bitwise_and(changes_pose, unmasked_pose))
340 | # update mask with changes
341 | unmasked_pose = torch.bitwise_or(unmasked_pose, changes_pose)
342 |
343 | x_identity_0_hat_logits, x_pose_0_hat_logits = self._denoise_fn(
344 | self.text_embedding, x_identity_t, x_pose_t, t=t)
345 |
346 | x_identity_0_hat_logits = x_identity_0_hat_logits[:, 1:1 +
347 | self.shape[0] *
348 | self.shape[1], :]
349 | x_pose_0_hat_logits = x_pose_0_hat_logits[:, 1 + self.shape[0] *
350 | self.shape[1]:]
351 |
352 | # scale by temperature
353 | x_identity_0_hat_logits = x_identity_0_hat_logits / temp
354 | x_identity_0_dist = dists.Categorical(
355 | logits=x_identity_0_hat_logits)
356 | x_identity_0_hat = x_identity_0_dist.sample().long()
357 |
358 | x_pose_0_hat_logits = x_pose_0_hat_logits / temp
359 | x_pose_0_dist = dists.Categorical(logits=x_pose_0_hat_logits)
360 | x_pose_0_hat = x_pose_0_dist.sample().long()
361 |
362 | # x_t would be the input to the transformer, so the index range should be continual one
363 | x_identity_t[changes_identity] = x_identity_0_hat[changes_identity]
364 | x_pose_t[changes_pose] = x_pose_0_hat[changes_pose]
365 |
366 | self._denoise_fn.train()
367 |
368 | return x_identity_t, x_pose_t
369 |
370 | def get_vis(self, image, gt_quant_identity, gt_quant_frame, quant_identity,
371 | quant_frame, save_path):
372 | # original image
373 | ori_img = self.decode_image_indices(gt_quant_identity, gt_quant_frame)
374 | # pred image
375 | pred_img = self.decode_image_indices(quant_identity, quant_frame)
376 | img_cat = torch.cat([
377 | image,
378 | ori_img,
379 | pred_img,
380 | ], dim=3).detach()
381 | img_cat = ((img_cat + 1) / 2)
382 | img_cat = img_cat.clamp_(0, 1)
383 | save_image(img_cat, save_path, nrow=1, padding=4)
384 |
385 | def inference(self, data_loader, save_dir):
386 | self._denoise_fn.eval()
387 |
388 | for _, data in enumerate(data_loader):
389 | img_name = data['img_name']
390 | self.feed_data(data)
391 | b = self.image.size(0)
392 | with torch.no_grad():
393 | x_identity_t, x_pose_t = self.sample_fn(
394 | temp=1, sample_steps=self.sample_steps)
395 | for idx in range(b):
396 | self.get_vis(self.image[idx:idx + 1],
397 | self.identity_tokens[idx:idx + 1],
398 | self.pose_tokens[idx:idx + 1],
399 | x_identity_t[idx:idx + 1], x_pose_t[idx:idx + 1],
400 | f'{save_dir}/{img_name[idx]}.png')
401 |
402 | self._denoise_fn.train()
403 |
404 | def sample_appearance(self, text, save_path, shape=[256, 128]):
405 | self._denoise_fn.eval()
406 |
407 | self.text = text
408 | self.image = torch.zeros([1, 3, shape[0], shape[1]]).to(self.device)
409 | self.get_text_embedding()
410 |
411 | with torch.no_grad():
412 | x_identity_t, x_pose_t = self.sample_fn(
413 | temp=1, sample_steps=self.sample_steps)
414 |
415 | self.get_vis_generated_only(x_identity_t, x_pose_t, save_path)
416 |
417 | quant_identity = self.quantize_identity.get_codebook_entry(
418 | x_identity_t, (x_identity_t.size(0), self.shape[0], self.shape[1],
419 | self.opt["img_z_channels"]))
420 | quant_frame = self.quantize_others.get_codebook_entry(
421 | x_pose_t,
422 | (x_pose_t.size(0), self.shape[0] // 4, self.shape[1] // 4,
423 | self.opt["img_z_channels"] // 2)).view(
424 | x_pose_t.size(0), self.opt["img_z_channels"] // 2,
425 | -1).permute(0, 2, 1)
426 |
427 | self._denoise_fn.train()
428 |
429 | return quant_identity, quant_frame
430 |
431 | def get_vis_generated_only(self, quant_identity, quant_frame, save_path):
432 | # pred image
433 | pred_img = self.decode_image_indices(quant_identity, quant_frame)
434 | img_cat = ((pred_img.detach() + 1) / 2)
435 | img_cat = img_cat.clamp_(0, 1)
436 | save_image(img_cat, save_path, nrow=1, padding=4)
437 |
438 | def get_current_log(self):
439 | return self.log_dict
440 |
441 | def update_learning_rate(self, epoch, iters=None):
442 | """Update learning rate.
443 |
444 | Args:
445 | current_iter (int): Current iteration.
446 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
447 | Default: -1.
448 | """
449 | lr = self.optimizer.param_groups[0]['lr']
450 |
451 | if self.opt['lr_decay'] == 'step':
452 | lr = self.opt['lr'] * (
453 | self.opt['gamma']**(epoch // self.opt['step']))
454 | elif self.opt['lr_decay'] == 'cos':
455 | lr = self.opt['lr'] * (
456 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
457 | elif self.opt['lr_decay'] == 'linear':
458 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
459 | elif self.opt['lr_decay'] == 'linear2exp':
460 | if epoch < self.opt['turning_point'] + 1:
461 | # learning rate decay as 95%
462 | # at the turning point (1 / 95% = 1.0526)
463 | lr = self.opt['lr'] * (
464 | 1 - epoch / int(self.opt['turning_point'] * 1.0526))
465 | else:
466 | lr *= self.opt['gamma']
467 | elif self.opt['lr_decay'] == 'schedule':
468 | if epoch in self.opt['schedule']:
469 | lr *= self.opt['gamma']
470 | elif self.opt['lr_decay'] == 'warm_up':
471 | if iters <= self.opt['warmup_iters']:
472 | lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
473 | else:
474 | lr = self.opt['lr']
475 | else:
476 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
477 | # set learning rate
478 | for param_group in self.optimizer.param_groups:
479 | param_group['lr'] = lr
480 |
481 | return lr
482 |
483 | def save_network(self, net, save_path):
484 | """Save networks.
485 |
486 | Args:
487 | net (nn.Module): Network to be saved.
488 | net_label (str): Network label.
489 | current_iter (int): Current iter number.
490 | """
491 | state_dict = net.state_dict()
492 | torch.save(state_dict, save_path)
493 |
494 | def load_network(self):
495 | checkpoint = torch.load(self.opt['pretrained_sampler'])
496 | self._denoise_fn.load_state_dict(checkpoint, strict=True)
497 | self._denoise_fn.eval()
498 |
--------------------------------------------------------------------------------
/models/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/models/archs/__init__.py
--------------------------------------------------------------------------------
/models/archs/dalle_transformer_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import wraps
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | from torch import einsum, nn
8 |
9 | from models.archs.einops_exts import repeat_many
10 | from models.archs.rotary_embedding_torch import RotaryEmbedding
11 |
12 |
13 | # helper functions
14 | def exists(val):
15 | return val is not None
16 |
17 |
18 | def l2norm(t):
19 | return F.normalize(t, dim=-1)
20 |
21 |
22 | # relative positional bias for causal transformer
23 | class RelPosBias(nn.Module):
24 |
25 | def __init__(
26 | self,
27 | heads=8,
28 | num_buckets=32,
29 | max_distance=128,
30 | ):
31 | super().__init__()
32 | self.num_buckets = num_buckets
33 | self.max_distance = max_distance
34 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
35 |
36 | @staticmethod
37 | def _relative_position_bucket(relative_position,
38 | num_buckets=32,
39 | max_distance=128):
40 | n = -relative_position
41 | n = torch.max(n, torch.zeros_like(n))
42 |
43 | max_exact = num_buckets // 2
44 | is_small = n < max_exact
45 |
46 | val_if_large = max_exact + (torch.log(n.float() / max_exact) /
47 | math.log(max_distance / max_exact) *
48 | (num_buckets - max_exact)).long()
49 | val_if_large = torch.min(
50 | val_if_large, torch.full_like(val_if_large, num_buckets - 1))
51 | return torch.where(is_small, n, val_if_large)
52 |
53 | def forward(self, i, j, *, device):
54 | q_pos = torch.arange(i, dtype=torch.long, device=device)
55 | k_pos = torch.arange(j, dtype=torch.long, device=device)
56 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
57 | rp_bucket = self._relative_position_bucket(
58 | rel_pos,
59 | num_buckets=self.num_buckets,
60 | max_distance=self.max_distance)
61 | values = self.relative_attention_bias(rp_bucket)
62 | return rearrange(values, 'i j h -> h i j')
63 |
64 |
65 | class LayerNorm(nn.Module):
66 |
67 | def __init__(self, dim, eps=1e-5, fp16_eps=1e-3, stable=False):
68 | super().__init__()
69 | self.eps = eps
70 | self.fp16_eps = fp16_eps
71 | self.stable = stable
72 | self.g = nn.Parameter(torch.ones(dim))
73 |
74 | def forward(self, x):
75 | eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
76 |
77 | if self.stable:
78 | x = x / x.amax(dim=-1, keepdim=True).detach()
79 |
80 | var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
81 | mean = torch.mean(x, dim=-1, keepdim=True)
82 | return (x - mean) * (var + eps).rsqrt() * self.g
83 |
84 |
85 | # attention
86 | class Attention(nn.Module):
87 |
88 | def __init__(self,
89 | dim,
90 | *,
91 | dim_head=64,
92 | heads=8,
93 | dropout=0.,
94 | causal=False,
95 | rotary_emb=None,
96 | cosine_sim=True,
97 | cosine_sim_scale=16):
98 | super().__init__()
99 | self.scale = cosine_sim_scale if cosine_sim else (dim_head**-0.5)
100 | self.cosine_sim = cosine_sim
101 |
102 | self.heads = heads
103 | inner_dim = dim_head * heads
104 |
105 | self.causal = causal
106 | self.norm = LayerNorm(dim)
107 | self.dropout = nn.Dropout(dropout)
108 |
109 | self.null_kv = nn.Parameter(torch.randn(2, dim_head))
110 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
111 | self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
112 |
113 | self.rotary_emb = rotary_emb
114 |
115 | self.to_out = nn.Sequential(
116 | nn.Linear(inner_dim, dim, bias=False), LayerNorm(dim))
117 |
118 | def forward(self, x, mask=None, attn_bias=None):
119 | b, n, device = *x.shape[:2], x.device
120 |
121 | x = self.norm(x)
122 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
123 |
124 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
125 | q = q * self.scale
126 |
127 | # rotary embeddings
128 |
129 | if exists(self.rotary_emb):
130 | q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
131 |
132 | # add null key / value for classifier free guidance in prior net
133 |
134 | nk, nv = repeat_many(self.null_kv.unbind(dim=-2), 'd -> b 1 d', b=b)
135 | k = torch.cat((nk, k), dim=-2)
136 | v = torch.cat((nv, v), dim=-2)
137 |
138 | # whether to use cosine sim
139 |
140 | if self.cosine_sim:
141 | q, k = map(l2norm, (q, k))
142 |
143 | q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
144 |
145 | # calculate query / key similarities
146 |
147 | sim = einsum('b h i d, b j d -> b h i j', q, k)
148 |
149 | # relative positional encoding (T5 style)
150 |
151 | if exists(attn_bias):
152 | sim = sim + attn_bias
153 |
154 | # masking
155 |
156 | max_neg_value = -torch.finfo(sim.dtype).max
157 |
158 | if exists(mask):
159 | mask = F.pad(mask, (1, 0), value=True)
160 | mask = rearrange(mask, 'b j -> b 1 1 j')
161 | sim = sim.masked_fill(~mask, max_neg_value)
162 |
163 | if self.causal:
164 | i, j = sim.shape[-2:]
165 | causal_mask = torch.ones((i, j), dtype=torch.bool,
166 | device=device).triu(j - i + 1)
167 | sim = sim.masked_fill(causal_mask, max_neg_value)
168 |
169 | # attention
170 |
171 | attn = sim.softmax(dim=-1, dtype=torch.float32)
172 | attn = self.dropout(attn)
173 |
174 | # aggregate values
175 |
176 | out = einsum('b h i j, b j d -> b h i d', attn, v)
177 |
178 | out = rearrange(out, 'b h n d -> b n (h d)')
179 | return self.to_out(out)
180 |
181 |
182 | # feedforward
183 | class SwiGLU(nn.Module):
184 | """ used successfully in https://arxiv.org/abs/2204.0231 """
185 |
186 | def forward(self, x):
187 | x, gate = x.chunk(2, dim=-1)
188 | return x * F.silu(gate)
189 |
190 |
191 | def FeedForward(dim, mult=4, dropout=0., post_activation_norm=False):
192 | """ post-activation norm https://arxiv.org/abs/2110.09456 """
193 |
194 | inner_dim = int(mult * dim)
195 | return nn.Sequential(
196 | LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias=False), SwiGLU(),
197 | LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
198 | nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias=False))
199 |
200 |
201 | class NonCausalTransformerLanguage(nn.Module):
202 |
203 | def __init__(self,
204 | *,
205 | dim,
206 | depth,
207 | dim_head=64,
208 | heads=8,
209 | ff_mult=4,
210 | norm_in=False,
211 | norm_out=True,
212 | attn_dropout=0.,
213 | ff_dropout=0.,
214 | final_proj=True,
215 | normformer=False,
216 | rotary_emb=True):
217 | super().__init__()
218 | self.init_norm = LayerNorm(dim) if norm_in else nn.Identity(
219 | ) # from latest BLOOM model and Yandex's YaLM
220 |
221 | self.rel_pos_bias = RelPosBias(heads=heads)
222 |
223 | self.text_feature_mapping = nn.Sequential(
224 | nn.LayerNorm(384),
225 | nn.Linear(384, 256),
226 | nn.LayerNorm(256),
227 | nn.Linear(256, dim),
228 | nn.LayerNorm(dim),
229 | )
230 |
231 | rotary_emb = RotaryEmbedding(
232 | dim=min(32, dim_head)) if rotary_emb else None
233 |
234 | self.mask_emb = nn.Parameter(torch.zeros(1, 1, dim))
235 |
236 | self.layers = nn.ModuleList([])
237 | for _ in range(depth):
238 | self.layers.append(
239 | nn.ModuleList([
240 | Attention(
241 | dim=dim,
242 | causal=False,
243 | dim_head=dim_head,
244 | heads=heads,
245 | dropout=attn_dropout,
246 | rotary_emb=rotary_emb),
247 | FeedForward(
248 | dim=dim,
249 | mult=ff_mult,
250 | dropout=ff_dropout,
251 | post_activation_norm=normformer)
252 | ]))
253 |
254 | self.norm = LayerNorm(
255 | dim, stable=True
256 | ) if norm_out else nn.Identity(
257 | ) # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
258 | self.project_out = nn.Linear(
259 | dim, dim, bias=False) if final_proj else nn.Identity()
260 |
261 | def forward(self, x, exemplar_frame_embeddings, text_embedding, masks):
262 | x_masked = x.clone()
263 | x_masked[masks, :] = self.mask_emb
264 |
265 | x = torch.cat((self.text_feature_mapping(text_embedding),
266 | exemplar_frame_embeddings, x_masked),
267 | dim=1).clone()
268 |
269 | n, device = x.shape[1], x.device
270 | x = self.init_norm(x)
271 |
272 | attn_bias = self.rel_pos_bias(n, n + 1, device=device)
273 |
274 | for attn, ff in self.layers:
275 | x = attn(x, attn_bias=attn_bias) + x
276 | x = ff(x) + x
277 |
278 | out = self.norm(x)
279 | return self.project_out(out)
--------------------------------------------------------------------------------
/models/archs/einops_exts.py:
--------------------------------------------------------------------------------
1 | import re
2 | from functools import partial, wraps
3 |
4 | from einops import rearrange, reduce, repeat
5 | from torch import nn
6 |
7 | # checking shape
8 | # @nils-werner
9 | # https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838
10 |
11 |
12 | def check_shape(tensor, pattern, **kwargs):
13 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
14 |
15 |
16 | # do same einops operations on a list of tensors
17 |
18 |
19 | def _many(fn):
20 |
21 | @wraps(fn)
22 | def inner(tensors, pattern, **kwargs):
23 | return (fn(tensor, pattern, **kwargs) for tensor in tensors)
24 |
25 | return inner
26 |
27 |
28 | # do einops with unflattening of anonymously named dimensions
29 | # (...flattened) -> ...flattened
30 |
31 |
32 | def _with_anon_dims(fn):
33 |
34 | @wraps(fn)
35 | def inner(tensor, pattern, **kwargs):
36 | regex = r'(\.\.\.[a-zA-Z]+)'
37 | matches = re.findall(regex, pattern)
38 | get_anon_dim_name = lambda t: t.lstrip('...')
39 | dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))
40 |
41 | update_kwargs_dict = dict()
42 |
43 | for prefix in dim_prefixes:
44 | assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
45 | dim_list = kwargs[prefix]
46 | assert isinstance(
47 | dim_list, (list, tuple)
48 | ), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
49 | dim_names = list(
50 | map(lambda ind: f'{prefix}{ind}', range(len(dim_list))))
51 | update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))
52 |
53 | def sub_with_anonymous_dims(t):
54 | dim_name_prefix = get_anon_dim_name(t.groups()[0])
55 | return ' '.join(update_kwargs_dict[dim_name_prefix].keys())
56 |
57 | pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)
58 |
59 | for prefix, update_dict in update_kwargs_dict.items():
60 | del kwargs[prefix]
61 | kwargs.update(update_dict)
62 |
63 | return fn(tensor, pattern_new, **kwargs)
64 |
65 | return inner
66 |
67 |
68 | # generate all helper functions
69 |
70 | rearrange_many = _many(rearrange)
71 | repeat_many = _many(repeat)
72 | reduce_many = _many(reduce)
73 |
74 | rearrange_with_anon_dims = _with_anon_dims(rearrange)
75 | repeat_with_anon_dims = _with_anon_dims(repeat)
76 | reduce_with_anon_dims = _with_anon_dims(reduce)
77 |
78 | # for rearranging to and from a pattern
79 |
80 |
81 | class EinopsToAndFrom(nn.Module):
82 |
83 | def __init__(self, from_einops, to_einops, fn):
84 | super().__init__()
85 | self.from_einops = from_einops
86 | self.to_einops = to_einops
87 | self.fn = fn
88 |
89 | def forward(self, x, **kwargs):
90 | shape = x.shape
91 | reconstitute_kwargs = dict(
92 | tuple(zip(self.from_einops.split(' '), shape)))
93 | x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
94 | x = self.fn(x, **kwargs)
95 | x = rearrange(x, f'{self.to_einops} -> {self.from_einops}',
96 | **reconstitute_kwargs)
97 | return x
98 |
--------------------------------------------------------------------------------
/models/archs/rotary_embedding_torch.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | from math import log, pi
3 |
4 | import torch
5 | from einops import rearrange, repeat
6 | from torch import einsum, nn
7 |
8 | # helper functions
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def broadcat(tensors, dim=-1):
16 | num_tensors = len(tensors)
17 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
18 | assert len(
19 | shape_lens) == 1, 'tensors must all have the same number of dimensions'
20 | shape_len = list(shape_lens)[0]
21 |
22 | dim = (dim + shape_len) if dim < 0 else dim
23 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
24 |
25 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
26 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)
27 | ]), 'invalid dimensions for broadcastable concatentation'
28 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
29 | expanded_dims = list(
30 | map(lambda t: (t[0], (t[1], ) * num_tensors), max_dims))
31 | expanded_dims.insert(dim, (dim, dims[dim]))
32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
33 | tensors = list(
34 | map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
35 | return torch.cat(tensors, dim=dim)
36 |
37 |
38 | # rotary embedding helper functions
39 |
40 |
41 | def rotate_half(x):
42 | x = rearrange(x, '... (d r) -> ... d r', r=2)
43 | x1, x2 = x.unbind(dim=-1)
44 | x = torch.stack((-x2, x1), dim=-1)
45 | return rearrange(x, '... d r -> ... (d r)')
46 |
47 |
48 | def apply_rotary_emb(freqs, t, start_index=0):
49 | freqs = freqs.to(t)
50 | rot_dim = freqs.shape[-1]
51 | end_index = start_index + rot_dim
52 | assert rot_dim <= t.shape[
53 | -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
54 | t_left, t, t_right = t[..., :start_index], t[
55 | ..., start_index:end_index], t[..., end_index:]
56 | t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
57 | return torch.cat((t_left, t, t_right), dim=-1)
58 |
59 |
60 | # learned rotation helpers
61 |
62 |
63 | def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
64 | if exists(freq_ranges):
65 | rotations = einsum('..., f -> ... f', rotations, freq_ranges)
66 | rotations = rearrange(rotations, '... r f -> ... (r f)')
67 |
68 | rotations = repeat(rotations, '... n -> ... (n r)', r=2)
69 | return apply_rotary_emb(rotations, t, start_index=start_index)
70 |
71 |
72 | # classes
73 |
74 |
75 | class RotaryEmbedding(nn.Module):
76 |
77 | def __init__(self,
78 | dim,
79 | custom_freqs=None,
80 | freqs_for='lang',
81 | theta=10000,
82 | max_freq=10,
83 | num_freqs=1,
84 | learned_freq=False):
85 | super().__init__()
86 | if exists(custom_freqs):
87 | freqs = custom_freqs
88 | elif freqs_for == 'lang':
89 | freqs = 1. / (
90 | theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
91 | elif freqs_for == 'pixel':
92 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
93 | elif freqs_for == 'constant':
94 | freqs = torch.ones(num_freqs).float()
95 | else:
96 | raise ValueError(f'unknown modality {freqs_for}')
97 |
98 | self.cache = dict()
99 |
100 | if learned_freq:
101 | self.freqs = nn.Parameter(freqs)
102 | else:
103 | self.register_buffer('freqs', freqs)
104 |
105 | def rotate_queries_or_keys(self, t, seq_dim=-2):
106 | device = t.device
107 | seq_len = t.shape[seq_dim]
108 | freqs = self.forward(
109 | lambda: torch.arange(seq_len, device=device), cache_key=seq_len)
110 | return apply_rotary_emb(freqs, t)
111 |
112 | def forward(self, t, cache_key=None):
113 | if exists(cache_key) and cache_key in self.cache:
114 | return self.cache[cache_key]
115 |
116 | if isfunction(t):
117 | t = t()
118 |
119 | freqs = self.freqs
120 |
121 | freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
122 | freqs = repeat(freqs, '... n -> ... (n r)', r=2)
123 |
124 | if exists(cache_key):
125 | self.cache[cache_key] = freqs
126 |
127 | return freqs
128 |
--------------------------------------------------------------------------------
/models/archs/transformer_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class CausalSelfAttention(nn.Module):
9 | """
10 | A vanilla multi-head masked self-attention layer with a projection at the end.
11 | It is possible to use torch.nn.MultiheadAttention here but I am including an
12 | explicit implementation here to show that there is nothing too scary here.
13 | """
14 |
15 | def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
16 | block_size, sampler):
17 | super().__init__()
18 | assert bert_n_emb % bert_n_head == 0
19 | # key, query, value projections for all heads
20 | self.key = nn.Linear(bert_n_emb, bert_n_emb)
21 | self.query = nn.Linear(bert_n_emb, bert_n_emb)
22 | self.value = nn.Linear(bert_n_emb, bert_n_emb)
23 | # regularization
24 | self.attn_drop = nn.Dropout(attn_pdrop)
25 | self.resid_drop = nn.Dropout(resid_pdrop)
26 | # output projection
27 | self.proj = nn.Linear(bert_n_emb, bert_n_emb)
28 | self.n_head = bert_n_head
29 | self.causal = True if sampler == 'autoregressive' else False
30 | if self.causal:
31 | mask = torch.tril(torch.ones(block_size, block_size))
32 | self.register_buffer("mask", mask.view(1, 1, block_size,
33 | block_size))
34 |
35 | def forward(self, x, layer_past=None):
36 | B, T, C = x.size()
37 |
38 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
39 | k = self.key(x).view(B, T, self.n_head,
40 | C // self.n_head).transpose(1,
41 | 2) # (B, nh, T, hs)
42 | q = self.query(x).view(B, T, self.n_head,
43 | C // self.n_head).transpose(1,
44 | 2) # (B, nh, T, hs)
45 | v = self.value(x).view(B, T, self.n_head,
46 | C // self.n_head).transpose(1,
47 | 2) # (B, nh, T, hs)
48 |
49 | present = torch.stack((k, v))
50 | if self.causal and layer_past is not None:
51 | past_key, past_value = layer_past
52 | k = torch.cat((past_key, k), dim=-2)
53 | v = torch.cat((past_value, v), dim=-2)
54 |
55 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
56 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
57 |
58 | if self.causal and layer_past is None:
59 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
60 |
61 | att = F.softmax(att, dim=-1)
62 | att = self.attn_drop(att)
63 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
64 | # re-assemble all head outputs side by side
65 | y = y.transpose(1, 2).contiguous().view(B, T, C)
66 |
67 | # output projection
68 | y = self.resid_drop(self.proj(y))
69 | return y, present
70 |
71 |
72 | class Block(nn.Module):
73 | """ an unassuming Transformer block """
74 |
75 | def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
76 | block_size, sampler):
77 | super().__init__()
78 | self.ln1 = nn.LayerNorm(bert_n_emb)
79 | self.ln2 = nn.LayerNorm(bert_n_emb)
80 | self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
81 | resid_pdrop, block_size, sampler)
82 | self.mlp = nn.Sequential(
83 | nn.Linear(bert_n_emb, 4 * bert_n_emb),
84 | nn.GELU(), # nice
85 | nn.Linear(4 * bert_n_emb, bert_n_emb),
86 | nn.Dropout(resid_pdrop),
87 | )
88 |
89 | def forward(self, x, layer_past=None, return_present=False):
90 |
91 | attn, present = self.attn(self.ln1(x), layer_past)
92 | x = x + attn
93 | x = x + self.mlp(self.ln2(x))
94 |
95 | if layer_past is not None or return_present:
96 | return x, present
97 | return x
98 |
99 |
100 | class TransformerLanguage(nn.Module):
101 | """ the full GPT language model, with a context size of block_size """
102 |
103 | def __init__(self,
104 | codebook_size,
105 | bert_n_emb,
106 | bert_n_layers,
107 | bert_n_head,
108 | block_size,
109 | embd_pdrop,
110 | resid_pdrop,
111 | attn_pdrop,
112 | sampler='absorbing'):
113 | super().__init__()
114 |
115 | self.vocab_size = codebook_size + 1
116 | self.n_embd = bert_n_emb
117 | self.block_size = block_size
118 | self.n_layers = bert_n_layers
119 | self.codebook_size = codebook_size
120 | self.causal = sampler == 'autoregressive'
121 | if self.causal:
122 | self.vocab_size = codebook_size
123 |
124 | self.text_feature_mapping = nn.Sequential(
125 | nn.LayerNorm(384),
126 | nn.Linear(384, 256),
127 | nn.LayerNorm(256),
128 | nn.Linear(256, self.n_embd),
129 | nn.LayerNorm(self.n_embd),
130 | )
131 |
132 | self.identity_tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
133 | self.pose_tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
134 | self.pos_emb = nn.Parameter(
135 | torch.zeros(1, self.block_size, self.n_embd))
136 | self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
137 | self.drop = nn.Dropout(embd_pdrop)
138 |
139 | # transformer
140 | self.blocks = nn.Sequential(*[
141 | Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop, block_size,
142 | sampler) for _ in range(self.n_layers)
143 | ])
144 | # decoder head
145 | self.ln_f = nn.LayerNorm(self.n_embd)
146 | self.identity_head = nn.Linear(
147 | self.n_embd, self.codebook_size, bias=False)
148 | self.pose_head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
149 |
150 | def get_block_size(self):
151 | return self.block_size
152 |
153 | def _init_weights(self, module):
154 | if isinstance(module, (nn.Linear, nn.Embedding)):
155 | module.weight.data.normal_(mean=0.0, std=0.02)
156 | if isinstance(module, nn.Linear) and module.bias is not None:
157 | module.bias.data.zero_()
158 | elif isinstance(module, nn.LayerNorm):
159 | module.bias.data.zero_()
160 | module.weight.data.fill_(1.0)
161 |
162 | def forward(self, text_embedding, identity_idx, pose_idx, t=None):
163 | # each index maps to a (learnable) vector
164 | identity_token_embeddings = self.identity_tok_emb(identity_idx)
165 | pose_token_embeddings = self.pose_tok_emb(pose_idx)
166 |
167 | token_embeddings = torch.cat(
168 | (identity_token_embeddings, pose_token_embeddings), dim=1)
169 |
170 | if self.causal:
171 | token_embeddings = torch.cat((self.start_tok.repeat(
172 | token_embeddings.size(0), 1, 1), token_embeddings),
173 | dim=1)
174 |
175 | t = token_embeddings.shape[1]
176 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
177 | # each position maps to a (learnable) vector
178 |
179 | position_embeddings = self.pos_emb[:, :t, :]
180 |
181 | x = token_embeddings + position_embeddings
182 | x = self.drop(x)
183 | x = torch.cat([self.text_feature_mapping(text_embedding), x], dim=1)
184 |
185 | for block in self.blocks:
186 | x = block(x)
187 | x = self.ln_f(x)
188 | identity_logits = self.identity_head(x)
189 | pose_logits = self.pose_head(x)
190 |
191 | return identity_logits, pose_logits
192 |
--------------------------------------------------------------------------------
/models/archs/vqgan_arch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 |
6 |
7 | class VectorQuantizer(nn.Module):
8 | """
9 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
10 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
11 | """
12 |
13 | # NOTE: due to a bug the beta term was applied to the wrong term. for
14 | # backwards compatibility we use the buggy version by default, but you can
15 | # specify legacy=False to fix it.
16 | def __init__(self,
17 | n_e,
18 | e_dim,
19 | beta,
20 | remap=None,
21 | unknown_index="random",
22 | sane_index_shape=False,
23 | legacy=True):
24 | super().__init__()
25 | self.n_e = n_e
26 | self.e_dim = e_dim
27 | self.beta = beta
28 | self.legacy = legacy
29 |
30 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
31 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
32 |
33 | self.remap = remap
34 | if self.remap is not None:
35 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
36 | self.re_embed = self.used.shape[0]
37 | self.unknown_index = unknown_index # "random" or "extra" or integer
38 | if self.unknown_index == "extra":
39 | self.unknown_index = self.re_embed
40 | self.re_embed = self.re_embed + 1
41 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
42 | f"Using {self.unknown_index} for unknown indices.")
43 | else:
44 | self.re_embed = n_e
45 |
46 | self.sane_index_shape = sane_index_shape
47 |
48 | def remap_to_used(self, inds):
49 | ishape = inds.shape
50 | assert len(ishape) > 1
51 | inds = inds.reshape(ishape[0], -1)
52 | used = self.used.to(inds)
53 | match = (inds[:, :, None] == used[None, None, ...]).long()
54 | new = match.argmax(-1)
55 | unknown = match.sum(2) < 1
56 | if self.unknown_index == "random":
57 | new[unknown] = torch.randint(
58 | 0, self.re_embed,
59 | size=new[unknown].shape).to(device=new.device)
60 | else:
61 | new[unknown] = self.unknown_index
62 | return new.reshape(ishape)
63 |
64 | def unmap_to_all(self, inds):
65 | ishape = inds.shape
66 | assert len(ishape) > 1
67 | inds = inds.reshape(ishape[0], -1)
68 | used = self.used.to(inds)
69 | if self.re_embed > self.used.shape[0]: # extra token
70 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
71 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
72 | return back.reshape(ishape)
73 |
74 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
75 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
76 | assert rescale_logits == False, "Only for interface compatible with Gumbel"
77 | assert return_logits == False, "Only for interface compatible with Gumbel"
78 | # reshape z -> (batch, height, width, channel) and flatten
79 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
80 | z_flattened = z.view(-1, self.e_dim)
81 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
82 |
83 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
84 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
85 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
86 |
87 | min_encoding_indices = torch.argmin(d, dim=1)
88 | z_q = self.embedding(min_encoding_indices).view(z.shape)
89 | perplexity = None
90 | min_encodings = None
91 |
92 | # compute loss for embedding
93 | if not self.legacy:
94 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
95 | torch.mean((z_q - z.detach()) ** 2)
96 | else:
97 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
98 | torch.mean((z_q - z.detach()) ** 2)
99 |
100 | # preserve gradients
101 | z_q = z + (z_q - z).detach()
102 |
103 | # reshape back to match original input shape
104 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
105 |
106 | if self.remap is not None:
107 | min_encoding_indices = min_encoding_indices.reshape(
108 | z.shape[0], -1) # add batch axis
109 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
110 | min_encoding_indices = min_encoding_indices.reshape(-1,
111 | 1) # flatten
112 |
113 | if self.sane_index_shape:
114 | min_encoding_indices = min_encoding_indices.reshape(
115 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
116 |
117 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
118 |
119 | def get_codebook_entry(self, indices, shape):
120 | # shape specifying (batch, height, width, channel)
121 | if self.remap is not None:
122 | indices = indices.reshape(shape[0], -1) # add batch axis
123 | indices = self.unmap_to_all(indices)
124 | indices = indices.reshape(-1) # flatten again
125 |
126 | # get quantized latent vectors
127 | z_q = self.embedding(indices)
128 |
129 | if shape is not None:
130 | z_q = z_q.view(shape)
131 | # reshape back to match original input shape
132 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
133 |
134 | return z_q
135 |
136 | def get_nearest_codebook_embeddings(self, z, return_loss=False):
137 | # z = rearrange(z, 'b c h w -> b h w c').contiguous()
138 | z_flattened = z.view(-1, self.e_dim)
139 |
140 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
141 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
142 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
143 |
144 | min_encoding_indices = torch.argmin(d, dim=1)
145 | z_q = self.embedding(min_encoding_indices).view(z.shape)
146 |
147 | if return_loss:
148 | # compute loss for embedding
149 | if not self.legacy:
150 | loss = self.beta * torch.mean((z_q.detach() - z)**2)
151 | else:
152 | loss = torch.mean((z_q.detach() - z)**2)
153 |
154 | # preserve gradients
155 | z_q = z + (z_q - z).detach()
156 |
157 | # reshape back to match original input shape
158 | # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
159 |
160 | if return_loss:
161 | return z_q, loss
162 | else:
163 | return z_q
164 |
165 |
166 | class ResnetBlock(nn.Module):
167 |
168 | def __init__(self,
169 | *,
170 | in_channels,
171 | out_channels=None,
172 | conv_shortcut=False,
173 | dropout,
174 | temb_channels=512):
175 | super().__init__()
176 | self.in_channels = in_channels
177 | out_channels = in_channels if out_channels is None else out_channels
178 | self.out_channels = out_channels
179 | self.use_conv_shortcut = conv_shortcut
180 |
181 | self.norm1 = Normalize(in_channels)
182 | self.conv1 = torch.nn.Conv2d(
183 | in_channels, out_channels, kernel_size=3, stride=1, padding=1)
184 | if temb_channels > 0:
185 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
186 | self.norm2 = Normalize(out_channels)
187 | self.dropout = torch.nn.Dropout(dropout)
188 | self.conv2 = torch.nn.Conv2d(
189 | out_channels, out_channels, kernel_size=3, stride=1, padding=1)
190 | if self.in_channels != self.out_channels:
191 | if self.use_conv_shortcut:
192 | self.conv_shortcut = torch.nn.Conv2d(
193 | in_channels,
194 | out_channels,
195 | kernel_size=3,
196 | stride=1,
197 | padding=1)
198 | else:
199 | self.nin_shortcut = torch.nn.Conv2d(
200 | in_channels,
201 | out_channels,
202 | kernel_size=1,
203 | stride=1,
204 | padding=0)
205 |
206 | def forward(self, x, temb):
207 | h = x
208 | h = self.norm1(h)
209 | h = nonlinearity(h)
210 | h = self.conv1(h)
211 |
212 | if temb is not None:
213 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
214 |
215 | h = self.norm2(h)
216 | h = nonlinearity(h)
217 | h = self.dropout(h)
218 | h = self.conv2(h)
219 |
220 | if self.in_channels != self.out_channels:
221 | if self.use_conv_shortcut:
222 | x = self.conv_shortcut(x)
223 | else:
224 | x = self.nin_shortcut(x)
225 |
226 | return x + h
227 |
228 |
229 | class AttnBlock(nn.Module):
230 |
231 | def __init__(self, in_channels):
232 | super().__init__()
233 | self.in_channels = in_channels
234 |
235 | self.norm = Normalize(in_channels)
236 | self.q = torch.nn.Conv2d(
237 | in_channels, in_channels, kernel_size=1, stride=1, padding=0)
238 | self.k = torch.nn.Conv2d(
239 | in_channels, in_channels, kernel_size=1, stride=1, padding=0)
240 | self.v = torch.nn.Conv2d(
241 | in_channels, in_channels, kernel_size=1, stride=1, padding=0)
242 | self.proj_out = torch.nn.Conv2d(
243 | in_channels, in_channels, kernel_size=1, stride=1, padding=0)
244 |
245 | def forward(self, x):
246 | h_ = x
247 | h_ = self.norm(h_)
248 | q = self.q(h_)
249 | k = self.k(h_)
250 | v = self.v(h_)
251 |
252 | # compute attention
253 | b, c, h, w = q.shape
254 | q = q.reshape(b, c, h * w)
255 | q = q.permute(0, 2, 1) # b,hw,c
256 | k = k.reshape(b, c, h * w) # b,c,hw
257 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
258 | w_ = w_ * (int(c)**(-0.5))
259 | w_ = torch.nn.functional.softmax(w_, dim=2)
260 |
261 | # attend to values
262 | v = v.reshape(b, c, h * w)
263 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
264 | h_ = torch.bmm(
265 | v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
266 | h_ = h_.reshape(b, c, h, w)
267 |
268 | h_ = self.proj_out(h_)
269 |
270 | return x + h_
271 |
272 |
273 | class Upsample(nn.Module):
274 |
275 | def __init__(self, in_channels, with_conv):
276 | super().__init__()
277 | self.with_conv = with_conv
278 | if self.with_conv:
279 | self.conv = torch.nn.Conv2d(
280 | in_channels, in_channels, kernel_size=3, stride=1, padding=1)
281 |
282 | def forward(self, x):
283 | x = torch.nn.functional.interpolate(
284 | x, scale_factor=2.0, mode="nearest")
285 | if self.with_conv:
286 | x = self.conv(x)
287 | return x
288 |
289 |
290 | class Downsample(nn.Module):
291 |
292 | def __init__(self, in_channels, with_conv):
293 | super().__init__()
294 | self.with_conv = with_conv
295 | if self.with_conv:
296 | # no asymmetric padding in torch conv, must do it ourselves
297 | self.conv = torch.nn.Conv2d(
298 | in_channels, in_channels, kernel_size=3, stride=2, padding=0)
299 |
300 | def forward(self, x):
301 | if self.with_conv:
302 | pad = (0, 1, 0, 1)
303 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
304 | x = self.conv(x)
305 | else:
306 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
307 | return x
308 |
309 |
310 | def nonlinearity(x):
311 | # swish
312 | return x * torch.sigmoid(x)
313 |
314 |
315 | def Normalize(in_channels):
316 | return torch.nn.GroupNorm(
317 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
318 |
319 |
320 | class DecoderUpOthersDoubleIdentity(nn.Module):
321 |
322 | def __init__(self,
323 | in_channels,
324 | resolution,
325 | z_channels,
326 | ch,
327 | out_ch,
328 | num_res_blocks,
329 | attn_resolutions,
330 | ch_mult=(1, 2, 4, 8),
331 | other_ch_mult=(8, 8),
332 | dropout=0.0,
333 | resamp_with_conv=True,
334 | give_pre_end=False):
335 | super().__init__()
336 | self.ch = ch
337 | self.temb_ch = 0
338 | self.num_resolutions = len(ch_mult)
339 | self.num_res_blocks = num_res_blocks
340 | self.resolution = resolution
341 | self.in_channels = in_channels
342 | self.give_pre_end = give_pre_end
343 |
344 | self.num_other_resolutions = len(other_ch_mult)
345 |
346 | # compute in_ch_mult, block_in and curr_res at lowest res
347 | in_ch_mult = (1, ) + tuple(ch_mult)
348 | block_in = ch * ch_mult[self.num_resolutions - 1]
349 |
350 | curr_res = resolution // 2**(self.num_resolutions - 1)
351 | self.z_shape = (1, z_channels, curr_res, curr_res // 2)
352 | print("Working with z of shape {} = {} dimensions.".format(
353 | self.z_shape, np.prod(self.z_shape)))
354 |
355 | # z to block_in
356 | self.conv_in_identity = torch.nn.Conv2d(
357 | z_channels, block_in, kernel_size=3, stride=1, padding=1)
358 |
359 | # z to block_in
360 | self.conv_in_others = torch.nn.Conv2d(
361 | z_channels // 2, block_in // 2, kernel_size=3, stride=1, padding=1)
362 |
363 | self.conv_in = torch.nn.Conv2d(
364 | block_in // 2 + block_in,
365 | block_in,
366 | kernel_size=3,
367 | stride=1,
368 | padding=1)
369 |
370 | # others upsampling
371 | self.others_up = nn.ModuleList()
372 | block_in_others = ch // 2 * ch_mult[self.num_resolutions - 1]
373 | for i_level in reversed(range(self.num_other_resolutions)):
374 | block = nn.ModuleList()
375 | block_out_others = ch // 2 * other_ch_mult[i_level]
376 | for i_block in range(self.num_res_blocks + 1):
377 | block.append(
378 | ResnetBlock(
379 | in_channels=block_in_others,
380 | out_channels=block_out_others,
381 | temb_channels=self.temb_ch,
382 | dropout=dropout))
383 | block_in_others = block_out_others
384 | up = nn.Module()
385 | up.block = block
386 | up.upsample = Upsample(block_in_others, resamp_with_conv)
387 | self.others_up.insert(0, up) # prepend to get consistent order
388 |
389 | # middle
390 | self.mid = nn.Module()
391 | self.mid.block_1 = ResnetBlock(
392 | in_channels=block_in,
393 | out_channels=block_in,
394 | temb_channels=self.temb_ch,
395 | dropout=dropout)
396 | self.mid.attn_1 = AttnBlock(block_in)
397 | self.mid.block_2 = ResnetBlock(
398 | in_channels=block_in,
399 | out_channels=block_in,
400 | temb_channels=self.temb_ch,
401 | dropout=dropout)
402 |
403 | # upsampling
404 | self.up = nn.ModuleList()
405 | for i_level in reversed(range(self.num_resolutions)):
406 | block = nn.ModuleList()
407 | attn = nn.ModuleList()
408 | block_out = ch * ch_mult[i_level]
409 | for i_block in range(self.num_res_blocks + 1):
410 | block.append(
411 | ResnetBlock(
412 | in_channels=block_in,
413 | out_channels=block_out,
414 | temb_channels=self.temb_ch,
415 | dropout=dropout))
416 | block_in = block_out
417 | if curr_res in attn_resolutions:
418 | attn.append(AttnBlock(block_in))
419 | up = nn.Module()
420 | up.block = block
421 | up.attn = attn
422 | if i_level != 0:
423 | up.upsample = Upsample(block_in, resamp_with_conv)
424 | curr_res = curr_res * 2
425 | self.up.insert(0, up) # prepend to get consistent order
426 |
427 | # end
428 | self.norm_out = Normalize(block_in)
429 | self.conv_out = torch.nn.Conv2d(
430 | block_in, out_ch, kernel_size=3, stride=1, padding=1)
431 |
432 | def forward(self, z_identity, z_others):
433 | # timestep embedding
434 | temb = None
435 |
436 | # upsampling others
437 | h_others = self.conv_in_others(z_others)
438 | for i_level in reversed(range(self.num_other_resolutions)):
439 | for i_block in range(self.num_res_blocks + 1):
440 | h_others = self.others_up[i_level].block[i_block](h_others,
441 | temb)
442 | h_others = self.others_up[i_level].upsample(h_others)
443 |
444 | # z to block_in
445 | h_identity = self.conv_in_identity(z_identity)
446 |
447 | h = self.conv_in(torch.cat([h_identity, h_others], dim=1))
448 |
449 | # middle
450 | h = self.mid.block_1(h, temb)
451 | h = self.mid.attn_1(h)
452 | h = self.mid.block_2(h, temb)
453 |
454 | # upsampling
455 | for i_level in reversed(range(self.num_resolutions)):
456 | for i_block in range(self.num_res_blocks + 1):
457 | h = self.up[i_level].block[i_block](h, temb)
458 | if len(self.up[i_level].attn) > 0:
459 | h = self.up[i_level].attn[i_block](h)
460 | if i_level != 0:
461 | h = self.up[i_level].upsample(h)
462 |
463 | # end
464 | if self.give_pre_end:
465 | return h
466 |
467 | h = self.norm_out(h)
468 | h = nonlinearity(h)
469 | h = self.conv_out(h)
470 | return h
471 |
472 |
473 | class EncoderDecomposeBaseDownOthersDoubleIdentity(nn.Module):
474 |
475 | def __init__(self,
476 | ch,
477 | num_res_blocks,
478 | attn_resolutions,
479 | in_channels,
480 | resolution,
481 | z_channels,
482 | ch_mult=(1, 2, 4, 8),
483 | other_ch_mult=(8, 8),
484 | dropout=0.0,
485 | resamp_with_conv=True,
486 | double_z=True):
487 | super().__init__()
488 | self.ch = ch
489 | self.temb_ch = 0
490 | self.num_resolutions = len(ch_mult)
491 | self.num_res_blocks = num_res_blocks
492 | self.resolution = resolution
493 | self.in_channels = in_channels
494 |
495 | self.num_other_resolutions = len(other_ch_mult)
496 |
497 | # downsampling
498 | self.conv_in = torch.nn.Conv2d(
499 | in_channels, self.ch, kernel_size=3, stride=1, padding=1)
500 |
501 | curr_res = resolution
502 | in_ch_mult = (1, ) + tuple(ch_mult)
503 | self.down = nn.ModuleList()
504 | for i_level in range(self.num_resolutions):
505 | block = nn.ModuleList()
506 | attn = nn.ModuleList()
507 | block_in = ch * in_ch_mult[i_level]
508 | block_out = ch * ch_mult[i_level]
509 | for i_block in range(self.num_res_blocks):
510 | block.append(
511 | ResnetBlock(
512 | in_channels=block_in,
513 | out_channels=block_out,
514 | temb_channels=self.temb_ch,
515 | dropout=dropout))
516 | block_in = block_out
517 | if curr_res in attn_resolutions:
518 | attn.append(AttnBlock(block_in))
519 | down = nn.Module()
520 | down.block = block
521 | down.attn = attn
522 | if i_level != self.num_resolutions - 1:
523 | down.downsample = Downsample(block_in, resamp_with_conv)
524 | curr_res = curr_res // 2
525 | self.down.append(down)
526 |
527 | # identity branch
528 | # middle
529 | self.mid_identity = nn.Module()
530 | self.mid_identity.block_1 = ResnetBlock(
531 | in_channels=block_in,
532 | out_channels=block_in,
533 | temb_channels=self.temb_ch,
534 | dropout=dropout)
535 | self.mid_identity.attn_1 = AttnBlock(block_in)
536 | self.mid_identity.block_2 = ResnetBlock(
537 | in_channels=block_in,
538 | out_channels=block_in,
539 | temb_channels=self.temb_ch,
540 | dropout=dropout)
541 |
542 | # end
543 | self.norm_out_identity = Normalize(block_in)
544 | self.conv_out_identity = torch.nn.Conv2d(
545 | block_in,
546 | z_channels * 2 if double_z else z_channels,
547 | kernel_size=3,
548 | stride=1,
549 | padding=1)
550 |
551 | self.other_down = nn.ModuleList()
552 | for i_level in range(self.num_other_resolutions):
553 | block = nn.ModuleList()
554 | block_in = ch * other_ch_mult[i_level]
555 | block_out = ch * other_ch_mult[i_level]
556 | for i_block in range(self.num_res_blocks):
557 | block.append(
558 | ResnetBlock(
559 | in_channels=block_in,
560 | out_channels=block_out,
561 | temb_channels=self.temb_ch,
562 | dropout=dropout))
563 | block_in = block_out
564 |
565 | down = nn.Module()
566 | down.block = block
567 | down.downsample = Downsample(block_in, resamp_with_conv)
568 | self.other_down.append(down)
569 |
570 | # other branch
571 | # middle
572 | self.mid_other = nn.Module()
573 | self.mid_other.block_1 = ResnetBlock(
574 | in_channels=block_in,
575 | out_channels=block_in,
576 | temb_channels=self.temb_ch,
577 | dropout=dropout)
578 | self.mid_other.attn_1 = AttnBlock(block_in)
579 | self.mid_other.block_2 = ResnetBlock(
580 | in_channels=block_in,
581 | out_channels=block_in,
582 | temb_channels=self.temb_ch,
583 | dropout=dropout)
584 |
585 | self.norm_out_other = Normalize(block_in)
586 | self.conv_out_other = torch.nn.Conv2d(
587 | block_in,
588 | z_channels if double_z else z_channels // 2,
589 | kernel_size=3,
590 | stride=1,
591 | padding=1)
592 |
593 | def forward(self, x):
594 | #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
595 |
596 | # timestep embedding
597 | temb = None
598 |
599 | # downsampling
600 | hs = [self.conv_in(x)]
601 | for i_level in range(self.num_resolutions):
602 | for i_block in range(self.num_res_blocks):
603 | h = self.down[i_level].block[i_block](hs[-1], temb)
604 | if len(self.down[i_level].attn) > 0:
605 | h = self.down[i_level].attn[i_block](h)
606 | hs.append(h)
607 | if i_level != self.num_resolutions - 1:
608 | hs.append(self.down[i_level].downsample(hs[-1]))
609 |
610 | # identity branch
611 | # middle
612 | h_identity = self.mid_identity.block_1(hs[-1], temb)
613 | h_identity = self.mid_identity.attn_1(h_identity)
614 | h_identity = self.mid_identity.block_2(h_identity, temb)
615 |
616 | # end
617 | h_identity = self.norm_out_identity(h_identity)
618 | h_identity = nonlinearity(h_identity)
619 | h_identity = self.conv_out_identity(h_identity)
620 |
621 | # other branch
622 | for i_level in range(self.num_other_resolutions):
623 | for i_block in range(self.num_res_blocks):
624 | h = self.other_down[i_level].block[i_block](hs[-1], temb)
625 | hs.append(h)
626 | hs.append(self.other_down[i_level].downsample(hs[-1]))
627 |
628 | # middle
629 | h_other = self.mid_other.block_1(hs[-1], temb)
630 | h_other = self.mid_other.attn_1(h_other)
631 | h_other = self.mid_other.block_2(h_other, temb)
632 |
633 | # end
634 | h_other = self.norm_out_other(h_other)
635 | h_other = nonlinearity(h_other)
636 | h_other = self.conv_out_other(h_other)
637 | return h_identity, h_other
638 |
639 |
640 | # patch based discriminator
641 | class Discriminator(nn.Module):
642 |
643 | def __init__(self, nc, ndf, n_layers=3):
644 | super().__init__()
645 |
646 | layers = [
647 | nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
648 | nn.LeakyReLU(0.2, True)
649 | ]
650 | ndf_mult = 1
651 | ndf_mult_prev = 1
652 | for n in range(1,
653 | n_layers): # gradually increase the number of filters
654 | ndf_mult_prev = ndf_mult
655 | ndf_mult = min(2**n, 8)
656 | layers += [
657 | nn.Conv2d(
658 | ndf * ndf_mult_prev,
659 | ndf * ndf_mult,
660 | kernel_size=4,
661 | stride=2,
662 | padding=1,
663 | bias=False),
664 | nn.BatchNorm2d(ndf * ndf_mult),
665 | nn.LeakyReLU(0.2, True)
666 | ]
667 |
668 | ndf_mult_prev = ndf_mult
669 | ndf_mult = min(2**n_layers, 8)
670 |
671 | layers += [
672 | nn.Conv2d(
673 | ndf * ndf_mult_prev,
674 | ndf * ndf_mult,
675 | kernel_size=4,
676 | stride=1,
677 | padding=1,
678 | bias=False),
679 | nn.BatchNorm2d(ndf * ndf_mult),
680 | nn.LeakyReLU(0.2, True)
681 | ]
682 |
683 | layers += [
684 | nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
685 | ] # output 1 channel prediction map
686 | self.main = nn.Sequential(*layers)
687 |
688 | def forward(self, x):
689 | return self.main(x)
690 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from collections import OrderedDict
4 | from copy import deepcopy
5 |
6 | import torch
7 | from torch.nn.parallel import DataParallel, DistributedDataParallel
8 | from utils.dist_util import master_only
9 |
10 | logger = logging.getLogger('base')
11 |
12 |
13 | class BaseModel():
14 | """Base model."""
15 |
16 | def __init__(self, opt):
17 | self.opt = opt
18 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
19 | self.is_train = opt['is_train']
20 |
21 | def feed_data(self, data):
22 | pass
23 |
24 | def optimize_parameters(self):
25 | pass
26 |
27 | def get_current_visuals(self):
28 | pass
29 |
30 | def save(self, epoch, current_iter):
31 | """Save networks and training state."""
32 | pass
33 |
34 | def validation(self, dataloader, current_iter, tb_logger, save_img=False):
35 | """Validation function.
36 |
37 | Args:
38 | dataloader (torch.utils.data.DataLoader): Validation dataloader.
39 | current_iter (int): Current iteration.
40 | tb_logger (tensorboard logger): Tensorboard logger.
41 | save_img (bool): Whether to save images. Default: False.
42 | """
43 | if self.opt['dist']:
44 | self.dist_validation(dataloader, current_iter, tb_logger, save_img)
45 | else:
46 | self.nondist_validation(dataloader, current_iter, tb_logger,
47 | save_img)
48 |
49 | def get_current_log(self):
50 | return self.log_dict
51 |
52 | def model_to_device(self, net):
53 | """Model to device. It also warps models with DistributedDataParallel
54 | or DataParallel.
55 |
56 | Args:
57 | net (nn.Module)
58 | """
59 | net = net.to(self.device)
60 | if self.opt['dist']:
61 | find_unused_parameters = self.opt.get('find_unused_parameters',
62 | False)
63 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
64 | net = DistributedDataParallel(
65 | net,
66 | device_ids=[torch.cuda.current_device()],
67 | find_unused_parameters=find_unused_parameters)
68 | elif self.opt['num_gpu'] > 1:
69 | net = DataParallel(net)
70 | return net
71 |
72 | def get_bare_model(self, net):
73 | """Get bare model, especially under wrapping with
74 | DistributedDataParallel or DataParallel.
75 | """
76 | if isinstance(net, (DataParallel, DistributedDataParallel)):
77 | net = net.module
78 | return net
79 |
80 | @master_only
81 | def save_network(self, net, save_path):
82 | """Save networks.
83 |
84 | Args:
85 | net (nn.Module | list[nn.Module]): Network(s) to be saved.
86 | net_label (str): Network label.
87 | current_iter (int): Current iter number.
88 | param_key (str | list[str]): The parameter key(s) to save network.
89 | Default: 'params'.
90 | """
91 |
92 | net = self.get_bare_model(net)
93 | state_dict = net.state_dict()
94 | for key, param in state_dict.items():
95 | if key.startswith('module.'): # remove unnecessary 'module.'
96 | key = key[7:]
97 | state_dict[key] = param.cpu()
98 |
99 | torch.save(state_dict, save_path)
100 |
101 | def load_network(self, net, load_path, strict=True, param_key='params'):
102 | """Load network.
103 |
104 | Args:
105 | load_path (str): The path of networks to be loaded.
106 | net (nn.Module): Network.
107 | strict (bool): Whether strictly loaded.
108 | param_key (str): The parameter key of loaded network.
109 | Default: 'params'.
110 | """
111 | net = self.get_bare_model(net)
112 | logger.info(
113 | f'Loading {net.__class__.__name__} model from {load_path}.')
114 | load_net = torch.load(
115 | load_path, map_location=lambda storage, loc: storage)[param_key]
116 | # remove unnecessary 'module.'
117 | for k, v in deepcopy(load_net).items():
118 | if k.startswith('module.'):
119 | load_net[k[7:]] = v
120 | load_net.pop(k)
121 | # self._print_different_keys_loading(net, load_net, strict)
122 | net.load_state_dict(load_net, strict=strict)
123 |
124 | def update_learning_rate(self, epoch, iters=None):
125 | """Update learning rate.
126 |
127 | Args:
128 | current_iter (int): Current iteration.
129 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
130 | Default: -1.
131 | """
132 | lr = self.optimizer.param_groups[0]['lr']
133 |
134 | if self.opt['lr_decay'] == 'step':
135 | lr = self.opt['lr'] * (
136 | self.opt['gamma']**(epoch // self.opt['step']))
137 | elif self.opt['lr_decay'] == 'cos':
138 | lr = self.opt['lr'] * (
139 | 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
140 | elif self.opt['lr_decay'] == 'linear':
141 | lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
142 | elif self.opt['lr_decay'] == 'linear2exp':
143 | if epoch < self.opt['turning_point'] + 1:
144 | # learning rate decay as 95%
145 | # at the turning point (1 / 95% = 1.0526)
146 | lr = self.opt['lr'] * (
147 | 1 - epoch / int(self.opt['turning_point'] * 1.0526))
148 | else:
149 | lr *= self.opt['gamma']
150 | elif self.opt['lr_decay'] == 'schedule':
151 | if epoch in self.opt['schedule']:
152 | lr *= self.opt['gamma']
153 | elif self.opt['lr_decay'] == 'warm_up':
154 | if iters <= self.opt['warmup_iters']:
155 | lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
156 | else:
157 | lr = self.opt['lr']
158 | else:
159 | raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
160 | # set learning rate
161 | for param_group in self.optimizer.param_groups:
162 | param_group['lr'] = lr
163 |
164 | return lr
165 |
166 | def reduce_loss_dict(self, loss_dict):
167 | """reduce loss dict.
168 |
169 | In distributed training, it averages the losses among different GPUs .
170 |
171 | Args:
172 | loss_dict (OrderedDict): Loss dict.
173 | """
174 | with torch.no_grad():
175 | if self.opt['dist']:
176 | keys = []
177 | losses = []
178 | for name, value in loss_dict.items():
179 | keys.append(name)
180 | losses.append(value)
181 | losses = torch.stack(losses, 0)
182 | torch.distributed.reduce(losses, dst=0)
183 | if self.opt['rank'] == 0:
184 | losses /= self.opt['world_size']
185 | loss_dict = {key: loss for key, loss in zip(keys, losses)}
186 |
187 | log_dict = OrderedDict()
188 | for name, value in loss_dict.items():
189 | log_dict[name] = value.mean().item()
190 |
191 | return log_dict
192 |
--------------------------------------------------------------------------------
/models/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/models/losses/__init__.py
--------------------------------------------------------------------------------
/models/losses/vqgan_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | @torch.jit.script
6 | def hinge_d_loss(logits_real, logits_fake):
7 | loss_real = torch.mean(F.relu(1. - logits_real))
8 | loss_fake = torch.mean(F.relu(1. + logits_fake))
9 | d_loss = 0.5 * (loss_real + loss_fake)
10 |
11 | return d_loss, loss_real, loss_fake
12 |
13 |
14 | def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
15 | recon_grads = torch.autograd.grad(
16 | recon_loss, last_layer, retain_graph=True)[0]
17 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
18 |
19 | d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
20 | d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
21 | return d_weight
22 |
23 |
24 | def adopt_weight(weight, global_step, threshold=0, value=0.):
25 | if global_step < threshold:
26 | weight = value
27 | return weight
28 |
29 |
30 | def DiffAugment(x, policy='', channels_first=True):
31 | if policy:
32 | if not channels_first:
33 | x = x.permute(0, 3, 1, 2)
34 | for p in policy.split(','):
35 | for f in AUGMENT_FNS[p]:
36 | x = f(x)
37 | if not channels_first:
38 | x = x.permute(0, 2, 3, 1)
39 | x = x.contiguous()
40 | return x
41 |
42 |
43 | def rand_brightness(x):
44 | x = x + (
45 | torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
46 | return x
47 |
48 |
49 | def rand_saturation(x):
50 | x_mean = x.mean(dim=1, keepdim=True)
51 | x = (x - x_mean) * (torch.rand(
52 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
53 | return x
54 |
55 |
56 | def rand_contrast(x):
57 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
58 | x = (x - x_mean) * (torch.rand(
59 | x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
60 | return x
61 |
62 |
63 | def rand_translation(x, ratio=0.125):
64 | shift_x, shift_y = int(x.size(2) * ratio +
65 | 0.5), int(x.size(3) * ratio + 0.5)
66 | translation_x = torch.randint(
67 | -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
68 | translation_y = torch.randint(
69 | -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
70 | grid_batch, grid_x, grid_y = torch.meshgrid(
71 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
72 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
73 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
74 | )
75 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
76 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
77 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
78 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
79 | grid_y].permute(0, 3, 1, 2)
80 | return x
81 |
82 |
83 | def rand_cutout(x, ratio=0.5):
84 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
85 | offset_x = torch.randint(
86 | 0,
87 | x.size(2) + (1 - cutout_size[0] % 2),
88 | size=[x.size(0), 1, 1],
89 | device=x.device)
90 | offset_y = torch.randint(
91 | 0,
92 | x.size(3) + (1 - cutout_size[1] % 2),
93 | size=[x.size(0), 1, 1],
94 | device=x.device)
95 | grid_batch, grid_x, grid_y = torch.meshgrid(
96 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
97 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
98 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
99 | )
100 | grid_x = torch.clamp(
101 | grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
102 | grid_y = torch.clamp(
103 | grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
104 | mask = torch.ones(
105 | x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
106 | mask[grid_batch, grid_x, grid_y] = 0
107 | x = x * mask.unsqueeze(1)
108 | return x
109 |
110 |
111 | AUGMENT_FNS = {
112 | 'color': [rand_brightness, rand_saturation, rand_contrast],
113 | 'translation': [rand_translation],
114 | 'cutout': [rand_cutout],
115 | }
116 |
--------------------------------------------------------------------------------
/models/vqgan_decompose_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 |
4 | import lpips
5 | import torch
6 | from torchvision.utils import save_image
7 |
8 | from models.archs.vqgan_arch import (
9 | DecoderUpOthersDoubleIdentity, Discriminator,
10 | EncoderDecomposeBaseDownOthersDoubleIdentity, VectorQuantizer)
11 | from models.base_model import BaseModel
12 | from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
13 | calculate_adaptive_weight, hinge_d_loss)
14 | from utils.dist_util import master_only
15 |
16 |
17 | class VQGANDecomposeModel(BaseModel):
18 |
19 | def __init__(self, opt):
20 | super().__init__(opt)
21 |
22 | self.encoder = self.model_to_device(
23 | EncoderDecomposeBaseDownOthersDoubleIdentity(
24 | ch=opt['ch'],
25 | num_res_blocks=opt['num_res_blocks'],
26 | attn_resolutions=opt['attn_resolutions'],
27 | ch_mult=opt['ch_mult'],
28 | other_ch_mult=opt['other_ch_mult'],
29 | in_channels=opt['in_channels'],
30 | resolution=opt['resolution'],
31 | z_channels=opt['z_channels'],
32 | double_z=opt['double_z'],
33 | dropout=opt['dropout']))
34 | self.decoder = self.model_to_device(
35 | DecoderUpOthersDoubleIdentity(
36 | in_channels=opt['in_channels'],
37 | resolution=opt['resolution'],
38 | z_channels=opt['z_channels'],
39 | ch=opt['ch'],
40 | out_ch=opt['out_ch'],
41 | num_res_blocks=opt['num_res_blocks'],
42 | attn_resolutions=opt['attn_resolutions'],
43 | ch_mult=opt['ch_mult'],
44 | other_ch_mult=opt['other_ch_mult'],
45 | dropout=opt['dropout'],
46 | resamp_with_conv=True,
47 | give_pre_end=False))
48 | self.quantize_identity = self.model_to_device(
49 | VectorQuantizer(opt['n_embed'], opt['embed_dim'], beta=0.25))
50 | self.quant_conv_identity = self.model_to_device(
51 | torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], 1))
52 | self.post_quant_conv_identity = self.model_to_device(
53 | torch.nn.Conv2d(opt['embed_dim'], opt["z_channels"], 1))
54 |
55 | self.quantize_others = self.model_to_device(
56 | VectorQuantizer(opt['n_embed'], opt['embed_dim'] // 2, beta=0.25))
57 | self.quant_conv_others = self.model_to_device(
58 | torch.nn.Conv2d(opt["z_channels"] // 2, opt['embed_dim'] // 2, 1))
59 | self.post_quant_conv_others = self.model_to_device(
60 | torch.nn.Conv2d(opt['embed_dim'] // 2, opt["z_channels"] // 2, 1))
61 |
62 | self.disc = self.model_to_device(
63 | Discriminator(
64 | opt['n_channels'], opt['ndf'], n_layers=opt['disc_layers']))
65 | self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
66 | self.perceptual_weight = opt['perceptual_weight']
67 | self.disc_start_step = opt['disc_start_step']
68 | self.disc_weight_max = opt['disc_weight_max']
69 | self.diff_aug = opt['diff_aug']
70 | self.policy = "color,translation"
71 |
72 | self.disc.train()
73 |
74 | if self.opt['pretrained_models'] is not None:
75 | self.load_pretrained_network()
76 |
77 | self.init_training_settings()
78 |
79 | def init_training_settings(self):
80 | self.configure_optimizers()
81 |
82 | def configure_optimizers(self):
83 | self.optimizer = torch.optim.Adam(
84 | list(self.encoder.parameters()) + list(self.decoder.parameters()) +
85 | list(self.quantize_identity.parameters()) +
86 | list(self.quant_conv_identity.parameters()) +
87 | list(self.post_quant_conv_identity.parameters()) +
88 | list(self.quantize_others.parameters()) +
89 | list(self.quant_conv_others.parameters()) +
90 | list(self.post_quant_conv_others.parameters()),
91 | lr=self.opt['lr'])
92 |
93 | self.disc_optimizer = torch.optim.Adam(
94 | self.disc.parameters(), lr=self.opt['lr'])
95 |
96 | @master_only
97 | def save_network(self, save_path):
98 | """Save networks.
99 |
100 | Args:
101 | net (nn.Module): Network to be saved.
102 | net_label (str): Network label.
103 | current_iter (int): Current iter number.
104 | """
105 |
106 | save_dict = {}
107 | save_dict['encoder'] = self.get_bare_model(self.encoder).state_dict()
108 | save_dict['decoder'] = self.get_bare_model(self.decoder).state_dict()
109 | save_dict['quantize_identity'] = self.get_bare_model(
110 | self.quantize_identity).state_dict()
111 | save_dict['quant_conv_identity'] = self.get_bare_model(
112 | self.quant_conv_identity).state_dict()
113 | save_dict['post_quant_conv_identity'] = self.get_bare_model(
114 | self.post_quant_conv_identity).state_dict()
115 | save_dict['quantize_others'] = self.get_bare_model(
116 | self.quantize_others).state_dict()
117 | save_dict['quant_conv_others'] = self.get_bare_model(
118 | self.quant_conv_others).state_dict()
119 | save_dict['post_quant_conv_others'] = self.get_bare_model(
120 | self.post_quant_conv_others).state_dict()
121 | save_dict['disc'] = self.get_bare_model(self.disc).state_dict()
122 | torch.save(save_dict, save_path)
123 |
124 | def load_pretrained_network(self):
125 |
126 | self.load_network(
127 | self.encoder, self.opt['pretrained_models'], param_key='encoder')
128 | self.load_network(
129 | self.decoder, self.opt['pretrained_models'], param_key='decoder')
130 | self.load_network(
131 | self.quantize_identity,
132 | self.opt['pretrained_models'],
133 | param_key='quantize_identity')
134 | self.load_network(
135 | self.quant_conv_identity,
136 | self.opt['pretrained_models'],
137 | param_key='quant_conv_identity')
138 | self.load_network(
139 | self.post_quant_conv_identity,
140 | self.opt['pretrained_models'],
141 | param_key='post_quant_conv_identity')
142 | self.load_network(
143 | self.quantize_others,
144 | self.opt['pretrained_models'],
145 | param_key='quantize_others')
146 | self.load_network(
147 | self.quant_conv_others,
148 | self.opt['pretrained_models'],
149 | param_key='quant_conv_others')
150 | self.load_network(
151 | self.post_quant_conv_others,
152 | self.opt['pretrained_models'],
153 | param_key='post_quant_conv_others')
154 |
155 | def optimize_parameters(self, data, current_iter):
156 | self.encoder.train()
157 | self.decoder.train()
158 | self.quantize_identity.train()
159 | self.quant_conv_identity.train()
160 | self.post_quant_conv_identity.train()
161 | self.quantize_others.train()
162 | self.quant_conv_others.train()
163 | self.post_quant_conv_others.train()
164 |
165 | loss, d_loss = self.training_step(data, current_iter)
166 | self.optimizer.zero_grad()
167 | loss.backward()
168 | self.optimizer.step()
169 |
170 | if current_iter > self.disc_start_step:
171 | self.disc_optimizer.zero_grad()
172 | d_loss.backward()
173 | self.disc_optimizer.step()
174 |
175 | def feed_data(self, data):
176 | x_identity = data['identity_image']
177 | x_frame_aug = data['frame_img_aug']
178 | x_frame = data['frame_img']
179 |
180 | return x_identity.float().to(self.device), x_frame_aug.float().to(
181 | self.device), x_frame.float().to(self.device)
182 |
183 | def encode(self, x_identity, x_frame):
184 | h_identity, _ = self.encoder(x_identity)
185 | h_identity = self.quant_conv_identity(h_identity)
186 | quant_identity, emb_loss_identity, _ = self.quantize_identity(
187 | h_identity)
188 |
189 | _, h_frame = self.encoder(x_frame)
190 | h_frame = self.quant_conv_others(h_frame)
191 | quant_frame, emb_loss_frame, _ = self.quantize_others(h_frame)
192 | return [quant_identity,
193 | quant_frame], emb_loss_identity + emb_loss_frame
194 |
195 | def decode(self, quant_list):
196 | quant_identity = self.post_quant_conv_identity(quant_list[0])
197 | quant_frame = self.post_quant_conv_others(quant_list[1])
198 | dec = self.decoder(quant_identity, quant_frame)
199 | return dec
200 |
201 | def forward_step(self, x_identity, x_frame_aug):
202 | quant_list, diff = self.encode(x_identity, x_frame_aug)
203 | dec = self.decode(quant_list)
204 | return dec, diff
205 |
206 | def training_step(self, data, step):
207 | x_identity, x_frame_aug, x_frame = self.feed_data(data)
208 | xrec, codebook_loss = self.forward_step(x_identity, x_frame_aug)
209 |
210 | # get recon/perceptual loss
211 | recon_loss = torch.abs(x_frame.contiguous() - xrec.contiguous())
212 | p_loss = self.perceptual(x_frame.contiguous(), xrec.contiguous())
213 | nll_loss = recon_loss + self.perceptual_weight * p_loss
214 | nll_loss = torch.mean(nll_loss)
215 |
216 | # augment for input to discriminator
217 | if self.diff_aug:
218 | xrec = DiffAugment(xrec, policy=self.policy)
219 |
220 | # update generator
221 | logits_fake = self.disc(xrec)
222 | g_loss = -torch.mean(logits_fake)
223 | last_layer = self.get_bare_model(self.decoder).conv_out.weight
224 | d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
225 | self.disc_weight_max)
226 | d_weight *= adopt_weight(1, step, self.disc_start_step)
227 | loss = nll_loss + d_weight * g_loss + codebook_loss
228 |
229 | loss_dict = OrderedDict()
230 |
231 | loss_dict["loss"] = loss
232 | loss_dict["l1"] = recon_loss.mean()
233 | loss_dict["perceptual"] = p_loss.mean()
234 | loss_dict["nll_loss"] = nll_loss
235 | loss_dict["g_loss"] = g_loss
236 | loss_dict["d_weight"] = d_weight
237 | loss_dict["codebook_loss"] = codebook_loss
238 |
239 | if step > self.disc_start_step:
240 | if self.diff_aug:
241 | logits_real = self.disc(
242 | DiffAugment(
243 | x_frame.contiguous().detach(), policy=self.policy))
244 | else:
245 | logits_real = self.disc(x_frame.contiguous().detach())
246 | logits_fake = self.disc(xrec.contiguous().detach(
247 | )) # detach so that generator isn"t also updated
248 | d_loss, _, _ = hinge_d_loss(logits_real, logits_fake)
249 | loss_dict["d_loss"] = d_loss
250 | else:
251 | d_loss = None
252 |
253 | self.log_dict = self.reduce_loss_dict(loss_dict)
254 |
255 | return loss, d_loss
256 |
257 | @master_only
258 | def get_vis(self, x_identity, x_frame_aug, x_frame, xrec, save_dir,
259 | img_name):
260 | os.makedirs(save_dir, exist_ok=True)
261 | img_cat = torch.cat([x_identity, x_frame_aug, x_frame, xrec],
262 | dim=3).detach()
263 | img_cat = ((img_cat + 1) / 2)
264 | img_cat = img_cat.clamp_(0, 1)
265 | save_image(img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
266 |
267 | @torch.no_grad()
268 | def inference(self, data_loader, save_dir):
269 | self.encoder.eval()
270 | self.decoder.eval()
271 | self.quantize_identity.eval()
272 | self.quant_conv_identity.eval()
273 | self.post_quant_conv_identity.eval()
274 | self.quantize_others.eval()
275 | self.quant_conv_others.eval()
276 | self.post_quant_conv_others.eval()
277 |
278 | loss_total = 0
279 | num = 0
280 |
281 | for _, data in enumerate(data_loader):
282 | img_name = data['video_name'][0]
283 | x_identity, x_frame_aug, x_frame = self.feed_data(data)
284 | xrec, _ = self.forward_step(x_identity, x_frame_aug)
285 |
286 | recon_loss = torch.abs(x_frame.contiguous() - xrec.contiguous())
287 | p_loss = self.perceptual(x_frame.contiguous(), xrec.contiguous())
288 | nll_loss = recon_loss + self.perceptual_weight * p_loss
289 | nll_loss = torch.mean(nll_loss)
290 | loss_total += nll_loss
291 |
292 | num += x_frame.size(0)
293 |
294 | self.get_vis(x_identity, x_frame_aug, x_frame, xrec, save_dir,
295 | img_name)
296 |
297 | return (loss_total / num).item() * (-1)
298 |
--------------------------------------------------------------------------------
/train_dist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import os.path as osp
6 | import random
7 | import time
8 |
9 | import torch
10 |
11 | from data import create_dataloader, create_dataset
12 | from data.data_sampler import EnlargedSampler
13 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
14 | from models import create_model
15 | from utils.dist_util import get_dist_info, init_dist
16 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
17 | from utils.options import dict2str, dict_to_nonedict, parse
18 | from utils.util import make_exp_dirs, set_random_seed
19 |
20 |
21 | def get_dataloader(opt, logger):
22 | # create train, test, val dataloaders
23 | train_loader, val_loader, test_loader = None, None, None
24 | dataset_enlarge_ratio = opt.get('dataset_enlarge_ratio', 1)
25 | train_set = create_dataset(opt['datasets']['train'])
26 | opt['max_iters'] = opt['num_epochs'] * len(train_set) // (
27 | opt['batch_size_per_gpu'] * opt['num_gpu'])
28 | logger.info(f'Number of train set: {len(train_set)}.')
29 | train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'],
30 | dataset_enlarge_ratio)
31 | train_loader = create_dataloader(
32 | train_set,
33 | opt,
34 | phase='train',
35 | num_gpu=opt['num_gpu'],
36 | dist=opt['dist'],
37 | sampler=train_sampler,
38 | seed=opt['manual_seed'])
39 |
40 | val_set = create_dataset(opt['datasets']['val'])
41 | logger.info(f'Number of val set: {len(val_set)}.')
42 | val_loader = create_dataloader(
43 | val_set,
44 | opt,
45 | phase='val',
46 | num_gpu=opt['num_gpu'],
47 | dist=opt['dist'],
48 | sampler=None,
49 | seed=opt['manual_seed'])
50 |
51 | test_set = create_dataset(opt['datasets']['test'])
52 | logger.info(f'Number of test set: {len(test_set)}.')
53 | test_loader = create_dataloader(
54 | test_set,
55 | opt,
56 | phase='test',
57 | num_gpu=opt['num_gpu'],
58 | dist=opt['dist'],
59 | sampler=None,
60 | seed=opt['manual_seed'])
61 |
62 | return train_loader, train_sampler, val_loader, test_loader
63 |
64 |
65 | def main():
66 | # options
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
69 | parser.add_argument(
70 | '--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
71 | parser.add_argument('--local_rank', type=int, default=0)
72 | args = parser.parse_args()
73 | opt = parse(args.opt, is_train=True)
74 |
75 | # distributed settings
76 | if args.launcher == 'none':
77 | opt['dist'] = False
78 | print('Disable distributed.', flush=True)
79 | else:
80 | opt['dist'] = True
81 | if args.launcher == 'slurm' and 'dist_params' in opt:
82 | init_dist(args.launcher, **opt['dist_params'])
83 | else:
84 | init_dist(args.launcher)
85 |
86 | opt['rank'], opt['world_size'] = get_dist_info()
87 |
88 | # mkdir and loggers
89 | if opt['rank'] == 0:
90 | make_exp_dirs(opt)
91 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
92 | logger = get_root_logger(
93 | logger_name='base', log_level=logging.INFO, log_file=log_file)
94 | logger.info(dict2str(opt))
95 |
96 | # initialize tensorboard logger
97 | tb_logger = None
98 | if opt['use_tb_logger'] and 'debug' not in opt['name'] and opt['rank'] == 0:
99 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
100 |
101 | # random seed
102 | seed = opt['manual_seed']
103 | if seed is None:
104 | seed = random.randint(1, 10000)
105 | logger.info(f'Random seed: {seed}')
106 | set_random_seed(seed + opt['rank'])
107 |
108 | torch.backends.cudnn.benchmark = True
109 | # torch.backends.cudnn.deterministic = True
110 |
111 | # convert to NoneDict, which returns None for missing keys
112 | opt = dict_to_nonedict(opt)
113 |
114 | # set up data loader
115 | train_loader, train_sampler, val_loader, test_loader = get_dataloader(
116 | opt, logger)
117 |
118 | # dataloader prefetcher
119 | prefetch_mode = opt.get('prefetch_mode')
120 | if prefetch_mode is None or prefetch_mode == 'cpu':
121 | prefetcher = CPUPrefetcher(train_loader)
122 | elif prefetch_mode == 'cuda':
123 | prefetcher = CUDAPrefetcher(train_loader, opt)
124 | logger.info(f'Use {prefetch_mode} prefetch dataloader')
125 | if opt['datasets']['train'].get('pin_memory') is not True:
126 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
127 | else:
128 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
129 | "Supported ones are: None, 'cuda', 'cpu'.")
130 |
131 | current_iter = 0
132 | best_epoch = None
133 | best_acc = -100
134 |
135 | model = create_model(opt)
136 |
137 | data_time, iter_time = 0, 0
138 | current_iter = 0
139 |
140 | # create message logger (formatted outputs)
141 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
142 |
143 | for epoch in range(opt['num_epochs']):
144 | train_sampler.set_epoch(epoch)
145 | prefetcher.reset()
146 | train_data = prefetcher.next()
147 |
148 | lr = model.update_learning_rate(epoch)
149 |
150 | while train_data is not None:
151 | data_time = time.time() - data_time
152 |
153 | current_iter += 1
154 |
155 | model.feed_data(train_data)
156 | model.optimize_parameters()
157 |
158 | iter_time = time.time() - iter_time
159 | if current_iter % opt['print_freq'] == 0:
160 | log_vars = {'epoch': (epoch + 1), 'iter': current_iter}
161 | log_vars.update({'lrs': [lr]})
162 | log_vars.update({'time': iter_time, 'data_time': data_time})
163 | log_vars.update(model.get_current_log())
164 | msg_logger(log_vars)
165 |
166 | data_time = time.time()
167 | iter_time = time.time()
168 | train_data = prefetcher.next()
169 |
170 | if (epoch + 1) % opt['val_freq'] == 0:
171 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{(epoch + 1):03d}' # noqa
172 | val_acc = model.inference(val_loader, f'{save_dir}/inference')
173 | model.sample_multinomial(val_loader, f'{save_dir}/sample')
174 |
175 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{(epoch + 1):03d}' # noqa
176 | test_acc = model.inference(test_loader, f'{save_dir}/inference')
177 | model.sample_multinomial(test_loader, f'{save_dir}/sample')
178 |
179 | logger.info(f'Epoch: {(epoch + 1)}, '
180 | f'val_acc: {val_acc: .4f}, '
181 | f'test_acc: {test_acc: .4f}.')
182 |
183 | if test_acc > best_acc:
184 | best_epoch = (epoch + 1)
185 | best_acc = test_acc
186 |
187 | logger.info(f'Best epoch: {best_epoch}, '
188 | f'Best test acc: {best_acc: .4f}.')
189 |
190 | # save model
191 | model.save_network(
192 | model.sampler,
193 | f'{opt["path"]["models"]}/epoch_{(epoch + 1)}.pth')
194 |
195 | # torch.cuda.empty_cache()
196 |
197 |
198 | if __name__ == '__main__':
199 | main()
200 |
--------------------------------------------------------------------------------
/train_sampler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import os.path as osp
5 | import random
6 | import time
7 |
8 | import torch
9 |
10 | from data.sample_identity_dataset import SampleIdentityDataset
11 | from models import create_model
12 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13 | from utils.options import dict2str, dict_to_nonedict, parse
14 | from utils.util import make_exp_dirs, set_random_seed
15 |
16 |
17 | def main():
18 | # options
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21 | args = parser.parse_args()
22 | opt = parse(args.opt, is_train=True)
23 |
24 | # mkdir and loggers
25 | make_exp_dirs(opt)
26 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27 | logger = get_root_logger(
28 | logger_name='base', log_level=logging.INFO, log_file=log_file)
29 | logger.info(dict2str(opt))
30 | # initialize tensorboard logger
31 | tb_logger = None
32 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
33 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34 |
35 | # random seed
36 | seed = opt['manual_seed']
37 | if seed is None:
38 | seed = random.randint(1, 10000)
39 | logger.info(f'Random seed: {seed}')
40 | set_random_seed(seed)
41 |
42 | # convert to NoneDict, which returns None for missing keys
43 | opt = dict_to_nonedict(opt)
44 |
45 | # set up data loader
46 | train_dataset = SampleIdentityDataset(opt['datasets']['train'])
47 | train_loader = torch.utils.data.DataLoader(
48 | dataset=train_dataset,
49 | batch_size=opt['batch_size'],
50 | shuffle=True,
51 | num_workers=opt['num_workers'],
52 | persistent_workers=True,
53 | drop_last=True)
54 | logger.info(f'Number of train set: {len(train_dataset)}.')
55 | opt['max_iters'] = opt['num_epochs'] * len(
56 | train_dataset) // opt['batch_size']
57 |
58 | val_dataset = SampleIdentityDataset(opt['datasets']['val'])
59 | val_loader = torch.utils.data.DataLoader(
60 | dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
61 | logger.info(f'Number of val set: {len(val_dataset)}.')
62 |
63 | test_dataset = SampleIdentityDataset(opt['datasets']['test'])
64 | test_loader = torch.utils.data.DataLoader(
65 | dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
66 | logger.info(f'Number of test set: {len(test_dataset)}.')
67 |
68 | current_iter = 0
69 |
70 | model = create_model(opt)
71 |
72 | data_time, iter_time = 0, 0
73 | current_iter = 0
74 |
75 | # create message logger (formatted outputs)
76 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
77 |
78 | for epoch in range(opt['num_epochs']):
79 | lr = model.update_learning_rate(epoch, current_iter)
80 |
81 | for _, batch_data in enumerate(train_loader):
82 | data_time = time.time() - data_time
83 |
84 | current_iter += 1
85 |
86 | model.feed_data(batch_data)
87 | model.optimize_parameters()
88 |
89 | iter_time = time.time() - iter_time
90 | if current_iter % opt['print_freq'] == 0:
91 | log_vars = {'epoch': epoch, 'iter': current_iter}
92 | log_vars.update({'lrs': [lr]})
93 | log_vars.update({'time': iter_time, 'data_time': data_time})
94 | log_vars.update(model.get_current_log())
95 | msg_logger(log_vars)
96 |
97 | data_time = time.time()
98 | iter_time = time.time()
99 |
100 | if epoch % opt['val_freq'] == 0 and epoch != 0:
101 | save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
102 | os.makedirs(save_dir, exist_ok=opt['debug'])
103 | model.inference(val_loader, save_dir)
104 |
105 | save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
106 | os.makedirs(save_dir, exist_ok=opt['debug'])
107 | model.inference(test_loader, save_dir)
108 |
109 | # save model
110 | model.save_network(
111 | model._denoise_fn,
112 | f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
113 |
114 |
115 | if __name__ == '__main__':
116 | main()
117 |
--------------------------------------------------------------------------------
/train_vqvae_iter_dist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os.path as osp
4 | import random
5 | import time
6 |
7 | import torch
8 |
9 | from data import create_dataloader, create_dataset
10 | from data.data_sampler import EnlargedSampler
11 | from data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
12 | from models import create_model
13 | from utils.dist_util import get_dist_info, init_dist
14 | from utils.logger import MessageLogger, get_root_logger, init_tb_logger
15 | from utils.options import dict2str, dict_to_nonedict, parse
16 | from utils.util import make_exp_dirs, set_random_seed
17 |
18 |
19 | def get_dataloader(opt, logger):
20 | # create train, test, val dataloaders
21 | train_loader, val_loader, test_loader = None, None, None
22 | dataset_enlarge_ratio = opt.get('dataset_enlarge_ratio', 1)
23 | train_set = create_dataset(opt['datasets']['train'])
24 | opt['max_iters'] = opt['num_epochs'] * len(train_set) // (
25 | opt['batch_size_per_gpu'] * opt['num_gpu'])
26 | logger.info(f'Number of train set: {len(train_set)}.')
27 | train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'],
28 | dataset_enlarge_ratio)
29 | train_loader = create_dataloader(
30 | train_set,
31 | opt,
32 | phase='train',
33 | num_gpu=opt['num_gpu'],
34 | dist=opt['dist'],
35 | sampler=train_sampler,
36 | seed=opt['manual_seed'])
37 |
38 | val_set = create_dataset(opt['datasets']['val'])
39 | logger.info(f'Number of val set: {len(val_set)}.')
40 | val_loader = create_dataloader(
41 | val_set,
42 | opt,
43 | phase='val',
44 | num_gpu=opt['num_gpu'],
45 | dist=opt['dist'],
46 | sampler=None,
47 | seed=opt['manual_seed'])
48 |
49 | test_set = create_dataset(opt['datasets']['test'])
50 | logger.info(f'Number of test set: {len(test_set)}.')
51 | test_loader = create_dataloader(
52 | test_set,
53 | opt,
54 | phase='test',
55 | num_gpu=opt['num_gpu'],
56 | dist=opt['dist'],
57 | sampler=None,
58 | seed=opt['manual_seed'])
59 |
60 | return train_loader, train_sampler, val_loader, test_loader
61 |
62 |
63 | def main():
64 | # options
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('-opt', type=str, help='Path to option YAML file.')
67 | parser.add_argument(
68 | '--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
69 | parser.add_argument('--local_rank', type=int, default=0)
70 | args = parser.parse_args()
71 | opt = parse(args.opt, is_train=True)
72 |
73 | # distributed settings
74 | if args.launcher == 'none':
75 | opt['dist'] = False
76 | print('Disable distributed.', flush=True)
77 | else:
78 | opt['dist'] = True
79 | if args.launcher == 'slurm' and 'dist_params' in opt:
80 | init_dist(args.launcher, **opt['dist_params'])
81 | else:
82 | init_dist(args.launcher)
83 |
84 | opt['rank'], opt['world_size'] = get_dist_info()
85 |
86 | # mkdir and loggers
87 | if opt['rank'] == 0:
88 | make_exp_dirs(opt)
89 | log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
90 | logger = get_root_logger(
91 | logger_name='base', log_level=logging.INFO, log_file=log_file)
92 | logger.info(dict2str(opt))
93 |
94 | # initialize tensorboard logger
95 | tb_logger = None
96 | if opt['use_tb_logger'] and 'debug' not in opt['name'] and opt['rank'] == 0:
97 | tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
98 |
99 | # random seed
100 | seed = opt['manual_seed']
101 | if seed is None:
102 | seed = random.randint(1, 10000)
103 | logger.info(f'Random seed: {seed}')
104 | set_random_seed(seed + opt['rank'])
105 |
106 | torch.backends.cudnn.benchmark = True
107 | # torch.backends.cudnn.deterministic = True
108 |
109 | # convert to NoneDict, which returns None for missing keys
110 | opt = dict_to_nonedict(opt)
111 |
112 | # set up data loader
113 | train_loader, train_sampler, val_loader, test_loader = get_dataloader(
114 | opt, logger)
115 |
116 | # dataloader prefetcher
117 | prefetch_mode = opt.get('prefetch_mode')
118 | if prefetch_mode is None or prefetch_mode == 'cpu':
119 | prefetcher = CPUPrefetcher(train_loader)
120 | elif prefetch_mode == 'cuda':
121 | prefetcher = CUDAPrefetcher(train_loader, opt)
122 | logger.info(f'Use {prefetch_mode} prefetch dataloader')
123 | if opt['datasets']['train'].get('pin_memory') is not True:
124 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
125 | else:
126 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
127 | "Supported ones are: None, 'cuda', 'cpu'.")
128 |
129 | current_iter = 0
130 | best_epoch = None
131 | best_acc = -100
132 |
133 | model = create_model(opt)
134 |
135 | data_time, iter_time = 0, 0
136 | current_iter = 0
137 |
138 | # create message logger (formatted outputs)
139 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
140 |
141 | for epoch in range(opt['num_epochs']):
142 | train_sampler.set_epoch(epoch)
143 | prefetcher.reset()
144 | train_data = prefetcher.next()
145 |
146 | lr = model.update_learning_rate(epoch)
147 |
148 | while train_data is not None:
149 | data_time = time.time() - data_time
150 |
151 | current_iter += 1
152 |
153 | model.optimize_parameters(train_data, current_iter)
154 |
155 | iter_time = time.time() - iter_time
156 | if current_iter % opt['print_freq'] == 0:
157 | log_vars = {'epoch': (epoch + 1), 'iter': current_iter}
158 | log_vars.update({'lrs': [lr]})
159 | log_vars.update({'time': iter_time, 'data_time': data_time})
160 | log_vars.update(model.get_current_log())
161 | msg_logger(log_vars)
162 |
163 | data_time = time.time()
164 | iter_time = time.time()
165 | train_data = prefetcher.next()
166 |
167 | if (current_iter + 1) % opt['val_freq'] == 0:
168 | save_dir = f'{opt["path"]["visualization"]}/valset/iter_{(current_iter + 1):03d}' # noqa
169 | val_acc = model.inference(val_loader, f'{save_dir}/inference')
170 |
171 | save_dir = f'{opt["path"]["visualization"]}/testset/iter_{(current_iter + 1):03d}' # noqa
172 | test_acc = model.inference(test_loader,
173 | f'{save_dir}/inference')
174 |
175 | logger.info(f'Iter: {(current_iter + 1)}, '
176 | f'val_acc: {val_acc: .4f}, '
177 | f'test_acc: {test_acc: .4f}.')
178 |
179 | if test_acc > best_acc:
180 | best_iter = (current_iter + 1)
181 | best_acc = test_acc
182 |
183 | logger.info(f'Best iter: {best_iter}, '
184 | f'Best test acc: {best_acc: .4f}.')
185 |
186 | # save model
187 | model.save_network(
188 | f'{opt["path"]["models"]}/iter_{(current_iter + 1)}.pth')
189 |
190 | torch.cuda.empty_cache()
191 |
192 |
193 | if __name__ == '__main__':
194 | main()
195 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yumingj/Text2Performer/433489aae7bdd6fd868a2e272d98bed7c7b642a5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 |
6 | import torch
7 | import torch.distributed as dist
8 | import torch.multiprocessing as mp
9 |
10 |
11 | def init_dist(launcher, backend='nccl', **kwargs):
12 | if mp.get_start_method(allow_none=True) is None:
13 | mp.set_start_method('spawn')
14 | if launcher == 'pytorch':
15 | _init_dist_pytorch(backend, **kwargs)
16 | elif launcher == 'slurm':
17 | _init_dist_slurm(backend, **kwargs)
18 | else:
19 | raise ValueError(f'Invalid launcher type: {launcher}')
20 |
21 |
22 | def _init_dist_pytorch(backend, **kwargs):
23 | rank = int(os.environ['RANK'])
24 | num_gpus = torch.cuda.device_count()
25 | torch.cuda.set_device(rank % num_gpus)
26 | dist.init_process_group(backend=backend, **kwargs)
27 |
28 |
29 | def _init_dist_slurm(backend, port=None):
30 | """Initialize slurm distributed training environment.
31 |
32 | If argument ``port`` is not specified, then the master port will be system
33 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
34 | environment variable, then a default port ``29500`` will be used.
35 |
36 | Args:
37 | backend (str): Backend of torch.distributed.
38 | port (int, optional): Master port. Defaults to None.
39 | """
40 | proc_id = int(os.environ['SLURM_PROCID'])
41 | ntasks = int(os.environ['SLURM_NTASKS'])
42 | node_list = os.environ['SLURM_NODELIST']
43 | num_gpus = torch.cuda.device_count()
44 | torch.cuda.set_device(proc_id % num_gpus)
45 | addr = subprocess.getoutput(
46 | f'scontrol show hostname {node_list} | head -n1')
47 | # specify master port
48 | if port is not None:
49 | os.environ['MASTER_PORT'] = str(port)
50 | elif 'MASTER_PORT' in os.environ:
51 | pass # use MASTER_PORT in the environment variable
52 | else:
53 | # 29500 is torch.distributed default port
54 | os.environ['MASTER_PORT'] = '29500'
55 | os.environ['MASTER_ADDR'] = addr
56 | os.environ['WORLD_SIZE'] = str(ntasks)
57 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
58 | os.environ['RANK'] = str(proc_id)
59 | dist.init_process_group(backend=backend)
60 |
61 |
62 | def get_dist_info():
63 | if dist.is_available():
64 | initialized = dist.is_initialized()
65 | else:
66 | initialized = False
67 | if initialized:
68 | rank = dist.get_rank()
69 | world_size = dist.get_world_size()
70 | else:
71 | rank = 0
72 | world_size = 1
73 | return rank, world_size
74 |
75 |
76 | def master_only(func):
77 |
78 | @functools.wraps(func)
79 | def wrapper(*args, **kwargs):
80 | rank, _ = get_dist_info()
81 | if rank == 0:
82 | return func(*args, **kwargs)
83 |
84 | return wrapper
85 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 |
8 | class MessageLogger():
9 | """Message logger for printing.
10 |
11 | Args:
12 | opt (dict): Config. It contains the following keys:
13 | name (str): Exp name.
14 | logger (dict): Contains 'print_freq' (str) for logger interval.
15 | train (dict): Contains 'niter' (int) for total iters.
16 | use_tb_logger (bool): Use tensorboard logger.
17 | start_iter (int): Start iter. Default: 1.
18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
19 | """
20 |
21 | def __init__(self, opt, start_iter=1, tb_logger=None):
22 | self.exp_name = opt['name']
23 | self.interval = opt['print_freq']
24 | self.start_iter = start_iter
25 | self.max_iters = opt['max_iters']
26 | self.use_tb_logger = opt['use_tb_logger']
27 | self.tb_logger = tb_logger
28 | self.start_time = time.time()
29 | self.logger = get_root_logger()
30 |
31 | @master_only
32 | def __call__(self, log_vars):
33 | """Format logging message.
34 |
35 | Args:
36 | log_vars (dict): It contains the following keys:
37 | epoch (int): Epoch number.
38 | iter (int): Current iter.
39 | lrs (list): List for learning rates.
40 |
41 | time (float): Iter time.
42 | data_time (float): Data time for each iter.
43 | """
44 | # epoch, iter, learning rates
45 | epoch = log_vars.pop('epoch')
46 | current_iter = log_vars.pop('iter')
47 | lrs = log_vars.pop('lrs')
48 |
49 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
50 | f'iter:{current_iter:8,d}, lr:(')
51 | for v in lrs:
52 | message += f'{v:.3e},'
53 | message += ')] '
54 |
55 | # time and estimated time
56 | if 'time' in log_vars.keys():
57 | iter_time = log_vars.pop('time')
58 | data_time = log_vars.pop('data_time')
59 |
60 | total_time = time.time() - self.start_time
61 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
62 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
63 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
64 | message += f'[eta: {eta_str}, '
65 | message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
66 |
67 | # other items, especially losses
68 | for k, v in log_vars.items():
69 | message += f'{k}: {v:.4e} '
70 | # tensorboard logger
71 | if self.use_tb_logger and 'debug' not in self.exp_name:
72 | self.tb_logger.add_scalar(k, v, current_iter)
73 |
74 | self.logger.info(message)
75 |
76 |
77 | @master_only
78 | def init_tb_logger(log_dir):
79 | from torch.utils.tensorboard import SummaryWriter
80 | tb_logger = SummaryWriter(log_dir=log_dir)
81 | return tb_logger
82 |
83 |
84 | def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
85 | """Get the root logger.
86 |
87 | The logger will be initialized if it has not been initialized. By default a
88 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
89 | also be added.
90 |
91 | Args:
92 | logger_name (str): root logger name. Default: base.
93 | log_file (str | None): The log filename. If specified, a FileHandler
94 | will be added to the root logger.
95 | log_level (int): The root logger level. Note that only the process of
96 | rank 0 is affected, while other processes will set the level to
97 | "Error" and be silent most of the time.
98 |
99 | Returns:
100 | logging.Logger: The root logger.
101 | """
102 | logger = logging.getLogger(logger_name)
103 | # if the logger has been initialized, just return it
104 | if logger.hasHandlers():
105 | return logger
106 |
107 | format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
108 | logging.basicConfig(format=format_str, level=log_level)
109 | rank, _ = get_dist_info()
110 | if rank != 0:
111 | logger.setLevel('ERROR')
112 | elif log_file is not None:
113 | file_handler = logging.FileHandler(log_file, 'w')
114 | file_handler.setFormatter(logging.Formatter(format_str))
115 | file_handler.setLevel(log_level)
116 | logger.addHandler(file_handler)
117 |
118 | return logger
119 |
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from collections import OrderedDict
4 |
5 | import yaml
6 |
7 |
8 | def ordered_yaml():
9 | """Support OrderedDict for yaml.
10 |
11 | Returns:
12 | yaml Loader and Dumper.
13 | """
14 | try:
15 | from yaml import CDumper as Dumper
16 | from yaml import CLoader as Loader
17 | except ImportError:
18 | from yaml import Dumper, Loader
19 |
20 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
21 |
22 | def dict_representer(dumper, data):
23 | return dumper.represent_dict(data.items())
24 |
25 | def dict_constructor(loader, node):
26 | return OrderedDict(loader.construct_pairs(node))
27 |
28 | Dumper.add_representer(OrderedDict, dict_representer)
29 | Loader.add_constructor(_mapping_tag, dict_constructor)
30 | return Loader, Dumper
31 |
32 |
33 | def parse(opt_path, is_train=True):
34 | """Parse option file.
35 |
36 | Args:
37 | opt_path (str): Option file path.
38 | is_train (str): Indicate whether in training or not. Default: True.
39 |
40 | Returns:
41 | (dict): Options.
42 | """
43 | with open(opt_path, mode='r') as f:
44 | Loader, _ = ordered_yaml()
45 | opt = yaml.load(f, Loader=Loader)
46 |
47 | # gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
48 | # if opt.get('set_CUDA_VISIBLE_DEVICES', None):
49 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
50 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
51 | # else:
52 | # print('gpu_list: ', gpu_list, flush=True)
53 |
54 | opt['is_train'] = is_train
55 |
56 | # datasets
57 | if 'datasets' in opt.keys():
58 | for phase, dataset in opt['datasets'].items():
59 | # for several datasets, e.g., test_1, test_2
60 | phase = phase.split('_')[0]
61 | dataset['phase'] = phase
62 |
63 | # paths
64 | opt['path'] = {}
65 | opt['path']['root'] = osp.abspath(
66 | osp.join(__file__, osp.pardir, osp.pardir))
67 | if is_train:
68 | experiments_root = osp.join(opt['path']['root'], 'experiments',
69 | opt['name'])
70 | opt['path']['experiments_root'] = experiments_root
71 | opt['path']['models'] = osp.join(experiments_root, 'models')
72 | opt['path']['log'] = experiments_root
73 | opt['path']['visualization'] = osp.join(experiments_root,
74 | 'visualization')
75 |
76 | # change some options for debug mode
77 | if 'debug' in opt['name']:
78 | opt['debug'] = True
79 | opt['val_freq'] = 1
80 | opt['print_freq'] = 1
81 | opt['save_checkpoint_freq'] = 1
82 | else: # test
83 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
84 | opt['path']['results_root'] = results_root
85 | opt['path']['log'] = results_root
86 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
87 |
88 | return opt
89 |
90 |
91 | def dict2str(opt, indent_level=1):
92 | """dict to string for printing options.
93 |
94 | Args:
95 | opt (dict): Option dict.
96 | indent_level (int): Indent level. Default: 1.
97 |
98 | Return:
99 | (str): Option string for printing.
100 | """
101 | msg = ''
102 | for k, v in opt.items():
103 | if isinstance(v, dict):
104 | msg += ' ' * (indent_level * 2) + k + ':[\n'
105 | msg += dict2str(v, indent_level + 1)
106 | msg += ' ' * (indent_level * 2) + ']\n'
107 | else:
108 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
109 | return msg
110 |
111 |
112 | class NoneDict(dict):
113 | """None dict. It will return none if key is not in the dict."""
114 |
115 | def __missing__(self, key):
116 | return None
117 |
118 |
119 | def dict_to_nonedict(opt):
120 | """Convert to NoneDict, which returns None for missing keys.
121 |
122 | Args:
123 | opt (dict): Option dict.
124 |
125 | Returns:
126 | (dict): NoneDict for options.
127 | """
128 | if isinstance(opt, dict):
129 | new_opt = dict()
130 | for key, sub_opt in opt.items():
131 | new_opt[key] = dict_to_nonedict(sub_opt)
132 | return NoneDict(**new_opt)
133 | elif isinstance(opt, list):
134 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
135 | else:
136 | return opt
137 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import sys
5 | import time
6 | from shutil import get_terminal_size
7 |
8 | import numpy as np
9 | import torch
10 |
11 | logger = logging.getLogger('base')
12 |
13 |
14 | def make_exp_dirs(opt):
15 | """Make dirs for experiments."""
16 | path_opt = opt['path'].copy()
17 | if opt['is_train']:
18 | overwrite = True if 'debug' in opt['name'] else False
19 | os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
20 | os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
21 | else:
22 | os.makedirs(path_opt.pop('results_root'))
23 |
24 |
25 | def set_random_seed(seed):
26 | """Set random seeds."""
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 |
33 |
34 | class ProgressBar(object):
35 | """A progress bar which can print the progress.
36 |
37 | Modified from:
38 | https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
39 | """
40 |
41 | def __init__(self, task_num=0, bar_width=50, start=True):
42 | self.task_num = task_num
43 | max_bar_width = self._get_max_bar_width()
44 | self.bar_width = (
45 | bar_width if bar_width <= max_bar_width else max_bar_width)
46 | self.completed = 0
47 | if start:
48 | self.start()
49 |
50 | def _get_max_bar_width(self):
51 | terminal_width, _ = get_terminal_size()
52 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
53 | if max_bar_width < 10:
54 | print(f'terminal width is too small ({terminal_width}), '
55 | 'please consider widen the terminal for better '
56 | 'progressbar visualization')
57 | max_bar_width = 10
58 | return max_bar_width
59 |
60 | def start(self):
61 | if self.task_num > 0:
62 | sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
63 | f'elapsed: 0s, ETA:\nStart...\n')
64 | else:
65 | sys.stdout.write('completed: 0, elapsed: 0s')
66 | sys.stdout.flush()
67 | self.start_time = time.time()
68 |
69 | def update(self, msg='In progress...'):
70 | self.completed += 1
71 | elapsed = time.time() - self.start_time
72 | fps = self.completed / elapsed
73 | if self.task_num > 0:
74 | percentage = self.completed / float(self.task_num)
75 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
76 | mark_width = int(self.bar_width * percentage)
77 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
78 | sys.stdout.write('\033[2F') # cursor up 2 lines
79 | sys.stdout.write(
80 | '\033[J'
81 | ) # clean the output (remove extra chars since last display)
82 | sys.stdout.write(
83 | f'[{bar_chars}] {self.completed}/{self.task_num}, '
84 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
85 | f'ETA: {eta:5}s\n{msg}\n')
86 | else:
87 | sys.stdout.write(
88 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
89 | f'{fps:.1f} tasks/s')
90 | sys.stdout.flush()
91 |
92 |
93 | class AverageMeter(object):
94 | """
95 | Computes and stores the average and current value
96 | Imported from
97 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
98 | """
99 |
100 | def __init__(self):
101 | self.reset()
102 |
103 | def reset(self):
104 | self.val = 0
105 | self.avg = 0 # running average = running sum / running count
106 | self.sum = 0 # running sum
107 | self.count = 0 # running count
108 |
109 | def update(self, val, n=1):
110 | # n = batch_size
111 |
112 | # val = batch accuracy for an attribute
113 | # self.val = val
114 |
115 | # sum = 100 * accumulative correct predictions for this attribute
116 | self.sum += val * n
117 |
118 | # count = total samples so far
119 | self.count += n
120 |
121 | # avg = 100 * avg accuracy for this attribute
122 | # for all the batches so far
123 | self.avg = self.sum / self.count
124 |
--------------------------------------------------------------------------------