├── .gitignore ├── LICENSE ├── README.md ├── assets ├── mask_framework.png └── mask_visual2.png ├── configs ├── stage1 │ └── mqvae_imagenet_f8_r30.yml └── stage2 │ ├── stackformer_imagenet_v12p12_class.yml │ └── stackformer_imagenet_v12p12_uncond.yml ├── data ├── build.py ├── data_utils.py ├── default.py ├── faceshq.py ├── ffhq_lmdb.py ├── imagenet.py └── imagenet_base.py ├── environment.yml ├── models ├── stage1 │ └── utils.py ├── stage1_masked │ └── mqvae.py ├── stage2 │ └── utils.py └── stage2_masked │ ├── stackformer_class.py │ └── stackformer_uncond.py ├── modules ├── diffusionmodules │ ├── attn_model.py │ └── model.py ├── discriminator │ ├── model.py │ ├── stylegan.py │ └── stylegan_lucidrains.py ├── losses │ ├── lpips.py │ ├── vqperceptual.py │ ├── vqperceptual_budget.py │ ├── vqperceptual_epoch.py │ └── vqperceptual_multidisc.py ├── lpips │ └── vgg.pth ├── masked_quantization │ ├── decoder.py │ ├── demasker_vanilla.py │ ├── masker_random.py │ ├── masker_vanilla.py │ ├── masker_vanilla_refine.py │ ├── permuter.py │ ├── tools.py │ └── vqperceptual_budget.py ├── masked_quantization_stage2 │ ├── permuter.py │ └── stackedgpt.py ├── scheduler │ ├── lr_scheduler.py │ └── scheduler.py ├── text_encoders │ ├── clip_text_encoder │ │ ├── base_embedding.py │ │ ├── clip │ │ │ ├── README.md │ │ │ ├── clip.py │ │ │ ├── clip_tokenizer.py │ │ │ ├── model.py │ │ │ └── simple_tokenizer.py │ │ ├── clip_text_embedding.py │ │ └── my_tokenizer │ │ │ ├── base_codec.py │ │ │ └── my_tokenize.py │ ├── modules.py │ └── x_transformers.py ├── tokenizers │ └── SimpleSampleTokenizer.py ├── transformer │ ├── hybrid_decoders.py │ ├── mask_attention.py │ ├── mask_attention_decoders.py │ ├── mingpt.py │ ├── mingpt_t2i.py │ ├── modules.py │ ├── permuter.py │ ├── position_aware_mingpt.py │ ├── position_embeddings.py │ ├── stacked_mingpt.py │ ├── vit.py │ └── vit_modules.py ├── vector_quantization │ ├── common_utils.py │ ├── quantize.py │ ├── quantize2.py │ ├── quantize2_list.py │ ├── quantize2_mask.py │ ├── quantize_codebook_mask.py │ ├── quantize_lucidrains.py │ ├── quantize_rqvae.py │ └── quantize_vqgan.py └── vqvae │ └── quantize2.py ├── train.py └── utils ├── logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | /logs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Corleone-Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MaskedVectorQuantization (CVPR2023) 2 | 3 | Offical PyTorch implementation of our CVPR2023 paper "[Not All Image Regions Matter: Masked Vector Quantization for Autoregressive Image Generation](https://openaccess.thecvf.com/content/CVPR2023/papers/Huang_Not_All_Image_Regions_Matter_Masked_Vector_Quantization_for_Autoregressive_CVPR_2023_paper.pdf)". 4 | 5 | **TL;DR** Existing vector-quantization (VQ) based autoregressive image generation simply models `all local region` information of images `without distinguishing their different perceptual importance` in the first stage, which brings redundancy in the learned codebook that not only limits the next stage’s autoregressive model’s ability to 6 | model important structure but also results in high training cost and slow generation speed. In this study, we borrow the idea of importance perception from classical image coding theory and propose a novel two-stage framework, which consists of Masked Quantization VAE (MQVAE) and Stackformer, to relieve the model from modeling redundancy. 7 | 8 | Our framework includes: (1) MQ-VAE incorporates an adaptive mask module for masking redundant region features before quantization and an adaptive de-mask module for recovering the original grid image feature map to faithfully reconstruct the original images after quantization. (2) Then, Stackformer learns to predict the combination 9 | of the next code and its position in the feature map. 10 | 11 | **See Our Another CVPR2023 Highlight Work about Vector-Quantization based Image Generation** "[Towards Accurate Image Coding: Improved Autoregressive Image Generation with Dynamic Vector Quantization](https://openaccess.thecvf.com/content/CVPR2023/papers/Huang_Towards_Accurate_Image_Coding_Improved_Autoregressive_Image_Generation_With_Dynamic_CVPR_2023_paper.pdf)" ([GitHub](https://github.com/CrossmodalGroup/DynamicVectorQuantization)) 12 | 13 | ![image](assets/mask_framework.png) 14 | 15 | # Requirements and Installation 16 | Please run the following command to install the necessary dependencies. 17 | 18 | ``` 19 | conda env create -f environment.yml 20 | ``` 21 | 22 | # Data Preparation 23 | Prepare dataset as follows, then change the corresponding datapath in `data/default.py`. 24 | 25 | ## ImageNet 26 | Prepare ImageNet dataset structure as follows: 27 | 28 | ``` 29 | ${Your Data Root Path}/ImageNet/ 30 | ├── train 31 | │ ├── n01440764 32 | │ | |── n01440764_10026.JPEG 33 | │ | |── n01440764_10027.JPEG 34 | │ | |── ... 35 | │ ├── n01443537 36 | │ | |── n01443537_2.JPEG 37 | │ | |── n01443537_16.JPEG 38 | │ | |── ... 39 | │ ├── ... 40 | ├── val 41 | │ ├── n01440764 42 | │ | |── ILSVRC2012_val_00000293.JPEG 43 | │ | |── ILSVRC2012_val_00002138.JPEG 44 | │ | |── ... 45 | │ ├── n01443537 46 | │ | |── ILSVRC2012_val_00000236.JPEG 47 | │ | |── ILSVRC2012_val_00000262.JPEG 48 | │ | |── ... 49 | │ ├── ... 50 | ├── imagenet_idx_to_synset.yml 51 | ├── synset_human.txt 52 | ``` 53 | 54 | ## FFHQ 55 | The FFHQ dataset could be obtained from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). Then prepare the dataset structure as follows: 56 | ``` 57 | ${Your Data Root Path}/FFHQ/ 58 | ├── assets 59 | │ ├── ffhqtrain.txt 60 | │ ├── ffhqvalidation.txt 61 | ├── FFHQ 62 | │ ├── 00000.png 63 | │ ├── 00001.png 64 | ``` 65 | 66 | # Training of MQVAE 67 | 68 | ``` 69 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --gpus -1 --base configs/stage1/mqvae_imagenet_f8_r25.yml --max_epochs 50 70 | ``` 71 | 72 | The mask ratio could be set in `model.params.masker_config.params.topk_ratio` (i.e., mask ratio = 1 - `model.params.masker_config.params.topk_ratio`). 73 | 74 | ## Visualization of Adaptive Maske Module 75 | ![image](assets/mask_visual2.png) 76 | 77 | # Training of Stackformer 78 | 79 | ## Unconditional Training: 80 | 81 | Copy the first stage DQVAE's config to `model.params.first_stage_config`. Set the pre-trained DQVAE's ckpt path to `model.params.first_stage_config.params.ckpt_path`. 82 | 83 | ``` 84 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --gpus -1 --base configs/stage2/stackformer_imagenet_v12p12_uncond.yml --max_epochs 50 85 | ``` 86 | 87 | NOTE: Some important hyper-parameters: 88 | - the layer of Code-Transformer: `model.params.transformer_config.params.value_layer` 89 | - the layer of Position-Transformer: `model.params.transformer_config.params.position_layer` 90 | 91 | ## Class-conditional Training 92 | 93 | ``` 94 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --gpus -1 --base configs/stage2/stackformer_imagenet_v12p12_class.yml --max_epochs 50 95 | ``` 96 | 97 | # Reference 98 | If you found this code useful, please cite the following paper: 99 | ``` 100 | @InProceedings{Huang_2023_CVPR, 101 | author = {Huang, Mengqi and Mao, Zhendong and Wang, Quan and Zhang, Yongdong}, 102 | title = {Not All Image Regions Matter: Masked Vector Quantization for Autoregressive Image Generation}, 103 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 104 | month = {June}, 105 | year = {2023}, 106 | pages = {2002-2011} 107 | } 108 | ``` 109 | 110 | ``` 111 | @InProceedings{Huang_2023_CVPR, 112 | author = {Huang, Mengqi and Mao, Zhendong and Chen, Zhuowei and Zhang, Yongdong}, 113 | title = {Towards Accurate Image Coding: Improved Autoregressive Image Generation With Dynamic Vector Quantization}, 114 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 115 | month = {June}, 116 | year = {2023}, 117 | pages = {22596-22605} 118 | } 119 | ``` -------------------------------------------------------------------------------- /assets/mask_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/MaskedVectorQuantization/f38a9ecbe5947d9c0a395c4f17aef356ad250eac/assets/mask_framework.png -------------------------------------------------------------------------------- /assets/mask_visual2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/MaskedVectorQuantization/f38a9ecbe5947d9c0a395c4f17aef356ad250eac/assets/mask_visual2.png -------------------------------------------------------------------------------- /configs/stage1/mqvae_imagenet_f8_r30.yml: -------------------------------------------------------------------------------- 1 | model: 2 | # learning_rate: 0.0001 3 | base_learning_rate: 4.5e-06 4 | min_learning_rate: 0.0 5 | target: models.stage1_masked.mqvae.MaskedVectorQuantizationModel 6 | params: 7 | encoder_config: 8 | target: modules.diffusionmodules.model.Encoder 9 | params: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: [1, 1, 2, 2] 17 | num_res_blocks: 2 18 | attn_resolutions: [32] 19 | dropout: 0.0 20 | decoder_config: 21 | target: modules.masked_quantization.decoder.Decoder 22 | params: 23 | ch: 128 24 | in_ch: 256 25 | out_ch: 3 26 | ch_mult: [1, 1, 2, 2] 27 | num_res_blocks: 2 28 | resolution: 256 29 | attn_resolutions: [32] 30 | dropout: 0.0 31 | resamp_with_conv: true 32 | give_pre_end: false 33 | masker_config: 34 | target: modules.masked_quantization.masker_vanilla_refine.VanillaMasker 35 | params: 36 | topk_ratio: 0.30 37 | input_token_num: 1024 38 | input_dim: 256 39 | patch_size: 8 40 | score_pred_net_mode: 2layer 41 | codebook_dim: 256 42 | demasker_config: 43 | target: modules.masked_quantization.demasker_vanilla.VanillaDemasker 44 | params: 45 | output_dim: 256 46 | codebook_dim: 256 47 | height_and_width: 32 48 | n_layer: 8 49 | mask_init_value: 0.02 50 | lossconfig: 51 | target: modules.losses.vqperceptual.VQLPIPSWithDiscriminator 52 | params: 53 | disc_conditional: false 54 | disc_in_channels: 3 55 | disc_start: 0 56 | disc_weight: 0.75 57 | codebook_weight: 1.0 58 | disc_weight_max: 1.0 59 | vqconfig: 60 | target: modules.vector_quantization.quantize2.VectorQuantize2 61 | params: 62 | codebook_size: 1024 63 | codebook_dim: 256 64 | channel_last: false 65 | accept_image_fmap: false 66 | commitment_beta: 0.25 67 | decay: 0.99 68 | restart_unused_codes: true 69 | commit_loss_legacy: true 70 | image_key: image 71 | monitor: val_rec_loss 72 | warmup_epochs: 0.1 73 | scheduler_type: linear-warmup_cosine-decay 74 | 75 | data: 76 | target: data.build.DataModuleFromConfig 77 | params: 78 | batch_size: 4 # 30 79 | num_workers: 8 80 | train: 81 | target: data.imagenet.ImageNetTrain 82 | params: 83 | config: 84 | is_eval: False 85 | size: 256 86 | validation: 87 | target: data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | is_eval: True 91 | size: 256 -------------------------------------------------------------------------------- /configs/stage2/stackformer_imagenet_v12p12_class.yml: -------------------------------------------------------------------------------- 1 | model: 2 | learning_rate: 0.0005 3 | min_learning_rate: 0.0 4 | target: models.stage2_masked.stackformer_class.ClassTransformer 5 | params: 6 | monitor: val_loss 7 | weight_decay: 0.01 8 | warmup_epochs: 0 9 | 10 | position_value_permuter_config: 11 | target: modules.masked_quantization_stage2.permuter.raster_scan_permuter 12 | 13 | transformer_config: 14 | target: modules.masked_quantization_stage2.stackedgpt.ReverseStackedPositionGPT 15 | params: 16 | vocab_size: 2024 # 1024 + 1000 17 | position_size: 1025 # 1 + 32x32 18 | block_size: 1024 # large enough 19 | position_layer: 12 20 | value_layer: 12 21 | n_head: 16 22 | n_embd: 1024 23 | embd_pdrop: 0.1 24 | resid_pdrop: 0.1 25 | attn_pdrop: 0.1 26 | add_absolute_position: true 27 | 28 | first_stage_config: 29 | target: models.stage1_masked.mqvae.MaskedVectorQuantizationModel 30 | params: 31 | ckpt_path: "Your ckpt path" 32 | encoder_config: 33 | target: modules.diffusionmodules.model.Encoder 34 | params: 35 | double_z: false 36 | z_channels: 256 37 | resolution: 256 38 | in_channels: 3 39 | out_ch: 3 40 | ch: 128 41 | ch_mult: [1, 1, 2, 2] 42 | num_res_blocks: 2 43 | attn_resolutions: [32] 44 | dropout: 0.0 45 | decoder_config: 46 | target: modules.masked_quantization.decoder.Decoder 47 | params: 48 | ch: 128 49 | in_ch: 256 50 | out_ch: 3 51 | ch_mult: [1, 1, 2, 2] 52 | num_res_blocks: 2 53 | resolution: 256 54 | attn_resolutions: [32] 55 | dropout: 0.0 56 | resamp_with_conv: true 57 | give_pre_end: false 58 | masker_config: 59 | target: modules.masked_quantization.masker_vanilla_refine.VanillaMasker 60 | params: 61 | topk_ratio: 0.25 62 | input_token_num: 1024 63 | input_dim: 256 64 | patch_size: 8 65 | score_pred_net_mode: 2layer 66 | codebook_dim: 256 67 | demasker_config: 68 | target: modules.masked_quantization.demasker_vanilla.VanillaDemasker 69 | params: 70 | output_dim: 256 71 | codebook_dim: 256 72 | height_and_width: 32 73 | n_layer: 8 74 | mask_init_value: 0.02 75 | vqconfig: 76 | target: modules.vector_quantization.quantize2.VectorQuantize2 77 | params: 78 | codebook_size: 1024 79 | codebook_dim: 256 80 | channel_last: false 81 | accept_image_fmap: false 82 | commitment_beta: 0.25 83 | decay: 0.99 84 | restart_unused_codes: true 85 | commit_loss_legacy: true 86 | lossconfig: 87 | target: modules.losses.vqperceptual.DummyLoss 88 | 89 | ignore_keys: [] 90 | first_stage_key: image 91 | cond_stage_key: class_label 92 | pkeep: 1.0 93 | n_classes: 1000 94 | sos_pos_token: 1024 95 | loss_position_weight: 1.0 96 | 97 | height_and_weight: 32 98 | add_absolute_position: True 99 | 100 | 101 | data: 102 | target: data.build.DataModuleFromConfig 103 | params: 104 | batch_size: 4 # 30 105 | num_workers: 8 106 | train: 107 | target: data.imagenet.ImageNetTrain 108 | params: 109 | config: 110 | is_eval: False 111 | size: 256 112 | validation: 113 | target: data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | is_eval: True 117 | size: 256 -------------------------------------------------------------------------------- /configs/stage2/stackformer_imagenet_v12p12_uncond.yml: -------------------------------------------------------------------------------- 1 | model: 2 | learning_rate: 0.0005 3 | min_learning_rate: 0.0 4 | target: models.stage2_masked.stackformer_uncond.ReverseStackformer 5 | params: 6 | monitor: val_loss 7 | weight_decay: 0.01 8 | warmup_epochs: 0 9 | 10 | position_value_permuter_config: 11 | target: modules.masked_quantization_stage2.permuter.raster_scan_permuter 12 | 13 | transformer_config: 14 | target: modules.masked_quantization_stage2.stackedgpt.ReverseStackedPositionGPT 15 | params: 16 | vocab_size: 1025 # 1024 + 1 17 | position_size: 1025 # 1 + 32x32 18 | block_size: 1024 # large enough 19 | position_layer: 12 20 | value_layer: 12 21 | n_head: 16 22 | n_embd: 1024 23 | embd_pdrop: 0.1 24 | resid_pdrop: 0.1 25 | attn_pdrop: 0.1 26 | add_absolute_position: true 27 | 28 | first_stage_config: 29 | target: models.stage1_masked.mqvae.MaskedVectorQuantizationModel 30 | params: 31 | ckpt_path: "Your ckpt path" 32 | encoder_config: 33 | target: modules.diffusionmodules.model.Encoder 34 | params: 35 | double_z: false 36 | z_channels: 256 37 | resolution: 256 38 | in_channels: 3 39 | out_ch: 3 40 | ch: 128 41 | ch_mult: [1, 1, 2, 2] 42 | num_res_blocks: 2 43 | attn_resolutions: [32] 44 | dropout: 0.0 45 | decoder_config: 46 | target: modules.masked_quantization.decoder.Decoder 47 | params: 48 | ch: 128 49 | in_ch: 256 50 | out_ch: 3 51 | ch_mult: [1, 1, 2, 2] 52 | num_res_blocks: 2 53 | resolution: 256 54 | attn_resolutions: [32] 55 | dropout: 0.0 56 | resamp_with_conv: true 57 | give_pre_end: false 58 | masker_config: 59 | target: modules.masked_quantization.masker_vanilla_refine.VanillaMasker 60 | params: 61 | topk_ratio: 0.25 62 | input_token_num: 1024 63 | input_dim: 256 64 | patch_size: 8 65 | score_pred_net_mode: 2layer 66 | codebook_dim: 256 67 | demasker_config: 68 | target: modules.masked_quantization.demasker_vanilla.VanillaDemasker 69 | params: 70 | output_dim: 256 71 | codebook_dim: 256 72 | height_and_width: 32 73 | n_layer: 8 74 | mask_init_value: 0.02 75 | vqconfig: 76 | target: modules.vector_quantization.quantize2.VectorQuantize2 77 | params: 78 | codebook_size: 1024 79 | codebook_dim: 256 80 | channel_last: false 81 | accept_image_fmap: false 82 | commitment_beta: 0.25 83 | decay: 0.99 84 | restart_unused_codes: true 85 | commit_loss_legacy: true 86 | lossconfig: 87 | target: modules.losses.vqperceptual.DummyLoss 88 | 89 | ignore_keys: [] 90 | first_stage_key: image 91 | pkeep: 1.0 92 | sos_token: 1024 93 | sos_pos_token: 1024 94 | loss_position_weight: 1.0 95 | 96 | height_and_weight: 32 97 | add_absolute_position: True 98 | 99 | 100 | data: 101 | target: data.build.DataModuleFromConfig 102 | params: 103 | batch_size: 4 # 30 104 | num_workers: 8 105 | train: 106 | target: data.imagenet.ImageNetTrain 107 | params: 108 | config: 109 | is_eval: False 110 | size: 256 111 | validation: 112 | target: data.imagenet.ImageNetValidation 113 | params: 114 | config: 115 | is_eval: True 116 | size: 256 -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | import pytorch_lightning as pl 3 | from utils.utils import instantiate_from_config 4 | 5 | class WrappedDataset(Dataset): 6 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 7 | def __init__(self, dataset): 8 | self.data = dataset 9 | 10 | def __len__(self): 11 | return len(self.data) 12 | 13 | def __getitem__(self, idx): 14 | return self.data[idx] 15 | 16 | class DataModuleFromConfig(pl.LightningDataModule): 17 | def __init__(self, batch_size, train=None, validation=None, test=None, 18 | wrap=False, num_workers=None, train_val=False): 19 | super().__init__() 20 | self.batch_size = batch_size 21 | self.train_val = train_val 22 | self.dataset_configs = dict() 23 | self.num_workers = num_workers if num_workers is not None else batch_size*2 24 | if train is not None: 25 | self.dataset_configs["train"] = train 26 | self.train_dataloader = self._train_dataloader 27 | if validation is not None: 28 | self.dataset_configs["validation"] = validation 29 | self.val_dataloader = self._val_dataloader 30 | if test is not None: 31 | self.dataset_configs["test"] = test 32 | self.test_dataloader = self._test_dataloader 33 | self.wrap = wrap 34 | 35 | # move setup here to avoid warning 36 | self.datasets = dict( 37 | (k, instantiate_from_config(self.dataset_configs[k])) 38 | for k in self.dataset_configs) 39 | if self.wrap: 40 | for k in self.datasets: 41 | self.datasets[k] = WrappedDataset(self.datasets[k]) 42 | if self.train_val: 43 | if "train" in self.datasets.keys() and "validation" in self.datasets.keys(): 44 | self.datasets["train"] = self.datasets["train"] + self.datasets["validation"] 45 | for k in self.datasets.keys(): 46 | print("dataset: ", k, len(self.datasets[k])) 47 | 48 | def prepare_data(self): 49 | for data_cfg in self.dataset_configs.values(): 50 | print("instantiate from: ", data_cfg) 51 | instantiate_from_config(data_cfg) 52 | 53 | # def setup(self, stage=None): 54 | # self.datasets = dict( 55 | # (k, instantiate_from_config(self.dataset_configs[k])) 56 | # for k in self.dataset_configs) 57 | # if self.wrap: 58 | # for k in self.datasets: 59 | # self.datasets[k] = WrappedDataset(self.datasets[k]) 60 | # if self.train_val: 61 | # if "train" in self.datasets.keys() and "validation" in self.datasets.keys(): 62 | # self.datasets["train"] = self.datasets["train"] + self.datasets["validation"] 63 | # for k in self.datasets.keys(): 64 | # print("dataset: ", k, len(self.datasets[k])) 65 | 66 | def _train_dataloader(self): 67 | if hasattr(self.datasets["train"], "collate_fn"): 68 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 69 | num_workers=self.num_workers, shuffle=True, 70 | collate_fn=self.datasets["train"].collate_fn) 71 | else: 72 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 73 | num_workers=self.num_workers, shuffle=True) 74 | 75 | 76 | def _val_dataloader(self): 77 | if hasattr(self.datasets['validation'], "collate_fn"): 78 | return DataLoader(self.datasets["validation"], batch_size=self.batch_size, 79 | num_workers=self.num_workers, collate_fn=self.datasets["validation"].collate_fn) 80 | else: 81 | return DataLoader(self.datasets["validation"], batch_size=self.batch_size, 82 | num_workers=self.num_workers) 83 | 84 | def _test_dataloader(self): 85 | if hasattr(self.datasets["test"], "collate_fn"): 86 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 87 | num_workers=self.num_workers, collate_fn=self.datasets["test"].collate_fn) 88 | else: 89 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 90 | num_workers=self.num_workers) -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | import bisect 4 | import numpy as np 5 | import albumentations 6 | from PIL import Image 7 | from torch.utils.data import Dataset, ConcatDataset 8 | 9 | # data/utils 10 | def mark_prepared(root): 11 | Path(root).joinpath(".ready").touch() 12 | 13 | def is_prepared(root): 14 | return Path(root).joinpath(".ready").exists() 15 | 16 | def instantiate_from_config(config): 17 | if not "target" in config: 18 | raise KeyError("Expected key `target` to instantiate.") 19 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 20 | 21 | def get_obj_from_str(string, reload=False): 22 | module, cls = string.rsplit(".", 1) 23 | if reload: 24 | module_imp = importlib.import_module(module) 25 | importlib.reload(module_imp) 26 | return getattr(importlib.import_module(module, package=None), cls) 27 | 28 | # utils 29 | class KeyNotFoundError(Exception): 30 | def __init__(self, cause, keys=None, visited=None): 31 | self.cause = cause 32 | self.keys = keys 33 | self.visited = visited 34 | messages = list() 35 | if keys is not None: 36 | messages.append("Key not found: {}".format(keys)) 37 | if visited is not None: 38 | messages.append("Visited: {}".format(visited)) 39 | messages.append("Cause:\n{}".format(cause)) 40 | message = "\n".join(messages) 41 | super().__init__(message) 42 | 43 | def retrieve( 44 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 45 | ): 46 | """Given a nested list or dict return the desired value at key expanding 47 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 48 | is done in-place. 49 | 50 | Parameters 51 | ---------- 52 | list_or_dict : list or dict 53 | Possibly nested list or dictionary. 54 | key : str 55 | key/to/value, path like string describing all keys necessary to 56 | consider to get to the desired value. List indices can also be 57 | passed here. 58 | splitval : str 59 | String that defines the delimiter between keys of the 60 | different depth levels in `key`. 61 | default : obj 62 | Value returned if :attr:`key` is not found. 63 | expand : bool 64 | Whether to expand callable nodes on the path or not. 65 | 66 | Returns 67 | ------- 68 | The desired value or if :attr:`default` is not ``None`` and the 69 | :attr:`key` is not found returns ``default``. 70 | 71 | Raises 72 | ------ 73 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 74 | ``None``. 75 | """ 76 | 77 | keys = key.split(splitval) 78 | 79 | success = True 80 | try: 81 | visited = [] 82 | parent = None 83 | last_key = None 84 | for key in keys: 85 | if callable(list_or_dict): 86 | if not expand: 87 | raise KeyNotFoundError( 88 | ValueError( 89 | "Trying to get past callable node with expand=False." 90 | ), 91 | keys=keys, 92 | visited=visited, 93 | ) 94 | list_or_dict = list_or_dict() 95 | parent[last_key] = list_or_dict 96 | 97 | last_key = key 98 | parent = list_or_dict 99 | 100 | try: 101 | if isinstance(list_or_dict, dict): 102 | list_or_dict = list_or_dict[key] 103 | else: 104 | list_or_dict = list_or_dict[int(key)] 105 | except (KeyError, IndexError, ValueError) as e: 106 | raise KeyNotFoundError(e, keys=keys, visited=visited) 107 | 108 | visited += [key] 109 | # final expansion of retrieved value 110 | if expand and callable(list_or_dict): 111 | list_or_dict = list_or_dict() 112 | parent[last_key] = list_or_dict 113 | except KeyNotFoundError as e: 114 | if default is None: 115 | raise e 116 | else: 117 | list_or_dict = default 118 | success = False 119 | 120 | if not pass_success: 121 | return list_or_dict 122 | else: 123 | return list_or_dict, success 124 | 125 | class ConcatDatasetWithIndex(ConcatDataset): 126 | """Modified from original pytorch code to return dataset idx""" 127 | def __getitem__(self, idx): 128 | if idx < 0: 129 | if -idx > len(self): 130 | raise ValueError("absolute value of index should not exceed dataset length") 131 | idx = len(self) + idx 132 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 133 | if dataset_idx == 0: 134 | sample_idx = idx 135 | else: 136 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 137 | return self.datasets[dataset_idx][sample_idx], dataset_idx 138 | 139 | 140 | class ImagePaths(Dataset): 141 | def __init__(self, paths, size=None, random_crop=False, labels=None): 142 | self.size = size 143 | self.random_crop = random_crop 144 | 145 | self.labels = dict() if labels is None else labels 146 | self.labels["file_path_"] = paths 147 | self._length = len(paths) 148 | 149 | if self.size is not None and self.size > 0: 150 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 151 | if not self.random_crop: 152 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 153 | else: 154 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 155 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 156 | else: 157 | self.preprocessor = lambda **kwargs: kwargs 158 | 159 | def __len__(self): 160 | return self._length 161 | 162 | def preprocess_image(self, image_path): 163 | image = Image.open(image_path) 164 | if not image.mode == "RGB": 165 | image = image.convert("RGB") 166 | image = np.array(image).astype(np.uint8) 167 | image = self.preprocessor(image=image)["image"] 168 | image = (image/127.5 - 1.0).astype(np.float32) 169 | return image 170 | 171 | def __getitem__(self, i): 172 | example = dict() 173 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 174 | for k in self.labels: 175 | example[k] = self.labels[k][i] 176 | return example 177 | 178 | 179 | class NumpyPaths(ImagePaths): 180 | def preprocess_image(self, image_path): 181 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 182 | image = np.transpose(image, (1,2,0)) 183 | image = Image.fromarray(image, mode="RGB") 184 | image = np.array(image).astype(np.uint8) 185 | image = self.preprocessor(image=image)["image"] 186 | image = (image/127.5 - 1.0).astype(np.float32) 187 | return image -------------------------------------------------------------------------------- /data/default.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | DefaultDataPath = edict() 4 | 5 | DefaultDataPath.ImageNet = edict() 6 | 7 | DefaultDataPath.ImageNet.root = "Your Data Path/Datasets/ImageNet" 8 | DefaultDataPath.ImageNet.train_write_root = "Your Data Path/Datasets/ImageNet/train" 9 | DefaultDataPath.ImageNet.val_write_root = "Your Data Path/Datasets/ImageNet/val" 10 | 11 | # DefaultDataPath.ImageNet.root = "/home/huangmq/Datasets/ImageNet" 12 | # DefaultDataPath.ImageNet.train_write_root = "/home/huangmq/Datasets/ImageNet/train" 13 | # DefaultDataPath.ImageNet.val_write_root = "/home/huangmq/Datasets/ImageNet/val" -------------------------------------------------------------------------------- /data/faceshq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-present, Kakao Brain Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os, sys 16 | sys.path.append(os.getcwd()) 17 | from data.default import DefaultDataPath 18 | from data.data_utils import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 19 | from data.default import DefaultDataPath 20 | 21 | from pathlib import Path 22 | import torchvision 23 | import torchvision.transforms as transforms 24 | import numpy as np 25 | import albumentations 26 | import glob 27 | from torch.utils.data import Dataset 28 | 29 | class ImageFolder(torchvision.datasets.VisionDataset): 30 | 31 | def __init__(self, root, train_list_file, val_list_file, 32 | split='train', resolution=256, is_eval=False, **kwargs): 33 | 34 | root = Path(root) 35 | super().__init__(root, **kwargs) 36 | 37 | self.train_list_file = train_list_file 38 | self.val_list_file = val_list_file 39 | 40 | self.split = self._verify_split(split) 41 | 42 | self.loader = torchvision.datasets.folder.default_loader 43 | self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS 44 | 45 | if self.split == 'trainval': 46 | fname_list = os.listdir(self.root) 47 | samples = [self.root.joinpath(fname) for fname in fname_list 48 | if fname.lower().endswith(self.extensions)] 49 | else: 50 | listfile = self.train_list_file if self.split == 'train' else self.val_list_file 51 | with open(listfile, 'r') as f: 52 | samples = [self.root.joinpath(line.strip()) for line in f.readlines()] 53 | 54 | self.samples = samples 55 | 56 | if split == "train" and not is_eval: 57 | transforms_ = [ 58 | transforms.RandomResizedCrop(resolution, scale=(0.75, 1.0), ratio=(1.0, 1.0)), 59 | transforms.RandomHorizontalFlip(p=0.5), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 62 | ] 63 | else: 64 | transforms_ = [ 65 | transforms.Resize(resolution), 66 | transforms.CenterCrop(resolution), 67 | transforms.ToTensor(), 68 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 69 | ] 70 | self.transforms = transforms.Compose(transforms_) 71 | 72 | def _verify_split(self, split): 73 | if split not in self.valid_splits: 74 | msg = "Unknown split {} .".format(split) 75 | msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) 76 | raise ValueError(msg) 77 | return split 78 | 79 | @property 80 | def valid_splits(self): 81 | return 'train', 'val', 'trainval' 82 | 83 | def __len__(self): 84 | return len(self.samples) 85 | 86 | def __getitem__(self, index, with_transform=True): 87 | path = self.samples[index] 88 | sample = self.loader(path) 89 | if self.transforms is not None and with_transform: 90 | sample = self.transforms(sample) 91 | return { 92 | "image": sample 93 | } 94 | 95 | class FFHQ(ImageFolder): 96 | root = DefaultDataPath.FFHQ.root 97 | train_list_file = os.path.join(root, "assets/ffhqtrain.txt") 98 | val_list_file = os.path.join(root, "assets/ffhqvalidation.txt") 99 | 100 | def __init__(self, split='train', resolution=256, is_eval=False, **kwargs): 101 | super().__init__(FFHQ.root, FFHQ.train_list_file, FFHQ.val_list_file, split, resolution, is_eval, **kwargs) 102 | 103 | class FacesBase(Dataset): 104 | def __init__(self, *args, **kwargs): 105 | super().__init__() 106 | self.data = None 107 | self.keys = None 108 | 109 | def __len__(self): 110 | return len(self.data) 111 | 112 | def __getitem__(self, i): 113 | example = self.data[i] 114 | ex = {} 115 | if self.keys is not None: 116 | for k in self.keys: 117 | ex[k] = example[k] 118 | else: 119 | ex = example 120 | return ex 121 | 122 | class CelebAHQTrain(FacesBase): 123 | def __init__(self, size): 124 | super().__init__() 125 | glob_pattern = os.path.join(DefaultDataPath.CelebAHQ.root, 'train/images', '*.jpg') 126 | paths = sorted(glob.glob(glob_pattern)) 127 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 128 | self.keys = None 129 | 130 | transforms_ = [torchvision.transforms.ToTensor(),] 131 | self.transforms = torchvision.transforms.Compose(transforms_) 132 | 133 | def __getitem__(self, i): 134 | example = self.data[i] 135 | example["image"] = self.transforms(example["image"]) 136 | return example 137 | 138 | class CelebAHQValidation(FacesBase): 139 | def __init__(self, size): 140 | super().__init__() 141 | glob_pattern = os.path.join(DefaultDataPath.CelebAHQ.root, 'test/images', '*.jpg') 142 | paths = sorted(glob.glob(glob_pattern)) 143 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 144 | self.keys = None 145 | 146 | transforms_ = [torchvision.transforms.ToTensor(),] 147 | self.transforms = torchvision.transforms.Compose(transforms_) 148 | 149 | def __getitem__(self, i): 150 | example = self.data[i] 151 | example["image"] = self.transforms(example["image"]) 152 | return example 153 | 154 | 155 | class FacesHQTrain(Dataset): 156 | def __init__(self, size, is_eval=False): 157 | d1 = CelebAHQTrain(size=size) 158 | d2 = FFHQ(split='train', resolution=size, is_eval=is_eval) 159 | self.data = ConcatDatasetWithIndex([d1, d2]) 160 | 161 | def __len__(self): 162 | return len(self.data) 163 | 164 | def __getitem__(self, i): 165 | image = self.data[i][0]["image"] 166 | return {"image": image} 167 | 168 | class FacesHQValidation(Dataset): 169 | def __init__(self, size, is_eval=False): 170 | d1 = CelebAHQValidation(size=size) 171 | d2 = FFHQ(split="val", resolution=size, is_eval=is_eval) 172 | self.data = ConcatDatasetWithIndex([d1, d2]) 173 | 174 | def __len__(self): 175 | return len(self.data) 176 | 177 | def __getitem__(self, i): 178 | image = self.data[i][0]["image"] 179 | return {"image": image} 180 | 181 | class FacesHQ(Dataset): 182 | def __init__(self, size, is_eval=False): 183 | d1 = CelebAHQTrain(size=size) 184 | d2 = FFHQ(split='train', resolution=size, is_eval=is_eval) 185 | d3 = CelebAHQValidation(size=size) 186 | d4 = FFHQ(split="val", resolution=size, is_eval=is_eval) 187 | 188 | self.data = ConcatDatasetWithIndex([d1, d2, d3, d4]) 189 | 190 | def __len__(self): 191 | return len(self.data) 192 | 193 | def __getitem__(self, i): 194 | image = self.data[i][0]["image"] 195 | return {"image": image} 196 | 197 | if __name__ == "__main__": 198 | dataset = FFHQ(split='train', resolution=256, is_eval=False) 199 | dataset_val = FFHQ(split='val', resolution=256, is_eval=False) 200 | 201 | # celebahq = CelebAHQTrain(size=256) 202 | # celebahq_val = CelebAHQValidation(size=256) 203 | # out = celebahq.__getitem__(0) 204 | 205 | print(len(dataset)) 206 | print(len(dataset_val)) 207 | # print(len(celebahq)) 208 | # print(len(celebahq_val)) 209 | 210 | # facehq = FacesHQTrain(size=256) 211 | # facehq_val = FacesHQValidation(size=256) 212 | # facehq_all = FacesHQ(size=256) 213 | # out = facehq.__getitem__(0) 214 | # torchvision.utils.save_image(out["image"], "facehq.png", normalize=True) 215 | 216 | # print(len(facehq), len(facehq_val), len(facehq_all)) -------------------------------------------------------------------------------- /data/ffhq_lmdb.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import pickle 4 | import string 5 | from pathlib import Path 6 | 7 | import lmdb 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | 12 | import os, sys 13 | sys.path.append(os.getcwd()) 14 | from data.default import DefaultDataPath 15 | 16 | class FFHQ_LMDB(torchvision.datasets.VisionDataset): 17 | 18 | def __init__(self, split="train", resolution=256, is_eval=False, **kwargs): 19 | 20 | if split == "train": 21 | lmdb_path = DefaultDataPath.FFHQ.train_lmdb 22 | elif split == "val": 23 | lmdb_path = DefaultDataPath.FFHQ.val_lmdb 24 | else: 25 | raise ValueError() 26 | 27 | root = str(Path(lmdb_path)) 28 | super().__init__(root, **kwargs) 29 | 30 | self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) 31 | with self.env.begin(write=False) as txn: 32 | self.length = txn.stat()["entries"] 33 | 34 | cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) 35 | cache_file = os.path.join(root, cache_file) 36 | if os.path.isfile(cache_file): 37 | self.keys = pickle.load(open(cache_file, "rb")) 38 | else: 39 | with self.env.begin(write=False) as txn: 40 | self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)] 41 | pickle.dump(self.keys, open(cache_file, "wb")) 42 | 43 | if split == "train" and not is_eval: 44 | transforms_ = [ 45 | transforms.RandomResizedCrop(resolution, scale=(0.75, 1.0), ratio=(1.0, 1.0)), 46 | transforms.RandomHorizontalFlip(p=0.5), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 49 | ] 50 | else: 51 | transforms_ = [ 52 | transforms.Resize(resolution), 53 | transforms.CenterCrop(resolution), 54 | transforms.ToTensor(), 55 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 56 | ] 57 | self.transforms = transforms.Compose(transforms_) 58 | 59 | def __getitem__(self, index: int): 60 | env = self.env 61 | with env.begin(write=False) as txn: 62 | imgbuf = txn.get(self.keys[index]) 63 | 64 | buf = io.BytesIO() 65 | buf.write(imgbuf) 66 | buf.seek(0) 67 | img = Image.open(buf).convert("RGB") 68 | 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | 72 | return { 73 | "image": img 74 | } 75 | 76 | def __len__(self): 77 | return self.length 78 | 79 | 80 | if __name__ == "__main__": 81 | dataset = FFHQ_LMDB(split='train', resolution=256, is_eval=False) 82 | dataset_val = FFHQ_LMDB(split='val', resolution=256, is_eval=False) 83 | 84 | print(len(dataset)) 85 | print(len(dataset_val)) 86 | 87 | # sample = dataset.__getitem__(0) 88 | 89 | # torchvision.utils.save_image(sample["image"], "sample_ffhq.png", normalize=True) 90 | -------------------------------------------------------------------------------- /data/imagenet_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import albumentations 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | 7 | class ImagePaths(Dataset): 8 | def __init__(self, split, is_val, paths, size=None, random_crop=False, labels=None): 9 | self.size = size 10 | self.random_crop = random_crop 11 | 12 | self.labels = dict() if labels is None else labels 13 | self.labels["file_path_"] = paths 14 | self._length = len(paths) 15 | 16 | if split == "train" and not is_val: 17 | transforms_ = [ 18 | transforms.Resize(256), 19 | transforms.RandomCrop(256), 20 | transforms.RandomHorizontalFlip(p=0.5), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 23 | ] 24 | else: 25 | transforms_ = [ 26 | transforms.Resize(256), 27 | transforms.CenterCrop(256), 28 | transforms.Resize((256, 256)), 29 | transforms.ToTensor(), 30 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 31 | ] 32 | self.transforms = transforms.Compose(transforms_) 33 | 34 | # if self.size is not None and self.size > 0: 35 | # self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 36 | # if not self.random_crop: 37 | # self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 38 | # else: 39 | # self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 40 | # self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 41 | # else: 42 | # self.preprocessor = lambda **kwargs: kwargs 43 | 44 | def __len__(self): 45 | return self._length 46 | 47 | def preprocess_image(self, image_path): 48 | image = Image.open(image_path) 49 | if not image.mode == "RGB": 50 | image = image.convert("RGB") 51 | # image = np.array(image).astype(np.uint8) 52 | # we replace the original taming version image preprocess 53 | # with the one in RQVAE 54 | # image = self.preprocessor(image=image)["image"] 55 | # image = (image/127.5 - 1.0).astype(np.float32) 56 | 57 | image = self.transforms(image) 58 | return image 59 | 60 | def __getitem__(self, i): 61 | example = dict() 62 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 63 | for k in self.labels: 64 | example[k] = self.labels[k][i] 65 | return example -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vq 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - bzip2=1.0.8=h7b6447c_0 7 | - ca-certificates=2022.10.11=h06a4308_0 8 | - certifi=2022.12.7=py310h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - libuuid=1.0.3=h7f8727e_2 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1s=h7f8727e_0 16 | - pip=22.3.1=py310h06a4308_0 17 | - python=3.10.4=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=65.6.3=py310h06a4308_0 20 | - sqlite=3.38.5=hc218d9a_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - tzdata=2022g=h04d1e81_0 23 | - wheel=0.37.1=pyhd3eb1b0_0 24 | - xz=5.2.5=h7f8727e_1 25 | - zlib=1.2.12=h7f8727e_2 26 | - pip: 27 | - absl-py==1.3.0 28 | - academictorrents==2.3.3 29 | - aiohttp==3.8.3 30 | - aiosignal==1.3.1 31 | - albumentations==1.3.0 32 | - aniso8601==9.0.1 33 | - antlr4-python3-runtime==4.9.3 34 | - anyio==3.6.2 35 | - argon2-cffi==21.3.0 36 | - argon2-cffi-bindings==21.2.0 37 | - arrow==1.2.3 38 | - astroid==2.14.1 39 | - asttokens==2.2.1 40 | - astunparse==1.6.3 41 | - async-timeout==4.0.2 42 | - attrs==22.2.0 43 | - autofaiss==2.15.5 44 | - backcall==0.2.0 45 | - beautifulsoup4==4.11.2 46 | - bencode.py==2.0.0 47 | - bitstring==3.1.5 48 | - bleach==6.0.0 49 | - braceexpand==0.1.7 50 | - cachetools==5.2.1 51 | - cffi==1.15.1 52 | - charset-normalizer==2.1.1 53 | - click==8.1.3 54 | - clip-anytorch==2.5.0 55 | - clip-retrieval==2.36.1 56 | - comm==0.1.2 57 | - contourpy==1.0.6 58 | - cycler==0.11.0 59 | - dataclasses==0.6 60 | - debugpy==1.6.6 61 | - decorator==5.1.1 62 | - defusedxml==0.7.1 63 | - diffusers==0.16.1 64 | - dill==0.3.6 65 | - docker-pycreds==0.4.0 66 | - easydict==1.10 67 | - einops==0.6.0 68 | - embedding-reader==1.5.0 69 | - executing==1.2.0 70 | - exifread-nocycle==3.0.1 71 | - faiss-cpu==1.7.3 72 | - fastjsonschema==2.16.3 73 | - filelock==3.9.0 74 | - fire==0.4.0 75 | - flask==2.2.2 76 | - flask-cors==3.0.10 77 | - flask-restful==0.3.9 78 | - flatbuffers==23.3.3 79 | - fonttools==4.38.0 80 | - fqdn==1.5.1 81 | - frozenlist==1.3.3 82 | - fsspec==2022.11.0 83 | - ftfy==6.1.1 84 | - future==0.16.0 85 | - gast==0.4.0 86 | - gitdb==4.0.10 87 | - gitpython==3.1.30 88 | - google-auth==2.16.0 89 | - google-auth-oauthlib==0.4.6 90 | - google-pasta==0.2.0 91 | - grpcio==1.51.1 92 | - h5py==3.8.0 93 | - huggingface-hub==0.14.1 94 | - idna==3.4 95 | - imageio==2.25.1 96 | - img2dataset==1.41.0 97 | - importlib-metadata==6.6.0 98 | - ipykernel==6.21.3 99 | - ipython==8.10.0 100 | - ipython-genutils==0.2.0 101 | - ipywidgets==8.0.4 102 | - isoduration==20.11.0 103 | - isort==5.12.0 104 | - itsdangerous==2.1.2 105 | - jedi==0.18.2 106 | - jinja2==3.1.2 107 | - joblib==1.2.0 108 | - jsonpointer==2.3 109 | - jsonschema==4.17.3 110 | - jupyter==1.0.0 111 | - jupyter-console==6.6.3 112 | - jupyter-events==0.6.3 113 | - jupyter_client==8.0.3 114 | - jupyter_core==5.3.0 115 | - jupyter_server==2.5.0 116 | - jupyter_server_terminals==0.4.4 117 | - jupyterlab-pygments==0.2.2 118 | - jupyterlab-widgets==3.0.5 119 | - keras==2.11.0 120 | - kiwisolver==1.4.4 121 | - kornia==0.6.9 122 | - lazy-object-proxy==1.9.0 123 | - libclang==15.0.6.1 124 | - lmdb==1.4.0 125 | - lpips==0.1.4 126 | - markdown==3.4.1 127 | - markupsafe==2.1.1 128 | - matplotlib==3.6.2 129 | - matplotlib-inline==0.1.6 130 | - mccabe==0.7.0 131 | - mistune==2.0.5 132 | - multidict==6.0.4 133 | - multilingual-clip==1.0.10 134 | - nbclassic==0.5.3 135 | - nbclient==0.7.2 136 | - nbconvert==7.2.10 137 | - nbformat==5.7.3 138 | - nest-asyncio==1.5.6 139 | - networkx==3.0 140 | - nltk==3.8.1 141 | - notebook==6.5.3 142 | - notebook_shim==0.2.2 143 | - numpy==1.24.1 144 | - nvidia-cublas-cu11==11.10.3.66 145 | - nvidia-cuda-nvrtc-cu11==11.7.99 146 | - nvidia-cuda-runtime-cu11==11.7.99 147 | - nvidia-cudnn-cu11==8.5.0.96 148 | - oauthlib==3.2.2 149 | - omegaconf==2.3.0 150 | - open-clip-torch==2.13.0 151 | - opencv-python==4.7.0.68 152 | - opencv-python-headless==4.7.0.68 153 | - opt-einsum==3.3.0 154 | - packaging==23.0 155 | - pandas==1.5.3 156 | - pandocfilters==1.5.0 157 | - parso==0.8.3 158 | - pathtools==0.1.2 159 | - pexpect==4.8.0 160 | - pickleshare==0.7.5 161 | - pillow==9.4.0 162 | - platformdirs==3.0.0 163 | - prometheus-client==0.16.0 164 | - promise==2.3 165 | - prompt-toolkit==3.0.36 166 | - protobuf==3.19.6 167 | - psutil==5.9.4 168 | - ptyprocess==0.7.0 169 | - pure-eval==0.2.2 170 | - pyarrow==7.0.0 171 | - pyasn1==0.4.8 172 | - pyasn1-modules==0.2.8 173 | - pycocotools==2.0.6 174 | - pycparser==2.21 175 | - pydeprecate==0.3.1 176 | - pygments==2.14.0 177 | - pylint==2.16.1 178 | - pyparsing==3.0.9 179 | - pypubsub==3.3.0 180 | - pyrsistent==0.19.3 181 | - python-dateutil==2.8.2 182 | - python-json-logger==2.0.7 183 | - python-version==0.0.2 184 | - pytorch-lightning==1.5.6 185 | - pytz==2022.7 186 | - pywavelets==1.4.1 187 | - pyyaml==6.0 188 | - pyzmq==25.0.1 189 | - qtconsole==5.4.1 190 | - qtpy==2.3.0 191 | - qudida==0.0.4 192 | - regex==2022.10.31 193 | - requests==2.28.1 194 | - requests-oauthlib==1.3.1 195 | - rfc3339-validator==0.1.4 196 | - rfc3986-validator==0.1.1 197 | - rsa==4.9 198 | - scikit-image==0.19.3 199 | - scikit-learn==1.2.1 200 | - scipy==1.10.0 201 | - seaborn==0.12.2 202 | - send2trash==1.8.0 203 | - sentence-transformers==2.2.2 204 | - sentencepiece==0.1.97 205 | - sentry-sdk==1.12.1 206 | - setproctitle==1.3.2 207 | - shortuuid==1.0.11 208 | - six==1.16.0 209 | - sklearn==0.0.post1 210 | - smmap==5.0.0 211 | - sniffio==1.3.0 212 | - soupsieve==2.4 213 | - stack-data==0.6.2 214 | - tensorboard==2.11.0 215 | - tensorboard-data-server==0.6.1 216 | - tensorboard-plugin-wit==1.8.1 217 | - tensorflow-estimator==2.11.0 218 | - tensorflow-io-gcs-filesystem==0.31.0 219 | - termcolor==2.2.0 220 | - terminado==0.17.1 221 | - test-tube==0.7.5 222 | - threadpoolctl==3.1.0 223 | - tifffile==2023.2.3 224 | - timm==0.6.12 225 | - tinycss2==1.2.1 226 | - tokenizers==0.13.2 227 | - tomli==2.0.1 228 | - tomlkit==0.11.6 229 | - torch==1.13.1 230 | - torchmetrics==0.11.0 231 | - torchvision==0.14.1 232 | - tornado==6.2 233 | - tqdm==4.64.1 234 | - traitlets==5.9.0 235 | - transformers==4.26.1 236 | - typing_extensions==4.4.0 237 | - uri-template==1.2.0 238 | - urllib3==1.26.13 239 | - wandb==0.12.21 240 | - wcwidth==0.2.5 241 | - webcolors==1.12 242 | - webdataset==0.2.33 243 | - webencodings==0.5.1 244 | - websocket-client==1.5.1 245 | - werkzeug==2.2.2 246 | - widgetsnbextension==4.0.5 247 | - wrapt==1.14.1 248 | - yarl==1.8.2 249 | - zipp==3.15.0 250 | prefix: /home/huangmq/anaconda3/envs/vq 251 | 252 | -------------------------------------------------------------------------------- /models/stage1/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn_LinearWarmup(warmup_steps, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: 10 | return 1.0 11 | 12 | def Scheduler_LinearWarmup(warmup_steps): 13 | return partial(fn_LinearWarmup, warmup_steps) 14 | 15 | 16 | def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step): 17 | if step < warmup_steps: # linear warmup 18 | return float(step) / float(max(1, warmup_steps)) 19 | else: # cosine learning rate schedule 20 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 21 | return max(multipler, multipler_min) 22 | 23 | def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min): 24 | return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min) 25 | -------------------------------------------------------------------------------- /models/stage2/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn(warmup_steps, max_steps, multipler_min, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: # cosine learning rate schedule 10 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 11 | return max(multipler, multipler_min) 12 | 13 | def learning_rate_schedule(warmup_steps, max_steps, multipler_min): 14 | return partial(fn, warmup_steps, max_steps, multipler_min) 15 | 16 | def disabled_train(self, mode=True): 17 | """Overwrite model.train with this function to make sure train/eval mode 18 | does not change anymore.""" 19 | return self 20 | 21 | # commonly used sample functions 22 | def top_k_logits(logits, k): 23 | v, ix = torch.topk(logits, k) 24 | out = logits.clone() 25 | out[out < v[..., [-1]]] = -float('Inf') 26 | return out 27 | 28 | def top_p_logits(probs, p): 29 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 30 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 31 | 32 | sorted_idx_remove_cond = cum_probs >= p 33 | 34 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 35 | sorted_idx_remove_cond[..., 0] = 0 36 | 37 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 38 | probs = probs.masked_fill(indices_to_remove, 0.0) 39 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) 40 | return norm_probs -------------------------------------------------------------------------------- /modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from utils.utils import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /modules/discriminator/stylegan_lucidrains.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def default(val, d): 5 | return val if val is not None else d 6 | 7 | class ChanLayerNorm(nn.Module): 8 | def __init__( 9 | self, 10 | dim, 11 | eps = 1e-5 12 | ): 13 | super().__init__() 14 | self.eps = eps 15 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) 16 | 17 | def forward(self, x): 18 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 19 | mean = torch.mean(x, dim = 1, keepdim = True) 20 | return (x - mean) * (var + self.eps).rsqrt() * self.gamma 21 | 22 | class CrossEmbedLayer(nn.Module): 23 | def __init__( 24 | self, 25 | dim_in, 26 | kernel_sizes, 27 | dim_out = None, 28 | stride = 2 29 | ): 30 | super().__init__() 31 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) 32 | dim_out = default(dim_out, dim_in) 33 | 34 | kernel_sizes = sorted(kernel_sizes) 35 | num_scales = len(kernel_sizes) 36 | 37 | # calculate the dimension at each scale 38 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] 39 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] 40 | 41 | self.convs = nn.ModuleList([]) 42 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 43 | self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) 44 | 45 | def forward(self, x): 46 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 47 | return torch.cat(fmaps, dim = 1) 48 | 49 | class Block(nn.Module): 50 | def __init__( 51 | self, 52 | dim, 53 | dim_out, 54 | groups = 8 55 | ): 56 | super().__init__() 57 | self.groupnorm = nn.GroupNorm(groups, dim) 58 | self.activation = nn.LeakyReLU(0.1) 59 | self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) 60 | 61 | def forward(self, x, scale_shift = None): 62 | x = self.groupnorm(x) 63 | x = self.activation(x) 64 | return self.project(x) 65 | 66 | class ResnetBlock(nn.Module): 67 | def __init__( 68 | self, 69 | dim, 70 | dim_out = None, 71 | *, 72 | groups = 8 73 | ): 74 | super().__init__() 75 | dim_out = default(dim_out, dim) 76 | self.block = Block(dim, dim_out, groups = groups) 77 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 78 | 79 | def forward(self, x): 80 | h = self.block(x) 81 | return h + self.res_conv(x) 82 | 83 | # discriminator 84 | 85 | class Discriminator(nn.Module): 86 | def __init__( 87 | self, 88 | dim = 256, 89 | discr_layers = 6, 90 | channels = 3, 91 | groups = 8, 92 | cross_embed_kernel_sizes = (3, 7, 15) 93 | ): 94 | super().__init__() 95 | 96 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) 97 | layer_dims = [dim * mult for mult in layer_mults] 98 | dims = (dim, *layer_dims) 99 | 100 | init_dim, *_, final_dim = dims 101 | dim_pairs = zip(dims[:-1], dims[1:]) 102 | 103 | self.layers = nn.ModuleList([nn.Sequential( 104 | CrossEmbedLayer(channels, cross_embed_kernel_sizes, init_dim, stride = 1), 105 | nn.LeakyReLU(0.1) 106 | )]) 107 | 108 | for dim_in, dim_out in dim_pairs: 109 | self.layers.append(nn.Sequential( 110 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), 111 | nn.LeakyReLU(0.1), 112 | nn.GroupNorm(groups, dim_out), 113 | ResnetBlock(dim_out, dim_out), 114 | )) 115 | 116 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training 117 | nn.Conv2d(final_dim, final_dim, 1), 118 | nn.LeakyReLU(0.1), 119 | nn.Conv2d(final_dim, 1, 4) 120 | ) 121 | 122 | def forward(self, x): 123 | for net in self.layers: 124 | x = net(x) 125 | 126 | return self.to_logits(x) 127 | 128 | if __name__ == "__main__": 129 | x = torch.randn(10, 3, 256, 256) 130 | 131 | D = Discriminator( 132 | dim = 256, 133 | discr_layers = 6, 134 | channels = 3, 135 | groups = 8, 136 | cross_embed_kernel_sizes = (3, 7, 15) 137 | ) 138 | 139 | y = D(x) 140 | print(y.size()) -------------------------------------------------------------------------------- /modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from utils.utils import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "modules/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules.losses.lpips import LPIPS 6 | from modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge", disc_weight_max=None): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | self.disc_weight_max = disc_weight_max 64 | # recommend: 1000 for [churches, bedrooms], 1 for [ffhq] 65 | # paper: Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes 66 | 67 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 68 | if last_layer is not None: 69 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 70 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 71 | else: 72 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 73 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 74 | 75 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 76 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 77 | d_weight = d_weight * self.discriminator_weight 78 | return d_weight 79 | 80 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, cond=None, split="train"): 81 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 82 | if self.perceptual_weight > 0: 83 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 84 | rec_loss = rec_loss + self.perceptual_weight * p_loss 85 | else: 86 | p_loss = torch.tensor([0.0]) 87 | 88 | nll_loss = rec_loss 89 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 90 | nll_loss = torch.mean(nll_loss) 91 | 92 | # now the GAN part 93 | if optimizer_idx == 0: 94 | # generator update 95 | if cond is None: 96 | assert not self.disc_conditional 97 | logits_fake = self.discriminator(reconstructions.contiguous()) 98 | else: 99 | assert self.disc_conditional 100 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 101 | g_loss = -torch.mean(logits_fake) 102 | 103 | try: 104 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 105 | except RuntimeError: 106 | assert not self.training 107 | d_weight = torch.tensor(0.0) 108 | 109 | # 增加对disc_weight最大值的限制 110 | if self.disc_weight_max is not None: 111 | d_weight.clamp_max_(self.disc_weight_max) 112 | 113 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 114 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 115 | 116 | log = {"{}_total_loss".format(split): loss.clone().detach().mean(), 117 | "{}_quant_loss".format(split): codebook_loss.detach().mean(), 118 | "{}_nll_loss".format(split): nll_loss.detach().mean(), 119 | "{}_rec_loss".format(split): rec_loss.detach().mean(), 120 | "{}_p_loss".format(split): p_loss.detach().mean(), 121 | "{}_d_weight".format(split): d_weight.detach(), 122 | "{}_disc_factor".format(split): torch.tensor(disc_factor), 123 | "{}_g_loss".format(split): g_loss.detach().mean(), 124 | } 125 | return loss, log 126 | 127 | if optimizer_idx == 1: 128 | # second pass for discriminator update 129 | if cond is None: 130 | logits_real = self.discriminator(inputs.contiguous().detach()) 131 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 132 | else: 133 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 134 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 135 | 136 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 137 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 138 | 139 | log = {"{}_disc_loss".format(split): d_loss.clone().detach().mean(), 140 | "{}_logits_real".format(split): logits_real.detach().mean(), 141 | "{}_logits_fake".format(split): logits_fake.detach().mean() 142 | } 143 | return d_loss, log 144 | -------------------------------------------------------------------------------- /modules/losses/vqperceptual_epoch.py: -------------------------------------------------------------------------------- 1 | # Discard! remove to vqperceptual_multidisc 2 | # since we start adversarial training just at the begining, no need to check epoch 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from modules.losses.lpips import LPIPS 9 | from modules.discriminator.model import NLayerDiscriminator, weights_init 10 | 11 | 12 | class DummyLoss(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | 17 | def adopt_weight(weight, global_step, threshold=0, value=0.): 18 | if global_step < threshold: 19 | weight = value 20 | return weight 21 | 22 | 23 | def hinge_d_loss(logits_real, logits_fake): 24 | loss_real = torch.mean(F.relu(1. - logits_real)) 25 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 26 | d_loss = 0.5 * (loss_real + loss_fake) 27 | return d_loss 28 | 29 | 30 | def vanilla_d_loss(logits_real, logits_fake): 31 | d_loss = 0.5 * ( 32 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 33 | torch.mean(torch.nn.functional.softplus(logits_fake))) 34 | return d_loss 35 | 36 | 37 | # we assume the disc_start denotes the start epoch, i.e. [0,1,2,3,4,...] 38 | class VQLPIPSWithDiscriminator(nn.Module): 39 | def __init__(self, disc_start_epoch, codebook_weight=1.0, pixelloss_weight=1.0, 40 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 41 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 42 | disc_ndf=64, disc_loss="hinge", disc_weight_max=None): 43 | super().__init__() 44 | assert disc_loss in ["hinge", "vanilla"] 45 | self.codebook_weight = codebook_weight 46 | self.pixel_weight = pixelloss_weight 47 | self.perceptual_loss = LPIPS().eval() 48 | self.perceptual_weight = perceptual_weight 49 | 50 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 51 | n_layers=disc_num_layers, 52 | use_actnorm=use_actnorm, 53 | ndf=disc_ndf 54 | ).apply(weights_init) 55 | self.discriminator_iter_start_epoch = disc_start_epoch 56 | if disc_loss == "hinge": 57 | self.disc_loss = hinge_d_loss 58 | elif disc_loss == "vanilla": 59 | self.disc_loss = vanilla_d_loss 60 | else: 61 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 62 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 63 | self.disc_factor = disc_factor 64 | self.discriminator_weight = disc_weight 65 | self.disc_conditional = disc_conditional 66 | 67 | self.disc_weight_max = disc_weight_max 68 | # recommend: 1000 for [churches, bedrooms], 1 for [ffhq] 69 | # paper: Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes 70 | 71 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 72 | if last_layer is not None: 73 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 74 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 75 | else: 76 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 77 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 78 | 79 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 80 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 81 | d_weight = d_weight * self.discriminator_weight 82 | return d_weight 83 | 84 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, current_epoch, last_layer=None, cond=None, split="train"): 85 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 86 | if self.perceptual_weight > 0: 87 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 88 | rec_loss = rec_loss + self.perceptual_weight * p_loss 89 | else: 90 | p_loss = torch.tensor([0.0]) 91 | 92 | nll_loss = rec_loss 93 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 94 | nll_loss = torch.mean(nll_loss) 95 | 96 | # now the GAN part 97 | if optimizer_idx == 0: 98 | # generator update 99 | if cond is None: 100 | assert not self.disc_conditional 101 | logits_fake = self.discriminator(reconstructions.contiguous()) 102 | else: 103 | assert self.disc_conditional 104 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 105 | g_loss = -torch.mean(logits_fake) 106 | 107 | try: 108 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 109 | except RuntimeError: 110 | assert not self.training 111 | d_weight = torch.tensor(0.0) 112 | 113 | # 增加对disc_weight最大值的限制 114 | if self.disc_weight_max is not None: 115 | d_weight.clamp_max_(self.disc_weight_max) 116 | 117 | disc_factor = adopt_weight( 118 | self.disc_factor, current_epoch, threshold=self.discriminator_iter_start_epoch 119 | ) 120 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 121 | 122 | log = {"{}_total_loss".format(split): loss.clone().detach().mean(), 123 | "{}_quant_loss".format(split): codebook_loss.detach().mean(), 124 | "{}_nll_loss".format(split): nll_loss.detach().mean(), 125 | "{}_rec_loss".format(split): rec_loss.detach().mean(), 126 | "{}_p_loss".format(split): p_loss.detach().mean(), 127 | "{}_d_weight".format(split): d_weight.detach(), 128 | "{}_disc_factor".format(split): torch.tensor(disc_factor), 129 | "{}_g_loss".format(split): g_loss.detach().mean(), 130 | } 131 | return loss, log 132 | 133 | if optimizer_idx == 1: 134 | # second pass for discriminator update 135 | if cond is None: 136 | logits_real = self.discriminator(inputs.contiguous().detach()) 137 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 138 | else: 139 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 140 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 141 | 142 | disc_factor = adopt_weight( 143 | self.disc_factor, current_epoch, threshold=self.discriminator_iter_start_epoch 144 | ) 145 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 146 | 147 | log = {"{}_disc_loss".format(split): d_loss.clone().detach().mean(), 148 | "{}_logits_real".format(split): logits_real.detach().mean(), 149 | "{}_logits_fake".format(split): logits_fake.detach().mean() 150 | } 151 | return d_loss, log 152 | -------------------------------------------------------------------------------- /modules/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/MaskedVectorQuantization/f38a9ecbe5947d9c0a395c4f17aef356ad250eac/modules/lpips/vgg.pth -------------------------------------------------------------------------------- /modules/masked_quantization/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import os, sys 6 | sys.path.append(os.getcwd()) 7 | from modules.diffusionmodules.model import ResnetBlock, AttnBlock, Upsample, Normalize, nonlinearity 8 | 9 | 10 | class Decoder(nn.Module): 11 | def __init__(self, 12 | ch, in_ch, out_ch, ch_mult, num_res_blocks, resolution, 13 | attn_resolutions, dropout = 0.0, resamp_with_conv = True, give_pre_end = False, 14 | ): 15 | super().__init__() 16 | self.num_resolutions = len(ch_mult) 17 | self.num_res_blocks = num_res_blocks 18 | self.resolution = resolution 19 | self.in_ch = in_ch 20 | self.temb_ch = 0 21 | self.ch = ch 22 | self.give_pre_end = give_pre_end 23 | 24 | # compute block_in and curr_res at lowest res 25 | block_in = ch*ch_mult[self.num_resolutions-1] 26 | curr_res = resolution // 2**(self.num_resolutions-1) 27 | self.z_shape = (1,in_ch,curr_res,curr_res) 28 | print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) 29 | 30 | # self.conv_in = torch.nn.Conv2d(in_ch, block_in, kernel_size=3, stride=1, padding=1) 31 | 32 | # # middle 33 | # self.mid = nn.Module() 34 | # self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 35 | # self.mid.attn_1 = AttnBlock(block_in) 36 | # self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 37 | 38 | # upsampling 39 | self.up = nn.ModuleList() 40 | for i_level in reversed(range(self.num_resolutions)): 41 | block = nn.ModuleList() 42 | attn = nn.ModuleList() 43 | block_out = ch*ch_mult[i_level] 44 | for i_block in range(self.num_res_blocks+1): 45 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) 46 | block_in = block_out 47 | if curr_res in attn_resolutions: 48 | attn.append(AttnBlock(block_in)) 49 | up = nn.Module() 50 | up.block = block 51 | up.attn = attn 52 | if i_level != 0: 53 | up.upsample = Upsample(block_in, resamp_with_conv) 54 | curr_res = curr_res * 2 55 | self.up.insert(0, up) # prepend to get consistent order 56 | 57 | # end 58 | self.norm_out = Normalize(block_in) 59 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 60 | 61 | def forward(self, h): 62 | # z to block_in 63 | temb = None 64 | # h = self.conv_in(h) 65 | 66 | # # middle 67 | # h = self.mid.block_1(h, temb) 68 | # h = self.mid.attn_1(h) 69 | # h = self.mid.block_2(h, temb) 70 | 71 | # upsampling 72 | for i_level in reversed(range(self.num_resolutions)): 73 | for i_block in range(self.num_res_blocks+1): 74 | h = self.up[i_level].block[i_block](h, temb) 75 | if len(self.up[i_level].attn) > 0: 76 | h = self.up[i_level].attn[i_block](h) 77 | if i_level != 0: 78 | h = self.up[i_level].upsample(h) 79 | 80 | # end 81 | if self.give_pre_end: 82 | return h 83 | 84 | h = self.norm_out(h) 85 | h = nonlinearity(h) 86 | h = self.conv_out(h) 87 | 88 | return h -------------------------------------------------------------------------------- /modules/masked_quantization/demasker_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | import os, sys 6 | sys.path.append(os.getcwd()) 7 | 8 | from modules.diffusionmodules.model import Normalize, nonlinearity 9 | 10 | class ResnetBlockwithKernel(nn.Module): 11 | def __init__(self, *, in_channels, out_channels = None, conv_shortcut = False, dropout = 0., temb_channels = 512, kernel_size = 3): 12 | super().__init__() 13 | if kernel_size == 3: 14 | padding = 1 15 | elif kernel_size == 1: 16 | padding = 0 17 | else: 18 | raise NotImplementedError() 19 | 20 | self.in_channels = in_channels 21 | out_channels = in_channels if out_channels is None else out_channels 22 | self.out_channels = out_channels 23 | self.use_conv_shortcut = conv_shortcut 24 | 25 | self.norm1 = Normalize(in_channels) 26 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = 1, padding = padding) 27 | if temb_channels > 0: 28 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 29 | self.norm2 = Normalize(out_channels) 30 | self.dropout = torch.nn.Dropout(dropout) 31 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = 1, padding = padding) 32 | if self.in_channels != self.out_channels: 33 | if self.use_conv_shortcut: 34 | self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = 1, padding = padding) 35 | else: 36 | self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 0) 37 | 38 | def forward(self, x, temb=None, **ignore_kwargs): 39 | h = x 40 | h = self.norm1(h) 41 | h = nonlinearity(h) 42 | h = self.conv1(h) 43 | 44 | if temb is not None: 45 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 46 | 47 | h = self.norm2(h) 48 | h = nonlinearity(h) 49 | h = self.dropout(h) 50 | h = self.conv2(h) 51 | 52 | if self.in_channels != self.out_channels: 53 | if self.use_conv_shortcut: 54 | x = self.conv_shortcut(x) 55 | else: 56 | x = self.nin_shortcut(x) 57 | 58 | return x+h 59 | 60 | class BiasedSelfAttnBlock(nn.Module): 61 | def __init__(self, in_channels, reweight = False): 62 | super().__init__() 63 | self.in_channels = in_channels 64 | self.apply_reweight = reweight 65 | 66 | self.norm = Normalize(in_channels) 67 | self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 68 | self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 69 | self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 70 | self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 71 | 72 | def forward(self, x, mask, **ignore_kwargs): 73 | h_ = x 74 | h_ = self.norm(h_) 75 | q = self.q(h_) 76 | k = self.k(h_) 77 | v = self.v(h_) 78 | 79 | # compute attention 80 | b,c,h,w = q.shape 81 | q = q.reshape(b,c,h*w) 82 | q = q.permute(0,2,1) # b,hw,c 83 | k = k.reshape(b,c,h*w) # b,c,hw 84 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 85 | w_ = w_ * (int(c)**(-0.5)) 86 | w_ = torch.nn.functional.softmax(w_, dim=2) 87 | 88 | if mask is not None: 89 | unsqueezed_mask = mask.unsqueeze(-2) 90 | w_ = w_ * unsqueezed_mask 91 | 92 | if self.apply_reweight: 93 | w_sum = torch.sum(w_, dim=-1, keepdim=True) 94 | w_ = w_ / w_sum 95 | 96 | # attend to values 97 | v = v.reshape(b,c,h*w) 98 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 99 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 100 | h_ = h_.reshape(b,c,h,w) 101 | 102 | h_ = self.proj_out(h_) 103 | 104 | return x+h_ 105 | 106 | class TransformerStyleEncoderBlock(nn.Module): 107 | def __init__(self, dim): 108 | super().__init__() 109 | self.attn = BiasedSelfAttnBlock(dim) 110 | self.resblock = ResnetBlockwithKernel(in_channels=dim) 111 | 112 | def forward(self, x, mask = None): 113 | x = self.resblock(x) 114 | x = self.attn(x, mask) 115 | return x 116 | 117 | class TransformerStyleEncoder(nn.Module): 118 | def __init__(self, dim, n_layer, mask_init_value = 0.02): 119 | super().__init__() 120 | self.n_layer = n_layer 121 | self.mask_init_value = mask_init_value 122 | self.blocks = nn.ModuleList() 123 | for i in range(self.n_layer): 124 | self.blocks.append(TransformerStyleEncoderBlock(dim)) 125 | self.last_resblock = ResnetBlockwithKernel(in_channels=dim) 126 | 127 | def forward(self, x, mask = None): 128 | mask = mask + self.mask_init_value * (1 - mask) # replace 0 with self.mask_init_value 129 | for i in range(self.n_layer): 130 | x = self.blocks[i](x = x, mask = mask) 131 | if mask is not None: 132 | mask = torch.sqrt(mask) 133 | x = self.last_resblock(x) 134 | return x 135 | 136 | class VanillaDemasker(nn.Module): 137 | def __init__(self, codebook_dim, output_dim, height_and_width, n_layer, mask_init_value = 0.02): 138 | super().__init__() 139 | self.output_dim = output_dim 140 | self.codebook_dim = codebook_dim 141 | self.hw = height_and_width 142 | self.total_code_num = height_and_width * height_and_width 143 | self.mask_token = nn.Parameter(torch.zeros(1, codebook_dim, 1), requires_grad = True) 144 | self.mask_token.data.uniform_(-1.0 / codebook_dim, 1.0 / codebook_dim) 145 | self.post_projection = torch.nn.Conv2d(codebook_dim, output_dim, 1) 146 | 147 | self.transformer = TransformerStyleEncoder(output_dim, n_layer, mask_init_value) 148 | # keep it here to see whether it should 149 | self.conv_in = torch.nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=1) 150 | 151 | def forward(self, sampled_quant, remain_quant, sample_index, remain_index = None, mask = None): 152 | batch_size = sampled_quant.size(0) 153 | sample_index = sample_index.cpu().numpy().tolist() 154 | 155 | if remain_index == None: 156 | remain_index = [[j for j in range(self.total_code_num) if j not in sample_index[i]] for i in range(batch_size)] 157 | else: 158 | remain_index = remain_index.cpu().numpy().tolist() 159 | 160 | full_embedding = nn.Parameter(torch.zeros((batch_size, self.codebook_dim, self.total_code_num))).to(sampled_quant.device) 161 | 162 | for i in range(batch_size): 163 | full_embedding[i, :, sample_index[i]] = sampled_quant[i] 164 | full_embedding[i, :, remain_index[i]] = self.mask_token 165 | 166 | full_embedding = rearrange(full_embedding, "B C (H W) -> B C H W", H=self.hw, W=self.hw) 167 | full_embedding = self.post_projection(full_embedding) 168 | full_embedding = self.conv_in(full_embedding) # The conv_in in the decoder is here 169 | recovered_embedding = self.transformer(full_embedding, mask) 170 | 171 | return recovered_embedding -------------------------------------------------------------------------------- /modules/masked_quantization/masker_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class RandomTokenizer(nn.Module): 8 | def __init__(self, 9 | topk, 10 | input_token_num, 11 | input_dim, 12 | output_dim, 13 | z_dim, 14 | patch_size, 15 | apply_norm_image_features=True, 16 | mask_token_initialization=0, ): 17 | r""" 18 | mask_token_initialization: mask token 初始化的方式,默认为全0 19 | choice: 20 | 0 : 全0初始化, 最后训练的结果也是全0 21 | 1 : 全1初始化 22 | "uniform": 正态分布初始化 23 | """ 24 | super().__init__() 25 | self.input_token_num = input_token_num 26 | self.sample_num = int(topk * input_token_num) 27 | self.topk_num = int(topk * input_token_num) 28 | self.z_dim = z_dim 29 | self.patch_size = patch_size 30 | self.hw = int(input_token_num**0.5) 31 | self.apply_norm_image_features = apply_norm_image_features 32 | 33 | if self.apply_norm_image_features: 34 | self.norm_feature = nn.LayerNorm(input_dim, elementwise_affine=False) 35 | 36 | # 更新,2022-4-4,更新mask token的初始化方式 37 | if mask_token_initialization == 0: 38 | self.mask_token = nn.Parameter(torch.zeros(1, output_dim, 1), requires_grad=True) 39 | elif mask_token_initialization == 1: 40 | self.mask_token = nn.Parameter(torch.ones(1, output_dim, 1), requires_grad=True) 41 | elif mask_token_initialization == "uniform": 42 | self.mask_token = nn.Parameter(torch.zeros(1, output_dim, 1), requires_grad=True) 43 | self.mask_token.data.uniform_(-1.0 / output_dim, 1.0 / output_dim) 44 | elif mask_token_initialization == "random": 45 | self.mask_token = nn.Parameter(torch.randn(1, output_dim, 1), requires_grad=True) 46 | else: 47 | raise NotImplementedError() 48 | 49 | self.mask = torch.from_numpy(np.zeros(self.input_token_num)).float() 50 | 51 | self.decode_dim = output_dim 52 | self.post_projection = torch.nn.Conv2d(self.z_dim, self.decode_dim, 1) 53 | self.pre_projection = torch.nn.Linear(input_dim, self.z_dim) 54 | 55 | def preforward(self, image_features): 56 | image_features = rearrange(image_features, "B C H W -> B (H W) C") 57 | batch_size, length, channel = image_features.size() 58 | 59 | # random sample importance score 60 | pred_score = torch.randn(batch_size, length).to(image_features.device) 61 | 62 | sort_score, sort_order = pred_score.sort(descending=True,dim=1) 63 | sort_topk = sort_order[:, :self.topk_num] 64 | sort_topk_remaining = sort_order[:, self.topk_num:] 65 | ## flatten for gathering 66 | if self.apply_norm_image_features: 67 | image_features = self.norm_feature(image_features) 68 | 69 | ## (only) sampled features 70 | image_features_sampled = image_features.gather(1, sort_topk[...,None].expand(-1, -1, channel)) 71 | image_features = rearrange(self.pre_projection(image_features_sampled), "B N C -> B C N") 72 | 73 | # get mask 74 | self.mask = self.mask.to(image_features.device) 75 | for i in range(batch_size): 76 | if i == 0: 77 | mask = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 78 | else: 79 | mask_i = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 80 | mask = torch.cat([mask, mask_i], dim=0) 81 | squeezed_mask = mask.view(batch_size, -1) # [batch_size, length] 82 | mask = F.interpolate(mask.float().unsqueeze(1), scale_factor=self.patch_size, mode="nearest") 83 | 84 | # normalize the score just for better visualization 85 | normed_score = pred_score.sub(pred_score.min()).div(max(pred_score.max() - pred_score.min(), 1e-5)).unsqueeze(-1) 86 | normed_score = F.interpolate(rearrange(normed_score, "b (h w) c -> b c h w", h=self.hw, w=self.hw), scale_factor=self.patch_size, mode="nearest") 87 | 88 | return_dict = { 89 | "sample_h": image_features, 90 | "sample_index": sort_topk, 91 | "remain_index": sort_topk_remaining, 92 | "binary_map": mask, 93 | "score_map": normed_score, 94 | "squeezed_mask": squeezed_mask, 95 | "sort_score": sort_score[:, :self.topk_num], 96 | } 97 | 98 | return return_dict 99 | 100 | def postforward(self, sample_h, sample_index, remain_index=None): 101 | batch_size = sample_h.size(0) 102 | sample_index = sample_index.cpu().numpy().tolist() 103 | if remain_index == None: 104 | remain_index = [[j for j in range(self.input_token_num) if j not in sample_index[i]] for i in range(batch_size)] 105 | else: 106 | remain_index = remain_index.cpu().numpy().tolist() 107 | 108 | decoder_embeeding = nn.Parameter(torch.zeros((batch_size, self.decode_dim, self.input_token_num))).to(sample_h.device) 109 | 110 | for i in range(batch_size): 111 | decoder_embeeding[i, :, sample_index[i]] = sample_h[i] 112 | decoder_embeeding[i, :, remain_index[i]] = self.mask_token 113 | 114 | decoder_embeeding = rearrange(decoder_embeeding, "B C (H W) -> B C H W", H=self.hw, W=self.hw) 115 | decoder_embeeding = self.post_projection(decoder_embeeding) 116 | 117 | return { 118 | "decoder_embeeding": decoder_embeeding, 119 | } -------------------------------------------------------------------------------- /modules/masked_quantization/masker_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | # first predict scores, then norm features 8 | class VanillaMasker(nn.Module): 9 | def __init__(self, 10 | topk_ratio, 11 | input_token_num, 12 | input_dim, 13 | patch_size, 14 | score_pred_net_mode = "2layer-fc", 15 | codebook_dim = 32, 16 | ): 17 | super().__init__() 18 | self.input_token_num = input_token_num 19 | self.sample_num = int(topk_ratio * input_token_num) 20 | self.unsampled_num = input_token_num - self.sample_num 21 | self.topk_num = int(topk_ratio * input_token_num) 22 | self.patch_size = patch_size 23 | self.hw = int(input_token_num**0.5) 24 | if score_pred_net_mode == "2layer-fc": 25 | self.score_pred_net = nn.Sequential(nn.Linear(input_dim, input_dim), 26 | nn.ReLU(), 27 | nn.Linear(input_dim, 1), 28 | nn.Sigmoid()) 29 | else: 30 | raise ValueError 31 | self.norm_feature = nn.LayerNorm(input_dim, elementwise_affine=False) 32 | 33 | self.mask = torch.from_numpy(np.zeros(self.input_token_num)).float() 34 | self.pre_projection = torch.nn.Linear(input_dim, codebook_dim, bias=False) 35 | 36 | 37 | def forward(self, image_features): 38 | image_features = rearrange(image_features, "B C H W -> B (H W) C") 39 | batch_size, length, channel = image_features.size() 40 | 41 | pred_score = self.score_pred_net(image_features).view(batch_size, -1) 42 | pred_score_clone = pred_score.clone().detach() 43 | 44 | sort_score, sort_order = pred_score_clone.sort(descending=True,dim=1) 45 | sort_topk = sort_order[:, :self.topk_num] 46 | sort_topk_remain = sort_order[:, self.topk_num:] 47 | ## flatten for gathering 48 | image_features = self.norm_feature(image_features) 49 | 50 | ## (only) sampled features multiply with score 51 | image_features_sampled = image_features.gather(1, sort_topk[...,None].expand(-1, -1, channel)) * pred_score.gather(1, sort_topk).unsqueeze(-1) 52 | image_features_sampled = rearrange(self.pre_projection(image_features_sampled), "B N C -> B C N") 53 | 54 | # get mask 55 | self.mask = self.mask.to(image_features_sampled.device) 56 | for i in range(batch_size): 57 | if i == 0: 58 | mask = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 59 | else: 60 | mask_i = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 61 | mask = torch.cat([mask, mask_i], dim=0) 62 | squeezed_mask = mask.view(batch_size, -1) # [batch_size, length] 63 | mask = F.interpolate(mask.float().unsqueeze(1), scale_factor=self.patch_size, mode="nearest") 64 | 65 | # normalize the score just for better visualization 66 | normed_score = pred_score_clone.sub(pred_score_clone.min()).div(max(pred_score_clone.max() - pred_score_clone.min(), 1e-5)).unsqueeze(-1) 67 | normed_score = F.interpolate(rearrange(normed_score, "b (h w) c -> b c h w", h=self.hw, w=self.hw), scale_factor=self.patch_size, mode="nearest") 68 | 69 | return_dict = { 70 | "sample_features": image_features_sampled, 71 | "remain_features": None, 72 | "sample_index": sort_topk, 73 | "remain_index": sort_topk_remain, 74 | "binary_map": mask, 75 | "score_map": normed_score, 76 | "squeezed_mask": squeezed_mask, 77 | "sort_score": sort_score[:, :self.topk_num], 78 | "sampled_length": image_features_sampled.size(-1), 79 | } 80 | 81 | return return_dict -------------------------------------------------------------------------------- /modules/masked_quantization/masker_vanilla_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | # first predict scores, then norm features 8 | # replace score net with a conv one 9 | class VanillaMasker(nn.Module): 10 | def __init__(self, 11 | topk_ratio, 12 | input_token_num, 13 | input_dim, 14 | patch_size, 15 | score_pred_net_mode = "2layer", 16 | codebook_dim = 32, 17 | ): 18 | super().__init__() 19 | self.input_token_num = input_token_num 20 | self.sample_num = int(topk_ratio * input_token_num) 21 | self.unsampled_num = input_token_num - self.sample_num 22 | self.topk_num = int(topk_ratio * input_token_num) 23 | self.patch_size = patch_size 24 | self.hw = int(input_token_num**0.5) 25 | if score_pred_net_mode == "2layer": 26 | self.score_pred_net = nn.Sequential( 27 | nn.LayerNorm([input_dim, self.hw, self.hw], elementwise_affine=False), 28 | nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1), 29 | nn.ReLU(), 30 | nn.Conv2d(input_dim, 1, kernel_size=1, stride=1, padding=0), 31 | nn.BatchNorm2d(1), 32 | nn.Sigmoid(), 33 | ) 34 | else: 35 | raise ValueError 36 | self.norm_feature = nn.LayerNorm(input_dim, elementwise_affine=False) 37 | 38 | self.mask = torch.from_numpy(np.zeros(self.input_token_num)).float() 39 | self.pre_projection = torch.nn.Linear(input_dim, codebook_dim, bias=False) 40 | 41 | 42 | def forward(self, image_features): 43 | batch_size, channel, height, width = image_features.size() 44 | pred_score = self.score_pred_net(image_features).view(batch_size, -1) 45 | pred_score_clone = pred_score.clone().detach() 46 | 47 | sort_score, sort_order = pred_score_clone.sort(descending=True,dim=1) 48 | sort_topk = sort_order[:, :self.topk_num] 49 | sort_topk_remain = sort_order[:, self.topk_num:] 50 | ## flatten for gathering 51 | image_features = rearrange(image_features, "B C H W -> B (H W) C") 52 | image_features = self.norm_feature(image_features) 53 | 54 | ## (only) sampled features multiply with score 55 | image_features_sampled = image_features.gather( 56 | 1, sort_topk[...,None].expand(-1, -1, channel)) * pred_score.gather(1, sort_topk).unsqueeze(-1) 57 | image_features_sampled = rearrange(self.pre_projection(image_features_sampled), "B N C -> B C N") 58 | 59 | # get mask 60 | self.mask = self.mask.to(image_features_sampled.device) 61 | for i in range(batch_size): 62 | if i == 0: 63 | mask = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 64 | else: 65 | mask_i = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 66 | mask = torch.cat([mask, mask_i], dim=0) 67 | squeezed_mask = mask.view(batch_size, -1) # [batch_size, length] 68 | mask = F.interpolate(mask.float().unsqueeze(1), scale_factor=self.patch_size, mode="nearest") 69 | 70 | # normalize the score just for better visualization 71 | normed_score = pred_score_clone.sub(pred_score_clone.min()).div(max(pred_score_clone.max() - pred_score_clone.min(), 1e-5)).unsqueeze(-1) 72 | normed_score = F.interpolate(rearrange(normed_score, "b (h w) c -> b c h w", h=self.hw, w=self.hw), scale_factor=self.patch_size, mode="nearest") 73 | 74 | return_dict = { 75 | "sample_features": image_features_sampled, 76 | "remain_features": None, 77 | "sample_index": sort_topk, 78 | "remain_index": sort_topk_remain, 79 | "binary_map": mask, 80 | "score_map": normed_score, 81 | "squeezed_mask": squeezed_mask, 82 | "sort_score": sort_score[:, :self.topk_num], 83 | "sampled_length": image_features_sampled.size(-1), 84 | } 85 | 86 | return return_dict 87 | 88 | 89 | if __name__ == "__main__": 90 | image_features = torch.randn(10, 256, 32, 32) 91 | masker = VanillaMasker( 92 | topk_ratio = 0.25, 93 | input_token_num = 1024, 94 | input_dim = 256, 95 | patch_size = 32, 96 | score_pred_net_mode = "2layer", 97 | codebook_dim = 256, 98 | ) 99 | masker(image_features) -------------------------------------------------------------------------------- /modules/masked_quantization/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # return value_indices and position_indices into raster-scan 5 | class raster_scan_permuter(nn.Module): 6 | def forward(self, indices, position_indices, **ignore_kwargs): 7 | sorted_position_indices, sorted_order = torch.sort(position_indices, descending=False, dim=-1) 8 | sorted_indices = indices.gather(1, sorted_order) 9 | return sorted_indices, sorted_position_indices 10 | 11 | class identity_permuter(nn.Module): 12 | def forward(self, indices, position_indices, **ignore_kwargs): 13 | return indices, position_indices -------------------------------------------------------------------------------- /modules/masked_quantization/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from PIL import Image 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | 8 | transform_PIL = transforms.Compose([ 9 | transforms.ToPILImage(), 10 | ]) 11 | 12 | color_dict = { 13 | "red": (255, 0, 0), 14 | "green": (0, 255, 0), 15 | "white": (255, 255, 255), 16 | "yellow": (255, 255, 0), 17 | "blue": (5, 39, 175), 18 | } 19 | 20 | # 颜色列表 21 | # 红,绿,黄,蓝 22 | # 紫,青, 23 | # 灰色,冷灰色,石板灰,暖灰色 24 | # 香蕉色,镉黄,dougello,forum gold 25 | # 金黄色,黄花色 26 | color_list = [ 27 | (255, 0, 0), (0, 255, 0), (255, 255, 0), (5, 39, 175), 28 | (255,0,255), (0,255,255), 29 | (192,192,192), (128,138,135), (112,128,105), (128,128,105), 30 | (227,207,87), (255,153,18), (235,142,85), (255,227,132), 31 | (255,215,0), (218,165,105) 32 | ] 33 | # https://blog.csdn.net/pinaby/article/details/2823366 34 | 35 | 36 | # same function in torchvision.utils.save_image(normalize=True) 37 | def image_normalize(tensor, value_range=None, scale_each=False): 38 | tensor = tensor.clone() 39 | def norm_ip(img, low, high): 40 | img.clamp_(min=low, max=high) 41 | img.sub_(low).div_(max(high - low, 1e-5)) 42 | 43 | def norm_range(t, value_range): 44 | if value_range is not None: 45 | norm_ip(t, value_range[0], value_range[1]) 46 | else: 47 | norm_ip(t, float(t.min()), float(t.max())) 48 | 49 | if scale_each is True: 50 | for t in tensor: # loop over mini-batch dimension 51 | norm_range(t, value_range) 52 | else: 53 | norm_range(tensor, value_range) 54 | 55 | return tensor 56 | 57 | def transform_invert(img_, transform_train): 58 | """ 59 | 将data 进行反transfrom操作 60 | :param img_: tensor 61 | :param transform_train: torchvision.transforms 62 | :return: PIL image 63 | """ 64 | if 'Normalize' in str(transform_train): 65 | norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms)) 66 | mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device) 67 | std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device) 68 | img_.mul_(std[:, None, None]).add_(mean[:, None, None]) 69 | 70 | img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C 71 | img_ = np.array(img_) * 255 72 | 73 | if img_.shape[2] == 3: 74 | img_ = Image.fromarray(img_.astype('uint8')).convert('RGB') 75 | elif img_.shape[2] == 1: 76 | img_ = Image.fromarray(img_.astype('uint8').squeeze()) 77 | else: 78 | raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) ) 79 | 80 | return img_ 81 | 82 | # images, score_map: tensor 83 | def build_score_image(images, score_map, low_color="blue", high_color="red", scaler=0.9): 84 | bs = images.size(0) 85 | 86 | low = Image.new('RGB', (images.size(-2), images.size(-1)), color_dict[low_color]) 87 | high = Image.new('RGB', (images.size(-2), images.size(-1)), color_dict[high_color]) 88 | 89 | for i in range(bs): 90 | image_i_pil = transform_PIL(image_normalize(images[i])) 91 | 92 | score_map_i_np = rearrange(score_map[i], "C H W -> H W C").cpu().detach().numpy() 93 | score_map_i_blend = Image.fromarray( 94 | np.uint8(high * score_map_i_np + low * (1 - score_map_i_np))) 95 | 96 | image_i_blend = Image.blend(image_i_pil, score_map_i_blend, scaler) 97 | 98 | if i == 0: 99 | blended_images = torchvision.transforms.functional.to_tensor(image_i_blend).unsqueeze(0) 100 | else: 101 | blended_images = torch.cat([ 102 | blended_images, torchvision.transforms.functional.to_tensor(image_i_blend).unsqueeze(0) 103 | ], dim=0) 104 | return blended_images 105 | 106 | 107 | if __name__ == "__main__": 108 | images = torch.zeros(1, 3, 256, 256) 109 | score_map = torch.tensor([[[ 110 | [0.20, 0.07, 0.64, 0.09], 111 | [0.14, 0.12, 0.32, 0.02], 112 | [0.22, 0.97, 0.07, 0.07], 113 | [0.32, 0.37, 0.12, 0.53] 114 | ]]]).repeat_interleave(64, -1).repeat_interleave(64, -2) 115 | out = build_score_image(images, score_map, low_color="blue", high_color="red", scaler=0.9) 116 | torchvision.utils.save_image(out, "out.png") -------------------------------------------------------------------------------- /modules/masked_quantization_stage2/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # return value_indices and position_indices into raster-scan 5 | class raster_scan_permuter(nn.Module): 6 | def forward(self, indices, position_indices, **ignore_kwargs): 7 | sorted_position_indices, sorted_order = torch.sort(position_indices, descending=False, dim=-1) 8 | sorted_indices = indices.gather(1, sorted_order) 9 | return sorted_indices, sorted_position_indices 10 | 11 | class identity_permuter(nn.Module): 12 | def forward(self, indices, position_indices, **ignore_kwargs): 13 | return indices, position_indices 14 | 15 | if __name__ == "__main__": 16 | pass -------------------------------------------------------------------------------- /modules/scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LinearWarmUpScheduler: 5 | def __init__(self, ): 6 | pass 7 | 8 | 9 | 10 | class LambdaWarmUpCosineScheduler: 11 | """ 12 | note: use with a base_lr of 1.0 13 | """ 14 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 15 | self.lr_warm_up_steps = warm_up_steps 16 | self.lr_start = lr_start 17 | self.lr_min = lr_min 18 | self.lr_max = lr_max 19 | self.lr_max_decay_steps = max_decay_steps 20 | self.last_lr = 0. 21 | self.verbosity_interval = verbosity_interval 22 | 23 | def schedule(self, n): 24 | if self.verbosity_interval > 0: 25 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 26 | if n < self.lr_warm_up_steps: 27 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 28 | self.last_lr = lr 29 | return lr 30 | else: 31 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 32 | t = min(t, 1.0) 33 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 34 | 1 + np.cos(t * np.pi)) 35 | self.last_lr = lr 36 | return lr 37 | 38 | def __call__(self, n): 39 | return self.schedule(n) 40 | 41 | -------------------------------------------------------------------------------- /modules/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.optim.lr_scheduler import CosineAnnealingLR 4 | 5 | 6 | def create_scheduler(optimizer, scheduler_config, steps_per_epoch, max_epoch): 7 | 8 | multiplier = scheduler_config.multiplier 9 | warmup_steps = scheduler_config.warmip_epoch * steps_per_epoch 10 | buffer_steps = scheduler_config.buffer_epoch * steps_per_epoch 11 | final_steps = max_epoch * steps_per_epoch 12 | min_lr = scheduler_config.min_lr 13 | warmup_policy = scheduler_config.warmup_policy 14 | start_from_zero = scheduler_config.start_from_zero 15 | 16 | scheduler = CosineAnnealingLR( 17 | optimizer, T_max=final_steps - warmup_steps - buffer_steps, eta_min=min_lr 18 | ) 19 | 20 | if warmup_steps > 0.0: 21 | if warmup_policy == 'linear': 22 | multiplier = max(1.0, multiplier) 23 | elif warmup_policy == 'sqrt': 24 | multiplier = max(1.0, multiplier) 25 | elif warmup_policy == 'fix': 26 | multiplier = max(1.0, multiplier) 27 | elif warmup_policy == 'none': 28 | pass 29 | else: 30 | raise NotImplementedError(f'{warmup_policy} is not a valid warmup policy') 31 | warmup = GradualWarmup( 32 | optimizer, 33 | steps=warmup_steps, 34 | buffer_steps=buffer_steps, 35 | multiplier=multiplier, 36 | start_from_zero=start_from_zero 37 | ) 38 | else: 39 | warmup = None 40 | 41 | scheduler = Scheduler(warmup_scheduler=warmup, after_scheduler=scheduler) 42 | 43 | return scheduler 44 | 45 | 46 | class GradualWarmup(torch.optim.lr_scheduler._LRScheduler): 47 | def __init__(self, optimizer, steps, buffer_steps, multiplier, start_from_zero=True, last_epoch=-1): 48 | self.steps = steps 49 | self.t_steps = steps + buffer_steps 50 | self.multiplier = multiplier 51 | self.start_from_zero = start_from_zero 52 | 53 | super(GradualWarmup, self).__init__(optimizer, last_epoch) 54 | 55 | def get_lr(self): 56 | if self.last_epoch > self.steps: 57 | return [group['lr'] for group in self.optimizer.param_groups] 58 | 59 | if self.start_from_zero: 60 | multiplier = self.multiplier * min(1.0, (self.last_epoch / self.steps)) 61 | else: 62 | multiplier = 1 + ((self.multiplier - 1) * min(1.0, (self.last_epoch / self.steps))) 63 | return [lr * multiplier for lr in self.base_lrs] 64 | 65 | 66 | class Scheduler: 67 | def __init__(self, warmup_scheduler, after_scheduler): 68 | self.warmup_scheduler = warmup_scheduler 69 | self.after_scheduler = after_scheduler 70 | 71 | def step(self, epoch=None): 72 | if self.warmup_scheduler is not None: 73 | self.warmup_scheduler.step(epoch=epoch) 74 | 75 | if self.warmup_scheduler is None or \ 76 | self.warmup_scheduler.last_epoch > self.warmup_scheduler.t_steps: 77 | self.after_scheduler.step(epoch=epoch) 78 | 79 | def get_last_lr(self): 80 | if self.warmup_scheduler is not None and \ 81 | self.warmup_scheduler.last_epoch <= self.warmup_scheduler.t_steps: 82 | print("self.warmup_scheduler.get_last_lr(): ", self.warmup_scheduler.get_last_lr()) 83 | return self.warmup_scheduler.get_last_lr() 84 | else: 85 | return self.after_scheduler.get_last_lr() 86 | 87 | def state_dict(self): 88 | return { 89 | 'warmup': None if self.warmup_scheduler is None else self.warmup_scheduler.state_dict(), 90 | 'after': self.after_scheduler.state_dict() 91 | } 92 | 93 | def load_state_dict(self, state_dict): 94 | if self.warmup_scheduler is not None: 95 | self.warmup_scheduler.load_state_dict(state_dict['warmup']) 96 | self.after_scheduler.load_state_dict(state_dict['after']) -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/base_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseEmbedding(nn.Module): 6 | 7 | def get_loss(self): 8 | return None 9 | 10 | def forward(self, **kwargs): 11 | raise NotImplementedError 12 | 13 | def train(self, mode=True): 14 | self.training = mode 15 | if self.trainable and mode: 16 | super().train() 17 | return self 18 | 19 | def _set_trainable(self): 20 | if not self.trainable: 21 | for pn, p in self.named_parameters(): 22 | p.requires_grad = False 23 | self.eval() 24 | 25 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/README.md: -------------------------------------------------------------------------------- 1 | https://github.com/openai/CLIP -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, end_idx=49152, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | # merges = merges[1:49152-256-2+1] 68 | 69 | # end_idx can be 49152 for CLIP 70 | # or 16384 for DALL-E 71 | merges = merges[1:end_idx-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] # with length 256 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) # with length = end_idx+256 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + (token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token + '' 92 | 93 | while True: 94 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word) - 1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, end_idx=49152, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | # merges = merges[1:49152-256-2+1] 68 | 69 | # end_idx can be 49152 for CLIP 70 | # or 16384 for DALL-E 71 | merges = merges[1:end_idx-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] # with length 256 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) # with length = end_idx+256 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + (token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token + '' 92 | 93 | while True: 94 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word) - 1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip_text_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os, sys 4 | sys.path.append(os.getcwd()) 5 | from modules.clip_text_encoder.clip import clip 6 | from modules.clip_text_encoder.clip import model as clip_model 7 | from modules.clip_text_encoder.base_embedding import BaseEmbedding 8 | 9 | class CLIPTextEmbedding(BaseEmbedding): 10 | def __init__(self, 11 | clip_name='ViT-B/32', 12 | num_embed=49408, 13 | normalize=True, 14 | pick_last_embedding=True, 15 | keep_seq_len_dim=False, 16 | additional_last_embedding=False, 17 | embed_dim=512, 18 | ): 19 | super().__init__() 20 | self.num_embed = num_embed 21 | self.clip_name = clip_name 22 | self.normalize = normalize 23 | self.pick_last_embedding = pick_last_embedding 24 | self.keep_seq_len_dim = keep_seq_len_dim 25 | self.additional_last_embedding = additional_last_embedding 26 | 27 | model, _ = clip.load(clip_name, device='cpu',jit=False) 28 | model = clip_model.build_model(model.state_dict()) 29 | 30 | self.token_embedding = model.token_embedding 31 | self.positional_embedding = model.positional_embedding 32 | self.transformer = model.transformer 33 | self.ln_final = model.ln_final 34 | self.text_projection = model.text_projection 35 | 36 | if embed_dim == 1024: 37 | self.embed_dim = self.text_projection.shape[1]*2 # to fit 1024 dimension of image embedding 38 | else: 39 | self.embed_dim = self.text_projection.shape[1] # original output, 512 dim 40 | 41 | self.trainable = False 42 | self._set_trainable() 43 | 44 | @property 45 | def dtype(self): 46 | return self.transformer.resblocks[0].attn.in_proj_weight.dtype 47 | 48 | def encode_text(self, text): 49 | text[text < 0] = 0 # some padded text token maybe negative, so set them to 0 50 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 51 | 52 | x = x + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | if self.pick_last_embedding: 60 | # take features from the eot embedding (eot_token is the highest number in each sequence) 61 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection # [batch_size, transformer.width] 62 | if self.keep_seq_len_dim: 63 | x = x.unsqueeze(dim=1) # [batch_size, 1, transformer.width] 64 | return x 65 | 66 | 67 | 68 | def forward(self, index, **kwargs): 69 | """ 70 | index: B x L, index 71 | mask: B x L, bool type. The value of False indicating padded index 72 | """ 73 | assert index.dim() == 2 # B x L 74 | text_feature = self.encode_text(index) 75 | 76 | if self.embed_dim == 1024: 77 | text_features = torch.cat((text_feature, text_feature), dim=2) 78 | else: 79 | text_features = text_feature 80 | if self.normalize: 81 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 82 | 83 | if self.additional_last_embedding == True: 84 | last_feature = text_feature[torch.arange(text_feature.shape[0]), index.argmax(dim=-1)] @ self.text_projection 85 | if self.keep_seq_len_dim: 86 | last_feature = last_feature.unsqueeze(dim=1) 87 | return text_features, last_feature 88 | 89 | 90 | return text_features 91 | 92 | 93 | if __name__ == "__main__": 94 | model = CLIPTextEmbedding() -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/my_tokenizer/base_codec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseCodec(nn.Module): 6 | 7 | def get_tokens(self, x, **kwargs): 8 | """ 9 | Input: 10 | x: input data 11 | Return: 12 | indices: B x L, the codebook indices, where L is the length 13 | of flattened feature map size 14 | """ 15 | raise NotImplementedError 16 | 17 | def get_number_of_tokens(self): 18 | """ 19 | Return: int, the number of tokens 20 | """ 21 | raise NotImplementedError 22 | 23 | def encode(self, img): 24 | raise NotImplementedError 25 | 26 | def decode(self, img_seq): 27 | raise NotImplementedError 28 | 29 | def forward(self, **kwargs): 30 | raise NotImplementedError 31 | 32 | def train(self, mode=True): 33 | self.training = mode 34 | if self.trainable and mode: 35 | return super().train(True) 36 | else: 37 | return super().train(False) 38 | 39 | def _set_trainable(self): 40 | if not self.trainable: 41 | for pn, p in self.named_parameters(): 42 | p.requires_grad = False 43 | self.eval() -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/my_tokenizer/my_tokenize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, sys 3 | sys.path.append(os.getcwd()) 4 | from modules.clip_text_encoder.clip.clip import tokenize 5 | from modules.clip_text_encoder.my_tokenizer.base_codec import BaseCodec 6 | from utils.utils import instantiate_from_config 7 | 8 | class Tokenize(BaseCodec): 9 | def __init__(self, 10 | context_length:int = 77, 11 | add_start_and_end:bool = True, 12 | just_token = False, 13 | with_mask:bool = True, 14 | pad_value:int = 0, 15 | clip_embedding = False, 16 | condition_emb_config = None, 17 | tokenizer_config={ 18 | 'target': 'modules.clip_text_encoder.clip.simple_tokenizer.SimpleTokenizer', 19 | 'params':{ 20 | 'end_idx': 49152 # 16384 fo DALL-E 21 | }, 22 | }, 23 | ): 24 | """ 25 | This is a wrapper class for tokenize of texts. 26 | For CLIP and DALLE-pytorch tokenize, the default 27 | arguments are different: 28 | 29 | CLIP based: 30 | context_length: 77 31 | add_start_and_end: True 32 | 33 | DALLE-pytorch based: 34 | context_length: 256 35 | add_start_and_end: False 36 | 37 | """ 38 | super().__init__() 39 | self.context_length = context_length 40 | self.add_start_and_end = add_start_and_end 41 | self.with_mask = with_mask 42 | self.pad_value = pad_value 43 | self.just_token = just_token 44 | self.trainable = False 45 | self.condition_emb = None 46 | self.clip_embedding = clip_embedding 47 | if self.clip_embedding == True: 48 | assert condition_emb_config != None 49 | self.condition_emb = instantiate_from_config(condition_emb_config) 50 | 51 | self.tokenizer = instantiate_from_config(tokenizer_config) 52 | 53 | def __repr__(self): 54 | rep = "Tokenize for text\n\tcontent_length: {}\n\tadd_start_and_end: {}\n\twith_mask: {}"\ 55 | .format(self.context_length, self.add_start_and_end, self.with_mask) 56 | return rep 57 | 58 | def check_length(self, token): 59 | return len(token) <= self.context_length 60 | 61 | def get_tokens(self, text, **kwargs): 62 | text_token = tokenize(text, context_length=self.context_length, 63 | add_start_and_end=self.add_start_and_end, 64 | with_mask=self.with_mask, pad_value=self.pad_value, 65 | tokenizer=self.tokenizer, 66 | just_token=self.just_token) 67 | if self.clip_embedding == False: 68 | return text_token 69 | else: 70 | if self.condition_emb.additional_last_embedding == True: 71 | with torch.no_grad(): 72 | cond_emb, last_embedding = self.condition_emb(text_token['token'].cuda()) 73 | text_token['embed_token'] = cond_emb.detach() 74 | text_token['last_embed'] = last_embedding 75 | else: 76 | with torch.no_grad(): 77 | cond_emb = self.condition_emb(text_token['token'].cuda()) 78 | text_token['embed_token'] = cond_emb.detach() 79 | 80 | return text_token 81 | 82 | 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | text = "this is a test !" 88 | tokenizer = Tokenize() 89 | 90 | ids = tokenizer.get_tokens(text) 91 | print(text) 92 | print(ids) -------------------------------------------------------------------------------- /modules/tokenizers/SimpleSampleTokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | # 如果mask token初始化为全0,训练的结果也是全0 8 | class SimpleSampleTokenizer(nn.Module): 9 | def __init__(self, 10 | topk, 11 | input_token_num, 12 | input_dim, 13 | output_dim, 14 | z_dim, 15 | patch_size, 16 | score_pred_net_mode="2layer-fc", 17 | apply_norm_image_features=True, 18 | 19 | mask_token_initialization=0, 20 | ): 21 | r""" 22 | mask_token_initialization: mask token 初始化的方式,默认为全0 23 | choice: 24 | 0 : 全0初始化, 最后训练的结果也是全0 25 | 1 : 全1初始化 26 | "uniform": 正态分布初始化 27 | """ 28 | super().__init__() 29 | self.input_token_num = input_token_num 30 | self.sample_num = int(topk * input_token_num) 31 | self.topk_num = int(topk * input_token_num) 32 | self.apply_norm_image_features = apply_norm_image_features 33 | self.z_dim = z_dim 34 | self.patch_size = patch_size 35 | self.hw = int(input_token_num**0.5) 36 | if score_pred_net_mode == "2layer-fc": 37 | self.score_pred_net = nn.Sequential(nn.Linear(input_dim, input_dim), 38 | nn.ReLU(), 39 | nn.Linear(input_dim, 1), 40 | nn.Sigmoid()) 41 | else: 42 | raise ValueError 43 | 44 | if self.apply_norm_image_features: 45 | self.norm_feature = nn.LayerNorm(input_dim, elementwise_affine=False) 46 | 47 | # 更新,2022-4-4,更新mask token的初始化方式 48 | if mask_token_initialization == 0: 49 | self.mask_token = nn.Parameter(torch.zeros(1, output_dim, 1), requires_grad=True) 50 | elif mask_token_initialization == 1: 51 | self.mask_token = nn.Parameter(torch.ones(1, output_dim, 1), requires_grad=True) 52 | elif mask_token_initialization == "uniform": 53 | self.mask_token = nn.Parameter(torch.zeros(1, output_dim, 1), requires_grad=True) 54 | self.mask_token.data.uniform_(-1.0 / output_dim, 1.0 / output_dim) 55 | elif mask_token_initialization == "random": 56 | self.mask_token = nn.Parameter(torch.randn(1, output_dim, 1), requires_grad=True) 57 | else: 58 | raise NotImplementedError() 59 | 60 | self.mask = torch.from_numpy(np.zeros(self.input_token_num)).float() 61 | 62 | self.decode_dim = output_dim 63 | self.post_projection = torch.nn.Conv2d(self.z_dim, self.decode_dim, 1) 64 | self.pre_projection = torch.nn.Linear(input_dim, self.z_dim) 65 | 66 | def preforward(self, image_features): 67 | 68 | image_features_avg_values = image_features.mean(1) # for rebuttal 69 | 70 | image_features = rearrange(image_features, "B C H W -> B (H W) C") 71 | batch_size, length, channel = image_features.size() 72 | 73 | pred_score = self.score_pred_net(image_features).view(batch_size, -1) 74 | pred_score_clone = pred_score.clone().detach() # Note: why clone? 75 | 76 | sort_score, sort_order = pred_score_clone.sort(descending=True,dim=1) 77 | sort_topk = sort_order[:, :self.topk_num] 78 | sort_topk_remaining = sort_order[:, self.topk_num:] 79 | ## flatten for gathering 80 | if self.apply_norm_image_features: 81 | image_features = self.norm_feature(image_features) 82 | 83 | ## (only) sampled features multiply with score 84 | image_features_sampled = image_features.gather( 85 | 1, sort_topk[...,None].expand(-1, -1, channel)) * pred_score.gather(1, sort_topk).unsqueeze(-1) 86 | image_features = rearrange(self.pre_projection(image_features_sampled), "B N C -> B C N") 87 | 88 | # get mask 89 | self.mask = self.mask.to(image_features.device) 90 | for i in range(batch_size): 91 | if i == 0: 92 | mask = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 93 | else: 94 | mask_i = self.mask.scatter(-1, sort_topk[i], 1.).view(self.hw, self.hw).unsqueeze(0) 95 | mask = torch.cat([mask, mask_i], dim=0) 96 | squeezed_mask = mask.view(batch_size, -1) # [batch_size, length] 97 | mask = F.interpolate(mask.float().unsqueeze(1), scale_factor=self.patch_size, mode="nearest") 98 | 99 | # normalize the score just for better visualization 100 | normed_score = pred_score_clone.sub(pred_score_clone.min()).div(max(pred_score_clone.max() - pred_score_clone.min(), 1e-5)).unsqueeze(-1) 101 | normed_score = F.interpolate(rearrange(normed_score, "b (h w) c -> b c h w", h=self.hw, w=self.hw), scale_factor=self.patch_size, mode="nearest") 102 | 103 | return_dict = { 104 | "sample_h": image_features, 105 | "sample_index": sort_topk, 106 | "remain_index": sort_topk_remaining, 107 | "binary_map": mask, 108 | "score_map": normed_score, 109 | "squeezed_mask": squeezed_mask, 110 | "sort_score": sort_score[:, :self.topk_num], 111 | "image_features_avg_values": image_features_avg_values, 112 | "predicted_score": pred_score_clone, 113 | } 114 | 115 | return return_dict 116 | 117 | def postforward(self, sample_h, sample_index, remain_index=None): 118 | batch_size = sample_h.size(0) 119 | sample_index = sample_index.cpu().numpy().tolist() 120 | if remain_index == None: 121 | remain_index = [[j for j in range(self.input_token_num) if j not in sample_index[i]] for i in range(batch_size)] 122 | else: 123 | remain_index = remain_index.cpu().numpy().tolist() 124 | 125 | decoder_embeeding = nn.Parameter(torch.zeros((batch_size, self.decode_dim, self.input_token_num))).to(sample_h.device) 126 | 127 | for i in range(batch_size): 128 | decoder_embeeding[i, :, sample_index[i]] = sample_h[i] 129 | decoder_embeeding[i, :, remain_index[i]] = self.mask_token 130 | 131 | decoder_embeeding = rearrange(decoder_embeeding, "B C (H W) -> B C H W", H=self.hw, W=self.hw) 132 | decoder_embeeding = self.post_projection(decoder_embeeding) 133 | 134 | return { 135 | "decoder_embeeding": decoder_embeeding, 136 | } 137 | 138 | 139 | if __name__ == "__main__": 140 | model = SimpleSampleTokenizer( 141 | topk=0.5, 142 | input_token_num=256, 143 | input_dim=256, 144 | output_dim=256, 145 | z_dim=256, 146 | patch_size=16, 147 | score_pred_net_mode="2layer-fc", 148 | apply_norm_image_features=True, 149 | ) 150 | x = torch.randn(10,256,16,16) 151 | preforward_dict = model.preforward(x) 152 | postforward_dict = model.postforward( 153 | preforward_dict["sample_h"], 154 | preforward_dict["sample_index"], 155 | ) 156 | 157 | # print(preforward_dict["binary_map"].size()) 158 | # print(preforward_dict["score_map"]) -------------------------------------------------------------------------------- /modules/transformer/hybrid_decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | from modules.transformer.modules import Block 7 | from modules.transformer.position_embeddings import build_position_embed 8 | from utils.utils import instantiate_from_config 9 | 10 | class VisionTransformerDecoder(nn.Module): 11 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, 12 | depth, num_heads, attn_drop_rate=0., drop_rate=0.,init_type="default", 13 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 14 | super().__init__() 15 | 16 | self.hw = image_size // patch_size 17 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 18 | self.blocks = nn.ModuleList([ 19 | Block( 20 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 21 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 22 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 23 | for i in range(depth)]) 24 | 25 | self.patch_size = patch_size 26 | 27 | if init_type == "default": 28 | self.apply(self._init_weights) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.LayerNorm): 36 | nn.init.constant_(m.bias, 0) 37 | nn.init.constant_(m.weight, 1.0) 38 | 39 | def forward(self, x): 40 | B = x.size(0) 41 | x = self.pos_emb(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 47 | return x 48 | 49 | 50 | class HybrdDecoder(nn.Module): 51 | def __init__(self, transformer_config, cnn_config): 52 | super().__init__() 53 | self.transformer = instantiate_from_config(transformer_config) 54 | self.cnn = instantiate_from_config(cnn_config) 55 | self.conv_out = self.cnn.conv_out 56 | 57 | def forward(self, x): 58 | x = self.transformer(x) 59 | x = self.cnn(x) 60 | return x 61 | 62 | # transformer 模块的输入还加入了mask,以实现自定义的功能 63 | class HybrdDecoder_V2(nn.Module): 64 | def __init__(self, transformer_config, cnn_config): 65 | super().__init__() 66 | self.transformer = instantiate_from_config(transformer_config) 67 | self.cnn = instantiate_from_config(cnn_config) 68 | self.conv_out = self.cnn.conv_out 69 | 70 | def forward(self, x, mask): 71 | x = self.transformer(x, mask) 72 | x = self.cnn(x) 73 | return x 74 | 75 | if __name__ == "__main__": 76 | x = torch.randn(10, 512, 16, 16) 77 | vit_decoder = VisionTransformerDecoder( 78 | image_size=256, patch_size=16, pos_embed_type="learned-2d", embed_dim=512, drop_rate=0., 79 | depth=6, num_heads=4, attn_drop_rate=0., output_channel=3, init_type="default" 80 | ) 81 | y = vit_decoder(x) 82 | print(y.size()) -------------------------------------------------------------------------------- /modules/transformer/mask_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # 我们希望自注意力中的信息流向 bias 为 unmasked token -> masked token 5 | class MaskSelfAttention_SquareGrowth(nn.Module): 6 | def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): 7 | super().__init__() 8 | self.num_heads = num_heads 9 | head_dim = dim // num_heads 10 | self.scale = head_dim ** -0.5 11 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 12 | self.proj = nn.Linear(dim, dim) 13 | 14 | self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj_drop = nn.Dropout(proj_drop) 16 | 17 | def forward(self, h, mask=None): 18 | # mask (_type_, optional): [batch_size, length] 19 | B, N, C = h.shape 20 | qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 21 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 22 | attn = (q @ k.transpose(-2, -1)).contiguous() * self.scale 23 | attn = attn.softmax(dim=-1) 24 | 25 | if mask is not None: 26 | unsqueezed_mask = mask.unsqueeze(-2).unsqueeze(-2) 27 | attn = attn * unsqueezed_mask 28 | 29 | # update mask with SquareGrowth 30 | new_mask = torch.sqrt(mask) 31 | 32 | attn = self.attn_drop(attn) 33 | h = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, C) 34 | h = self.proj(h) 35 | h = self.proj_drop(h) 36 | return h, new_mask 37 | 38 | 39 | if __name__ == "__main__": 40 | batch_size = 1 41 | dim = 256 42 | height = 2 43 | model = MaskSelfAttention_SquareGrowth(dim=256, num_heads=4) 44 | 45 | h = torch.randn(batch_size, height*height, 256) # (10,256,16,16) 46 | mask = torch.randint(0, 2, (batch_size, height*height)) 47 | print(mask) 48 | 49 | new_mask = mask + 0.02 * (1 - mask) 50 | print(new_mask) 51 | 52 | 53 | model(h, new_mask) -------------------------------------------------------------------------------- /modules/transformer/mask_attention_decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | import os, sys 7 | sys.path.append(os.getcwd()) 8 | 9 | from modules.transformer.modules import norm, Mlp 10 | from modules.transformer.mask_attention import MaskSelfAttention_SquareGrowth 11 | from modules.transformer.position_embeddings import build_position_embed 12 | 13 | class MaskBlock(nn.Module): 14 | def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., init_values=None, 15 | act_type="GELU", norm_type="layer", attn_type="msa-sg", size=None): 16 | super().__init__() 17 | self.norm1 = norm(norm_type, dim) 18 | 19 | self.norm2 = norm(norm_type, dim) 20 | self.mlp = Mlp(dim, hidden_radio=mlp_ratio, act_type=act_type, drop=drop) 21 | 22 | if attn_type == "msa-sg": 23 | self.attn = MaskSelfAttention_SquareGrowth( 24 | dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop 25 | ) 26 | else: 27 | raise ValueError 28 | 29 | if init_values > 0: 30 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 31 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 32 | else: 33 | self.gamma_1, self.gamma_2 = None, None 34 | 35 | def forward(self, x, mask, **ignore_kwargs): 36 | if self.gamma_1 is None: 37 | attn, new_mask = self.attn(h=self.norm1(x), mask=mask) 38 | x = x + attn 39 | x = x + self.mlp(self.norm2(x)) 40 | else: 41 | attn, new_mask = self.attn(h=self.norm1(x), mask=mask) 42 | x = x + self.gamma_1 * attn 43 | x = x + self.gamma_2 * self.mlp(self.norm2(x)) 44 | return x, new_mask 45 | 46 | class MaskVisionTransformerDecoder(nn.Module): 47 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, 48 | depth, num_heads, attn_drop_rate=0., drop_rate=0.,init_type="default", 49 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 50 | super().__init__() 51 | 52 | self.hw = image_size // patch_size 53 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 54 | self.blocks = nn.ModuleList([ 55 | MaskBlock( 56 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 57 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 58 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 59 | for i in range(depth)]) 60 | 61 | self.patch_size = patch_size 62 | 63 | if init_type == "default": 64 | self.apply(self._init_weights) 65 | 66 | def _init_weights(self, m): 67 | if isinstance(m, nn.Linear): 68 | trunc_normal_(m.weight, std=.02) 69 | if isinstance(m, nn.Linear) and m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.LayerNorm): 72 | nn.init.constant_(m.bias, 0) 73 | nn.init.constant_(m.weight, 1.0) 74 | 75 | def forward(self, x, mask): 76 | B = x.size(0) 77 | x = self.pos_emb(x) 78 | 79 | new_mask = mask + 0.02 * (1 - mask) # 将0初始化为一个小值,0.02 80 | for blk in self.blocks: 81 | # print(new_mask) 82 | x, new_mask = blk(x=x, mask=new_mask) 83 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 84 | return x -------------------------------------------------------------------------------- /modules/transformer/permuter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractPermuter(nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | def forward(self, x, reverse=False): 10 | raise NotImplementedError 11 | 12 | 13 | class Identity(AbstractPermuter): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x, reverse=False): 18 | return x 19 | 20 | 21 | class Subsample(AbstractPermuter): 22 | def __init__(self, H, W): 23 | super().__init__() 24 | C = 1 25 | indices = np.arange(H*W).reshape(C,H,W) 26 | while min(H, W) > 1: 27 | indices = indices.reshape(C,H//2,2,W//2,2) 28 | indices = indices.transpose(0,2,4,1,3) 29 | indices = indices.reshape(C*4,H//2, W//2) 30 | H = H//2 31 | W = W//2 32 | C = C*4 33 | assert H == W == 1 34 | idx = torch.tensor(indices.ravel()) 35 | self.register_buffer('forward_shuffle_idx', 36 | nn.Parameter(idx, requires_grad=False)) 37 | self.register_buffer('backward_shuffle_idx', 38 | nn.Parameter(torch.argsort(idx), requires_grad=False)) 39 | 40 | def forward(self, x, reverse=False): 41 | if not reverse: 42 | return x[:, self.forward_shuffle_idx] 43 | else: 44 | return x[:, self.backward_shuffle_idx] 45 | 46 | 47 | def mortonify(i, j): 48 | """(i,j) index to linear morton code""" 49 | i = np.uint64(i) 50 | j = np.uint64(j) 51 | 52 | z = np.uint(0) 53 | 54 | for pos in range(32): 55 | z = (z | 56 | ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | 57 | ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) 58 | ) 59 | return z 60 | 61 | 62 | class ZCurve(AbstractPermuter): 63 | def __init__(self, H, W): 64 | super().__init__() 65 | reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] 66 | idx = np.argsort(reverseidx) 67 | idx = torch.tensor(idx) 68 | reverseidx = torch.tensor(reverseidx) 69 | self.register_buffer('forward_shuffle_idx', 70 | idx) 71 | self.register_buffer('backward_shuffle_idx', 72 | reverseidx) 73 | 74 | def forward(self, x, reverse=False): 75 | if not reverse: 76 | return x[:, self.forward_shuffle_idx] 77 | else: 78 | return x[:, self.backward_shuffle_idx] 79 | 80 | 81 | class SpiralOut(AbstractPermuter): 82 | def __init__(self, H, W): 83 | super().__init__() 84 | assert H == W 85 | size = W 86 | indices = np.arange(size*size).reshape(size,size) 87 | 88 | i0 = size//2 89 | j0 = size//2-1 90 | 91 | i = i0 92 | j = j0 93 | 94 | idx = [indices[i0, j0]] 95 | step_mult = 0 96 | for c in range(1, size//2+1): 97 | step_mult += 1 98 | # steps left 99 | for k in range(step_mult): 100 | i = i - 1 101 | j = j 102 | idx.append(indices[i, j]) 103 | 104 | # step down 105 | for k in range(step_mult): 106 | i = i 107 | j = j + 1 108 | idx.append(indices[i, j]) 109 | 110 | step_mult += 1 111 | if c < size//2: 112 | # step right 113 | for k in range(step_mult): 114 | i = i + 1 115 | j = j 116 | idx.append(indices[i, j]) 117 | 118 | # step up 119 | for k in range(step_mult): 120 | i = i 121 | j = j - 1 122 | idx.append(indices[i, j]) 123 | else: 124 | # end reached 125 | for k in range(step_mult-1): 126 | i = i + 1 127 | idx.append(indices[i, j]) 128 | 129 | assert len(idx) == size*size 130 | idx = torch.tensor(idx) 131 | self.register_buffer('forward_shuffle_idx', idx) 132 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 133 | 134 | def forward(self, x, reverse=False): 135 | if not reverse: 136 | return x[:, self.forward_shuffle_idx] 137 | else: 138 | return x[:, self.backward_shuffle_idx] 139 | 140 | 141 | class SpiralIn(AbstractPermuter): 142 | def __init__(self, H, W): 143 | super().__init__() 144 | assert H == W 145 | size = W 146 | indices = np.arange(size*size).reshape(size,size) 147 | 148 | i0 = size//2 149 | j0 = size//2-1 150 | 151 | i = i0 152 | j = j0 153 | 154 | idx = [indices[i0, j0]] 155 | step_mult = 0 156 | for c in range(1, size//2+1): 157 | step_mult += 1 158 | # steps left 159 | for k in range(step_mult): 160 | i = i - 1 161 | j = j 162 | idx.append(indices[i, j]) 163 | 164 | # step down 165 | for k in range(step_mult): 166 | i = i 167 | j = j + 1 168 | idx.append(indices[i, j]) 169 | 170 | step_mult += 1 171 | if c < size//2: 172 | # step right 173 | for k in range(step_mult): 174 | i = i + 1 175 | j = j 176 | idx.append(indices[i, j]) 177 | 178 | # step up 179 | for k in range(step_mult): 180 | i = i 181 | j = j - 1 182 | idx.append(indices[i, j]) 183 | else: 184 | # end reached 185 | for k in range(step_mult-1): 186 | i = i + 1 187 | idx.append(indices[i, j]) 188 | 189 | assert len(idx) == size*size 190 | idx = idx[::-1] 191 | idx = torch.tensor(idx) 192 | self.register_buffer('forward_shuffle_idx', idx) 193 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 194 | 195 | def forward(self, x, reverse=False): 196 | if not reverse: 197 | return x[:, self.forward_shuffle_idx] 198 | else: 199 | return x[:, self.backward_shuffle_idx] 200 | 201 | 202 | class Random(nn.Module): 203 | def __init__(self, H, W): 204 | super().__init__() 205 | indices = np.random.RandomState(1).permutation(H*W) 206 | idx = torch.tensor(indices.ravel()) 207 | self.register_buffer('forward_shuffle_idx', idx) 208 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 209 | 210 | def forward(self, x, reverse=False): 211 | if not reverse: 212 | return x[:, self.forward_shuffle_idx] 213 | else: 214 | return x[:, self.backward_shuffle_idx] 215 | 216 | 217 | class AlternateParsing(AbstractPermuter): 218 | def __init__(self, H, W): 219 | super().__init__() 220 | indices = np.arange(W*H).reshape(H,W) 221 | for i in range(1, H, 2): 222 | indices[i, :] = indices[i, ::-1] 223 | idx = indices.flatten() 224 | assert len(idx) == H*W 225 | idx = torch.tensor(idx) 226 | self.register_buffer('forward_shuffle_idx', idx) 227 | self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) 228 | 229 | def forward(self, x, reverse=False): 230 | if not reverse: 231 | return x[:, self.forward_shuffle_idx] 232 | else: 233 | return x[:, self.backward_shuffle_idx] 234 | 235 | 236 | if __name__ == "__main__": 237 | p0 = AlternateParsing(16, 16) 238 | print(p0.forward_shuffle_idx) 239 | print(p0.backward_shuffle_idx) 240 | 241 | x = torch.randint(0, 768, size=(11, 256)) 242 | y = p0(x) 243 | xre = p0(y, reverse=True) 244 | assert torch.equal(x, xre) 245 | 246 | p1 = SpiralOut(2, 2) 247 | print(p1.forward_shuffle_idx) 248 | print(p1.backward_shuffle_idx) 249 | -------------------------------------------------------------------------------- /modules/transformer/position_aware_mingpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class PositionAwareGPTConfig: 11 | """ base GPT config, params common to all GPT versions """ 12 | embd_pdrop = 0.1 13 | resid_pdrop = 0.1 14 | attn_pdrop = 0.1 15 | 16 | def __init__(self, vocab_size, position_size, block_size, **kwargs): 17 | self.vocab_size = vocab_size 18 | self.position_size = position_size 19 | self.block_size = block_size 20 | for k,v in kwargs.items(): 21 | setattr(self, k, v) 22 | 23 | class CausalSelfAttention(nn.Module): 24 | """ 25 | A vanilla multi-head masked self-attention layer with a projection at the end. 26 | It is possible to use torch.nn.MultiheadAttention here but I am including an 27 | explicit implementation here to show that there is nothing too scary here. 28 | """ 29 | 30 | def __init__(self, config): 31 | super().__init__() 32 | assert config.n_embd % config.n_head == 0 33 | # key, query, value projections for all heads 34 | self.key = nn.Linear(config.n_embd, config.n_embd) 35 | self.query = nn.Linear(config.n_embd, config.n_embd) 36 | self.value = nn.Linear(config.n_embd, config.n_embd) 37 | # regularization 38 | self.attn_drop = nn.Dropout(config.attn_pdrop) 39 | self.resid_drop = nn.Dropout(config.resid_pdrop) 40 | # output projection 41 | self.proj = nn.Linear(config.n_embd, config.n_embd) 42 | # causal mask to ensure that attention is only applied to the left in the input sequence 43 | mask = torch.tril(torch.ones(config.block_size, 44 | config.block_size)) 45 | if hasattr(config, "n_unmasked"): 46 | mask[:config.n_unmasked, :config.n_unmasked] = 1 47 | self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) 48 | self.n_head = config.n_head 49 | 50 | def forward(self, x, layer_past=None): 51 | B, T, C = x.size() 52 | 53 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 54 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 55 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 56 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 57 | 58 | present = torch.stack((k, v)) 59 | if layer_past is not None: 60 | past_key, past_value = layer_past 61 | k = torch.cat((past_key, k), dim=-2) 62 | v = torch.cat((past_value, v), dim=-2) 63 | 64 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 65 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 66 | if layer_past is None: 67 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 68 | 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_drop(att) 71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_drop(self.proj(y)) 76 | return y, present # TODO: check that this does not break anything 77 | 78 | class Block(nn.Module): 79 | """ an unassuming Transformer block """ 80 | def __init__(self, config): 81 | super().__init__() 82 | self.ln1 = nn.LayerNorm(config.n_embd) 83 | self.ln2 = nn.LayerNorm(config.n_embd) 84 | self.attn = CausalSelfAttention(config) 85 | self.mlp = nn.Sequential( 86 | nn.Linear(config.n_embd, 4 * config.n_embd), 87 | nn.GELU(), # nice 88 | nn.Linear(4 * config.n_embd, config.n_embd), 89 | nn.Dropout(config.resid_pdrop), 90 | ) 91 | 92 | def forward(self, x, layer_past=None, return_present=False): 93 | # TODO: check that training still works 94 | if return_present: assert not self.training 95 | # layer past: tuple of length two with B, nh, T, hs 96 | attn, present = self.attn(self.ln1(x), layer_past=layer_past) 97 | 98 | x = x + attn 99 | x = x + self.mlp(self.ln2(x)) 100 | if layer_past is not None or return_present: 101 | return x, present 102 | return x 103 | 104 | class PositionAwareGPT(nn.Module): 105 | """ the full GPT language model, with a context size of block_size """ 106 | def __init__(self, vocab_size, position_size, block_size, n_layer=12, n_head=8, n_embd=256, 107 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): 108 | super().__init__() 109 | config = PositionAwareGPTConfig( 110 | vocab_size=vocab_size, block_size=block_size, position_size=position_size, 111 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 112 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 113 | n_unmasked=n_unmasked) 114 | # input embedding stem 115 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 116 | self.token_pos_emb = nn.Embedding(config.position_size, config.n_embd) 117 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 118 | self.drop = nn.Dropout(config.embd_pdrop) 119 | # transformer 120 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 121 | # decoder head for value predictions 122 | self.ln_f = nn.LayerNorm(config.n_embd) 123 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 124 | 125 | # decoder head for position predictiona 126 | self.ln_f_pos = nn.LayerNorm(config.n_embd) 127 | self.head_pos = nn.Linear(config.n_embd, config.position_size, bias=False) 128 | 129 | self.block_size = config.block_size 130 | self.apply(self._init_weights) 131 | self.config = config 132 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 133 | 134 | def get_block_size(self): 135 | return self.block_size 136 | 137 | def _init_weights(self, module): 138 | if isinstance(module, (nn.Linear, nn.Embedding)): 139 | module.weight.data.normal_(mean=0.0, std=0.02) 140 | if isinstance(module, nn.Linear) and module.bias is not None: 141 | module.bias.data.zero_() 142 | elif isinstance(module, nn.LayerNorm): 143 | module.bias.data.zero_() 144 | module.weight.data.fill_(1.0) 145 | 146 | def forward(self, idx, pos_idx, embeddings=None): 147 | # forward the GPT model 148 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 149 | token_pos_embeddings = self.token_pos_emb(pos_idx) # each position maps to a (learnable) vector 150 | token_embeddings = token_embeddings + token_pos_embeddings 151 | 152 | if embeddings is not None: # prepend explicit embeddings 153 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 154 | 155 | t = token_embeddings.shape[1] 156 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 157 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 158 | x = self.drop(token_embeddings + position_embeddings) 159 | x = self.blocks(x) 160 | 161 | # for value prediction 162 | value_x = self.ln_f(x) 163 | value_logits = self.head(value_x) 164 | 165 | # for position prediction 166 | pos_x = self.ln_f_pos(x) 167 | pos_logits = self.head_pos(pos_x) 168 | 169 | return value_logits, pos_logits 170 | 171 | 172 | 173 | if __name__ == "__main__": 174 | model = PositionAwareGPT( 175 | vocab_size=1024, position_size=256, block_size=257, 176 | n_layer=3, n_head=4, n_embd=256, 177 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0 178 | ) 179 | 180 | idx = torch.randint(0, 1024, (1, 257)) 181 | 182 | value_logits, pos_logits = model(idx) 183 | 184 | print(value_logits.size(), pos_logits.size()) -------------------------------------------------------------------------------- /modules/transformer/position_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from einops import rearrange 6 | 7 | # input is (B,C,H,W), output is (B,HW,C) 8 | def build_position_embed(embed_type='learned', feats_dim=512, dropout=0., n_row=16): 9 | if embed_type == 'sine-1d': 10 | pos_embed = PositionalEncoding1d(emb_dim=feats_dim, dropout=dropout) 11 | elif embed_type == "sine-2d": 12 | pos_embed = PositionalEncoding2d(emb_dim=feats_dim, dropout=dropout) 13 | elif embed_type == "learned-2d": 14 | pos_embed = PositionEmbeddingLearned(n_row=n_row, feats_dim=feats_dim, dropout=dropout) 15 | else: 16 | raise ValueError(f"nor supported {embed_type}") 17 | return pos_embed 18 | 19 | 20 | ###################################################################################### 21 | # 1D position embedding 22 | ###################################################################################### 23 | class PositionalEncoding1d(nn.Module): 24 | def __init__(self, emb_dim, dropout=0.1, max_len=5000): 25 | super(PositionalEncoding1d, self).__init__() 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | pe = torch.zeros(max_len, emb_dim) 29 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 30 | div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim)) 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | pe[:, 1::2] = torch.cos(position * div_term) 33 | pe = pe.unsqueeze(0).transpose(0, 1) 34 | self.register_buffer('pe', pe) 35 | 36 | def forward(self, x): 37 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 38 | x = x + self.pe[:x.size(0), :].to(x.device) 39 | return self.dropout(x) 40 | 41 | ###################################################################################### 42 | # 2D position embedding 43 | ###################################################################################### 44 | class PositionalEncoding2d(nn.Module): 45 | def __init__(self, emb_dim, dropout, max_len=5000): 46 | super(PositionalEncoding2d, self).__init__() 47 | 48 | self.dropout = nn.Dropout(dropout) 49 | 50 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 51 | div_term = torch.exp(torch.arange(0, (emb_dim//2), 2).float() * (-math.log(10000.0) / (emb_dim//2))) 52 | pe_x = torch.zeros(max_len, emb_dim//2) 53 | pe_x[:, 0::2] = torch.sin(position * div_term) 54 | pe_x[:, 1::2] = torch.cos(position * div_term) 55 | 56 | pe_y = torch.zeros(max_len, emb_dim//2) 57 | pe_y[:, 1::2] = torch.sin(position * div_term) 58 | pe_y[:, 0::2] = torch.cos(position * div_term) 59 | 60 | self.register_buffer('pe_x', pe_x) 61 | self.register_buffer('pe_y', pe_y) 62 | 63 | def forward(self, x): 64 | _, _, h, w = x.shape 65 | add_x = self.pe_x[:h, :].unsqueeze(1).repeat(1,w,1) 66 | add_y = self.pe_y[:w, :].unsqueeze(0).repeat(h,1,1) 67 | add = torch.cat([add_x, add_y], dim=-1) #shape: h x w x dim 68 | add = add.permute(2, 0, 1).unsqueeze(0) 69 | 70 | x = x + add 71 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 72 | return self.dropout(x) 73 | 74 | class PositionEmbeddingLearned(nn.Module): 75 | """ 76 | This is a learned version of the position embedding 77 | """ 78 | def __init__(self, n_row, feats_dim, dropout, n_col=None): 79 | super().__init__() 80 | self.dropout = nn.Dropout(dropout) 81 | 82 | n_col = n_col if n_col is not None else n_row 83 | self.row_embed = nn.Embedding(n_row, feats_dim) 84 | self.col_embed = nn.Embedding(n_col, feats_dim) 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | nn.init.uniform_(self.row_embed.weight) 89 | nn.init.uniform_(self.col_embed.weight) 90 | 91 | def forward(self, x): 92 | h, w = x.shape[-2:] 93 | i = torch.arange(w, device=x.device) 94 | j = torch.arange(h, device=x.device) 95 | x_emb = self.col_embed(i).unsqueeze(0).repeat(h, 1, 1) 96 | y_emb = self.row_embed(j).unsqueeze(1).repeat(1, w, 1) 97 | pos = (x_emb + y_emb).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 98 | 99 | x = x + pos 100 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 101 | return self.dropout(x) -------------------------------------------------------------------------------- /modules/transformer/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | 7 | from modules.transformer.modules import PatchEmbed, Block 8 | from modules.transformer.position_embeddings import build_position_embed 9 | 10 | # vision transformer 11 | class VisionTransformerEncoder(nn.Module): 12 | def __init__(self, image_size, patch_size, input_channel, embed_dim, init_type, 13 | pos_embed_type, attn_drop_rate, drop_rate, depth, num_heads, 14 | mlp_ratio=4, norm_type="layer", act_type="GELU", init_values=0, attn_type="sa"): 15 | super().__init__() 16 | 17 | self.patch_embed = PatchEmbed( 18 | img_size=image_size, patch_size=patch_size, in_chans=input_channel, embed_dim=embed_dim 19 | ) 20 | 21 | self.hw = image_size // patch_size 22 | self.pos_emb = build_position_embed( 23 | embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw 24 | ) 25 | 26 | self.blocks = nn.ModuleList([ 27 | Block( 28 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 29 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 30 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 31 | for i in range(depth)]) 32 | 33 | if init_type == "default": 34 | self.apply(self._init_weights) 35 | 36 | def _init_weights(self, m): 37 | if isinstance(m, nn.Linear): 38 | trunc_normal_(m.weight, std=.02) 39 | if isinstance(m, nn.Linear) and m.bias is not None: 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.LayerNorm): 42 | nn.init.constant_(m.bias, 0) 43 | nn.init.constant_(m.weight, 1.0) 44 | 45 | def forward(self, images): 46 | x = self.patch_embed(images) 47 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 48 | x = self.pos_emb(x) 49 | for blk in self.blocks: 50 | x = blk(x) 51 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 52 | return x 53 | 54 | 55 | # vision transformer decoder 56 | class VisionTransformerDecoder(nn.Module): 57 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, drop_rate, 58 | depth, num_heads, attn_drop_rate, output_channel, init_type, 59 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 60 | super().__init__() 61 | 62 | self.hw = image_size // patch_size 63 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 64 | self.blocks = nn.ModuleList([ 65 | Block( 66 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 67 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 68 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 69 | for i in range(depth)]) 70 | 71 | # see: VECTOR-QUANTIZED IMAGE MODELING WITH IMPROVED VQGAN 72 | # self.output_linear = nn.Linear(embed_dim, patch_size*patch_size*output_channel) 73 | # self.output_linear = nn.Sequential( 74 | # nn.Linear(embed_dim, embed_dim), 75 | # nn.Tanh(), 76 | # nn.Linear(embed_dim, patch_size*patch_size*output_channel) 77 | # ) 78 | self.output_linear1 = nn.Linear(embed_dim, patch_size*patch_size*output_channel) 79 | self.conv_out = nn.Linear(patch_size*patch_size*output_channel, patch_size*patch_size*output_channel) 80 | # align with VQGAN 81 | 82 | self.output_channel = output_channel 83 | self.patch_size = patch_size 84 | 85 | if init_type == "default": 86 | self.apply(self._init_weights) 87 | 88 | def _init_weights(self, m): 89 | if isinstance(m, nn.Linear): 90 | trunc_normal_(m.weight, std=.02) 91 | if isinstance(m, nn.Linear) and m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.LayerNorm): 94 | nn.init.constant_(m.bias, 0) 95 | nn.init.constant_(m.weight, 1.0) 96 | 97 | def forward(self, x): 98 | B = x.size(0) 99 | x = self.pos_emb(x) 100 | 101 | for blk in self.blocks: 102 | x = blk(x) 103 | 104 | # x = self.output_linear(x) 105 | x = self.output_linear1(x) 106 | x = nn.Tanh()(x) 107 | x = self.conv_out(x) 108 | x = rearrange( 109 | x, "B (h w) (h_size w_size c) -> B c (h h_size) (w w_size)", 110 | B=B, h=self.hw, w=self.hw, h_size=self.patch_size, w_size=self.patch_size, c=self.output_channel).contiguous() 111 | return x 112 | 113 | if __name__ == "__main__": 114 | images = torch.randn(10,3,256,256) 115 | vit_encoder = VisionTransformerEncoder( 116 | image_size=256, patch_size=16, input_channel=3, embed_dim=512, init_type="default", 117 | pos_embed_type="learned-2d", attn_drop_rate=0., drop_rate=0., depth=6, num_heads=4 118 | ) 119 | x = vit_encoder(images) 120 | print(x.size()) 121 | 122 | vit_decoder = VisionTransformerDecoder( 123 | image_size=256, patch_size=16, pos_embed_type="learned-2d", embed_dim=512, drop_rate=0., 124 | depth=6, num_heads=4, attn_drop_rate=0., output_channel=3, init_type="default" 125 | ) 126 | y = vit_decoder(x) 127 | print(y.size()) -------------------------------------------------------------------------------- /modules/vector_quantization/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import torch.distributed as distributed 5 | from einops import rearrange, repeat 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | def noop(*args, **kwargs): 14 | pass 15 | 16 | def l2norm(t): 17 | return F.normalize(t, p = 2, dim = -1) 18 | 19 | def log(t, eps = 1e-20): 20 | return torch.log(t.clamp(min = eps)) 21 | 22 | def uniform_init(*shape): 23 | t = torch.empty(shape) 24 | nn.init.kaiming_uniform_(t) 25 | return t 26 | 27 | def gumbel_noise(t): 28 | noise = torch.zeros_like(t).uniform_(0, 1) 29 | return -log(-log(noise)) 30 | 31 | def gumbel_sample(t, temperature = 1., dim = -1): 32 | if temperature == 0: 33 | return t.argmax(dim = dim) 34 | 35 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 36 | 37 | def ema_inplace(moving_avg, new, decay): 38 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 39 | 40 | def laplace_smoothing(x, n_categories, eps = 1e-5): 41 | return (x + eps) / (x.sum() + n_categories * eps) 42 | 43 | def sample_vectors(samples, num): 44 | num_samples, device = samples.shape[0], samples.device 45 | if num_samples >= num: 46 | indices = torch.randperm(num_samples, device = device)[:num] 47 | else: 48 | indices = torch.randint(0, num_samples, (num,), device = device) 49 | 50 | return samples[indices] 51 | 52 | def batched_sample_vectors(samples, num): 53 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 54 | 55 | def pad_shape(shape, size, dim = 0): 56 | return [size if i == dim else s for i, s in enumerate(shape)] 57 | 58 | def sample_multinomial(total_count, probs): 59 | device = probs.device 60 | probs = probs.cpu() 61 | 62 | total_count = probs.new_full((), total_count) 63 | remainder = probs.new_ones(()) 64 | sample = torch.empty_like(probs, dtype = torch.long) 65 | 66 | for i, p in enumerate(probs): 67 | s = torch.binomial(total_count, p / remainder) 68 | sample[i] = s 69 | total_count -= s 70 | remainder -= p 71 | 72 | return sample.to(device) 73 | 74 | def all_gather_sizes(x, dim): 75 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 76 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 77 | distributed.all_gather(all_sizes, size) 78 | 79 | return torch.stack(all_sizes) 80 | 81 | def all_gather_variably_sized(x, sizes, dim = 0): 82 | rank = distributed.get_rank() 83 | all_x = [] 84 | 85 | for i, size in enumerate(sizes): 86 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 87 | distributed.broadcast(t, src = i, async_op = True) 88 | all_x.append(t) 89 | 90 | distributed.barrier() 91 | return all_x 92 | 93 | def sample_vectors_distributed(local_samples, num): 94 | rank = distributed.get_rank() 95 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 96 | 97 | if rank == 0: 98 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 99 | else: 100 | samples_per_rank = torch.empty_like(all_num_samples) 101 | 102 | distributed.broadcast(samples_per_rank, src = 0) 103 | samples_per_rank = samples_per_rank.tolist() 104 | 105 | local_samples = batched_sample_vectors(local_samples, samples_per_rank[rank]) 106 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 107 | return torch.cat(all_samples, dim = 0) 108 | 109 | def batched_bincount(x, *, minlength): 110 | batch, dtype, device = x.shape[0], x.dtype, x.device 111 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 112 | values = torch.ones_like(x) 113 | target.scatter_add_(-1, x, values) 114 | return target 115 | 116 | def kmeans( 117 | samples, 118 | num_clusters, 119 | num_iters = 10, 120 | use_cosine_sim = False, 121 | sample_fn = batched_sample_vectors, 122 | all_reduce_fn = noop 123 | ): 124 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 125 | 126 | means = sample_fn(samples, num_clusters) 127 | 128 | for _ in range(num_iters): 129 | if use_cosine_sim: 130 | dists = samples @ rearrange(means, 'h n d -> h d n') 131 | else: 132 | dists = -torch.cdist(samples, means, p = 2) 133 | 134 | buckets = torch.argmax(dists, dim = -1) 135 | bins = batched_bincount(buckets, minlength = num_clusters) 136 | all_reduce_fn(bins) 137 | 138 | zero_mask = bins == 0 139 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 140 | 141 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 142 | 143 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 144 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 145 | all_reduce_fn(new_means) 146 | 147 | if use_cosine_sim: 148 | new_means = l2norm(new_means) 149 | 150 | means = torch.where( 151 | rearrange(zero_mask, '... -> ... 1'), 152 | means, 153 | new_means 154 | ) 155 | 156 | return means, bins 157 | 158 | def batched_embedding(indices, embeds): 159 | batch, dim = indices.shape[1], embeds.shape[-1] 160 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 161 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 162 | return embeds.gather(2, indices) 163 | -------------------------------------------------------------------------------- /modules/vector_quantization/quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch import einsum 9 | 10 | sys.path.append(os.getcwd()) 11 | 12 | import modules.vector_quantization.common_utils as utils 13 | 14 | # support: cosine-similarity; kmeans centroids initialization; orthogonal_reg_weight, temperature sample 15 | class VectorQuantize(nn.Module): 16 | def __init__( 17 | self, 18 | codebook_size, 19 | codebook_dim = None, 20 | kmeans_init = False, 21 | kmeans_iters = 10, 22 | use_cosine_sim = False, 23 | use_cosine_distance = False, 24 | channel_last = False, 25 | accept_image_fmap = True, 26 | commitment_beta = 0.25, 27 | orthogonal_reg_weight = 0., 28 | ): 29 | super().__init__() 30 | 31 | self.codebook_size = codebook_size 32 | self.codebook_dim = codebook_dim 33 | self.accept_image_fmap = accept_image_fmap 34 | self.channel_last = channel_last 35 | self.use_cosine_sim = use_cosine_sim 36 | self.use_cosine_distance = use_cosine_distance 37 | self.beta = commitment_beta 38 | 39 | self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim) 40 | 41 | # codebook initialization 42 | if not kmeans_init: 43 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) 44 | else: 45 | self.embedding.weight.data.zero_() 46 | self.kmeans_iters = kmeans_iters 47 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 48 | self.register_buffer('cluster_size', torch.zeros(1, codebook_size)) 49 | 50 | self.sample_fn = utils.batched_sample_vectors 51 | self.all_reduce_fn = utils.noop 52 | 53 | # codebook_orthogonal_loss 54 | self.orthogonal_reg_weight = orthogonal_reg_weight 55 | 56 | def init_embed_(self, data): 57 | if self.initted: 58 | return 59 | 60 | data = rearrange(data, '... -> 1 ...').contiguous() 61 | data = rearrange(data, 'h ... d -> h (...) d').contiguous() 62 | 63 | embed, cluster_size = utils.kmeans( 64 | data, 65 | self.codebook_size, 66 | self.kmeans_iters, 67 | sample_fn = self.sample_fn, 68 | all_reduce_fn = self.all_reduce_fn 69 | ) 70 | 71 | self.embedding.weight.data.copy_(embed.squeeze(0)) 72 | self.cluster_size.data.copy_(cluster_size) 73 | self.initted.data.copy_(torch.Tensor([True])) 74 | 75 | def forward(self, x, temp=0.): 76 | need_transpose = not self.channel_last and not self.accept_image_fmap 77 | 78 | if self.accept_image_fmap: 79 | height, width = x.shape[-2:] 80 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 81 | 82 | if need_transpose: 83 | x = rearrange(x, 'b d n -> b n d').contiguous() 84 | shape, device, dtype = x.shape, x.device, x.dtype 85 | flatten = rearrange(x, 'h ... d -> h (...) d').contiguous() 86 | 87 | # if use cosine_sim, whether should norm the feature before k-means initialization ? 88 | # if self.use_cosine_sim: 89 | # flatten = F.normalize(flatten, p = 2, dim = -1) 90 | self.init_embed_(flatten) 91 | 92 | # calculate the distance 93 | if self.use_cosine_sim: # cosine similarity 94 | flatten_norm = F.normalize(flatten, p = 2, dim = -1) 95 | weight_norm = F.normalize(self.embedding.weight, p = 2, dim = -1).unsqueeze(0) 96 | 97 | dist = torch.matmul(flatten_norm, weight_norm.transpose(1, 2)) 98 | # dist = einsum('h n d, h c d -> h n c', flatten_norm, weight_norm) 99 | elif self.use_cosine_distance: # l2 normed distance, NOTE: experimental 100 | flatten_norm = F.normalize(flatten, p = 2, dim = -1).view(-1, self.codebook_dim) 101 | weight_norm = F.normalize(self.embedding.weight, p = 2, dim = -1) 102 | 103 | dist = - torch.sum(flatten_norm ** 2, dim=1, keepdim=True) - \ 104 | torch.sum(weight_norm**2, dim=1) + 2 * \ 105 | torch.einsum('bd,dn->bn', flatten_norm, rearrange(weight_norm, 'n d -> d n')) # more efficient, add "-" for argmax gumbel sample 106 | else: # L2 distance 107 | flatten = flatten.view(-1, self.codebook_dim) 108 | dist = - torch.sum(flatten ** 2, dim=1, keepdim=True) - \ 109 | torch.sum(self.embedding.weight**2, dim=1) + 2 * \ 110 | torch.einsum('bd,dn->bn', flatten, rearrange(self.embedding.weight, 'n d -> d n')) # more efficient, add "-" for argmax gumbel sample 111 | 112 | embed_ind = utils.gumbel_sample(dist, dim = -1, temperature = temp) 113 | # embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 114 | embed_ind = embed_ind.view(*shape[:-1]) 115 | 116 | x_q = self.embedding(embed_ind) 117 | 118 | # compute loss for embedding 119 | loss = self.beta * torch.mean((x_q.detach()-x)**2) + torch.mean((x_q - x.detach()) ** 2) 120 | 121 | # ortho reg term 122 | if self.orthogonal_reg_weight > 0. : 123 | # eq (2) from https://arxiv.org/abs/2112.00384 124 | emb_weight_after_norm = F.normalize(self.embedding.weight, p = 2, dim = -1) 125 | diff = torch.mm(emb_weight_after_norm, torch.transpose(emb_weight_after_norm, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(emb_weight_after_norm) 126 | ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 127 | 128 | # diff = torch.mm(self.embedding.weight, torch.transpose(self.embedding.weight, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(self.embedding.weight) 129 | # ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 130 | loss = loss + ortho_reg_term 131 | 132 | # preserve gradients 133 | x_q = x + (x_q - x).detach() 134 | 135 | if need_transpose: 136 | x_q = rearrange(x_q, 'b n d -> b d n').contiguous() 137 | 138 | if self.accept_image_fmap: 139 | x_q = rearrange(x_q, 'b (h w) c -> b c h w', h = height, w = width).contiguous() 140 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width).contiguous() 141 | 142 | return x_q, loss, (None, None, embed_ind) 143 | 144 | def get_codebook_entry(self, indices, *kwargs): 145 | # get quantized latent vectors 146 | z_q = self.embedding(indices) # (batch, height, width, channel) 147 | return z_q 148 | 149 | 150 | if __name__ == "__main__": 151 | quantizer = VectorQuantize( 152 | codebook_size = 1024, 153 | codebook_dim = 512, 154 | kmeans_init = True, 155 | kmeans_iters = 10, 156 | use_cosine_sim = False, 157 | use_cosine_distance = True, 158 | channel_last = False, 159 | accept_image_fmap = False, 160 | commitment_beta = 0.25, 161 | orthogonal_reg_weight = 10., 162 | ) 163 | 164 | # x = torch.randn(10, 512, 16, 16) 165 | x = torch.randn(10, 512, 120) 166 | 167 | x_q, loss, (_, _, embed_ind) = quantizer(x, 0.) 168 | print(loss) 169 | -------------------------------------------------------------------------------- /modules/vector_quantization/quantize2_list.py: -------------------------------------------------------------------------------- 1 | # add the random restart of rqvae 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from einops import rearrange 8 | 9 | 10 | class VQEmbedding(nn.Embedding): 11 | """VQ embedding module with ema update.""" 12 | 13 | def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5): 14 | super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed) 15 | 16 | self.ema = ema 17 | self.decay = decay 18 | self.eps = eps 19 | self.restart_unused_codes = restart_unused_codes 20 | self.n_embed = n_embed 21 | 22 | if self.ema: 23 | _ = [p.requires_grad_(False) for p in self.parameters()] 24 | 25 | # padding index is not updated by EMA 26 | self.register_buffer('cluster_size_ema', torch.zeros(n_embed)) 27 | self.register_buffer('embed_ema', self.weight[:-1, :].detach().clone()) 28 | 29 | @torch.no_grad() 30 | def compute_distances(self, inputs): 31 | codebook_t = self.weight[:-1, :].t() 32 | 33 | (embed_dim, _) = codebook_t.shape 34 | inputs_shape = inputs.shape 35 | assert inputs_shape[-1] == embed_dim 36 | 37 | inputs_flat = inputs.reshape(-1, embed_dim) 38 | 39 | inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) 40 | codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) 41 | distances = torch.addmm( 42 | inputs_norm_sq + codebook_t_norm_sq, 43 | inputs_flat, 44 | codebook_t, 45 | alpha=-2.0, 46 | ) 47 | distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1] 48 | return distances 49 | 50 | @torch.no_grad() 51 | def find_nearest_embedding(self, inputs): 52 | distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1] 53 | embed_idxs = distances.argmin(dim=-1) # use padding index or not 54 | 55 | return embed_idxs 56 | 57 | @torch.no_grad() 58 | def _tile_with_noise(self, x, target_n): 59 | B, embed_dim = x.shape 60 | n_repeats = (target_n + B -1) // B 61 | std = x.new_ones(embed_dim) * 0.01 / np.sqrt(embed_dim) 62 | x = x.repeat(n_repeats, 1) 63 | x = x + torch.rand_like(x) * std 64 | return x 65 | 66 | @torch.no_grad() 67 | def _update_buffers(self, vectors, idxs): 68 | 69 | n_embed, embed_dim = self.weight.shape[0]-1, self.weight.shape[-1] 70 | 71 | vectors = vectors.reshape(-1, embed_dim) 72 | idxs = idxs.reshape(-1) 73 | 74 | n_vectors = vectors.shape[0] 75 | n_total_embed = n_embed 76 | 77 | one_hot_idxs = vectors.new_zeros(n_total_embed, n_vectors) 78 | one_hot_idxs.scatter_(dim=0, 79 | index=idxs.unsqueeze(0), 80 | src=vectors.new_ones(1, n_vectors) 81 | ) 82 | 83 | cluster_size = one_hot_idxs.sum(dim=1) 84 | vectors_sum_per_cluster = one_hot_idxs @ vectors 85 | 86 | if dist.is_initialized(): 87 | dist.all_reduce(vectors_sum_per_cluster, op=dist.ReduceOp.SUM) 88 | dist.all_reduce(cluster_size, op=dist.ReduceOp.SUM) 89 | 90 | self.cluster_size_ema.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay) 91 | self.embed_ema.mul_(self.decay).add_(vectors_sum_per_cluster, alpha=1 - self.decay) 92 | 93 | if self.restart_unused_codes: 94 | if n_vectors < n_embed: 95 | vectors = self._tile_with_noise(vectors, n_embed) 96 | n_vectors = vectors.shape[0] 97 | _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed] 98 | 99 | if dist.is_initialized(): 100 | dist.broadcast(_vectors_random, 0) 101 | 102 | usage = (self.cluster_size_ema.view(-1, 1) >= 1).float() 103 | self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage)) 104 | self.cluster_size_ema.mul_(usage.view(-1)) 105 | self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1)) 106 | 107 | @torch.no_grad() 108 | def _update_embedding(self): 109 | 110 | n_embed = self.weight.shape[0] - 1 111 | n = self.cluster_size_ema.sum() 112 | normalized_cluster_size = ( 113 | n * (self.cluster_size_ema + self.eps) / (n + n_embed * self.eps) 114 | ) 115 | self.weight[:-1, :] = self.embed_ema / normalized_cluster_size.reshape(-1, 1) 116 | 117 | def forward(self, inputs): 118 | embed_idxs = self.find_nearest_embedding(inputs) 119 | if self.training: 120 | if self.ema: 121 | self._update_buffers(inputs, embed_idxs) 122 | 123 | embeds = self.embed(embed_idxs) 124 | 125 | if self.ema and self.training: 126 | self._update_embedding() 127 | 128 | return embeds, embed_idxs 129 | 130 | def embed(self, idxs): 131 | embeds = super().forward(idxs) 132 | return embeds 133 | 134 | # simplified version with random restart unused, accept list features input 135 | class VectorQuantize2(nn.Module): 136 | def __init__(self, 137 | codebook_size, 138 | codebook_dim = None, 139 | commitment_beta = 0.25, 140 | decay = 0.99, 141 | restart_unused_codes = True, 142 | ): 143 | super().__init__() 144 | self.beta = commitment_beta 145 | self.restart_unused_codes = restart_unused_codes 146 | self.codebook = VQEmbedding(codebook_size, 147 | codebook_dim, 148 | decay = decay, 149 | restart_unused_codes = restart_unused_codes, 150 | ) 151 | self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) 152 | 153 | def forward(self, x_list, *ignorewargs, **ignorekwargs): 154 | batch_size = len(x_list) 155 | 156 | x_q_list, x_code_list = [], [] 157 | loss = 0. 158 | for i in range(batch_size): 159 | x_q_i, x_code_i = self.codebook(x_list[i]) 160 | 161 | loss += self.beta * torch.mean((x_q_i.detach()-x_list[i])**2) + torch.mean((x_q_i - x_list[i].detach()) ** 2) 162 | 163 | # preserve gradients 164 | x_q_i = x_list[i] + (x_q_i - x_list[i]).detach() 165 | 166 | x_q_list.append(x_q_i) 167 | x_code_list.append(x_code_i) 168 | 169 | loss /= batch_size 170 | return x_q_list, loss, (None, None, x_code_list) 171 | 172 | @torch.no_grad() 173 | def get_soft_codes(self, x, temp=1.0, stochastic=False): 174 | distances = self.codebook.compute_distances(x) 175 | soft_code = F.softmax(-distances / temp, dim=-1) 176 | 177 | if stochastic: 178 | soft_code_flat = soft_code.reshape(-1, soft_code.shape[-1]) 179 | code = torch.multinomial(soft_code_flat, 1) 180 | code = code.reshape(*soft_code.shape[:-1]) 181 | else: 182 | code = distances.argmin(dim=-1) 183 | 184 | return soft_code, code 185 | 186 | def get_codebook_entry(self, indices, *kwargs): 187 | # get quantized latent vectors 188 | z_q = self.codebook.embed(indices) # (batch, height, width, channel) 189 | return z_q -------------------------------------------------------------------------------- /modules/vector_quantization/quantize_codebook_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import einsum 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from einops import rearrange, repeat 7 | import time 8 | 9 | import os, sys 10 | sys.path.append(os.getcwd()) 11 | 12 | import modules.vector_quantization.common_utils as utils 13 | 14 | # support: cosine-similarity; kmeans centroids initialization; orthogonal_reg_weight 15 | class MaskVectorQuantize(nn.Module): 16 | def __init__( 17 | self, 18 | codebook_size, 19 | codebook_dim = None, 20 | kmeans_init = False, 21 | kmeans_iters = 10, 22 | use_cosine_sim = False, 23 | channel_last = False, 24 | accept_image_fmap = True, 25 | commitment_beta = 0.25, 26 | orthogonal_reg_weight = 0., 27 | activate_mask_quantize = True, 28 | ): 29 | super().__init__() 30 | 31 | self.codebook_size = codebook_size 32 | self.codebook_dim = codebook_dim 33 | self.accept_image_fmap = accept_image_fmap 34 | self.channel_last = channel_last 35 | self.use_cosine_sim = use_cosine_sim 36 | self.beta = commitment_beta 37 | 38 | self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim) 39 | 40 | # codebook initialization 41 | if not kmeans_init: 42 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) 43 | else: 44 | self.embedding.weight.data.zero_() 45 | self.kmeans_iters = kmeans_iters 46 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 47 | self.register_buffer('cluster_size', torch.zeros(1, codebook_size)) 48 | 49 | self.sample_fn = utils.batched_sample_vectors 50 | self.all_reduce_fn = utils.noop 51 | 52 | # codebook_orthogonal_loss 53 | self.orthogonal_reg_weight = orthogonal_reg_weight 54 | 55 | # activate mask quantization 56 | self.activate_mask_quantize = activate_mask_quantize 57 | 58 | def init_embed_(self, data): 59 | if self.initted: 60 | return 61 | 62 | data = rearrange(data, '... -> 1 ...').contiguous() 63 | data = rearrange(data, 'h ... d -> h (...) d').contiguous() 64 | 65 | embed, cluster_size = utils.kmeans( 66 | data, 67 | self.codebook_size, 68 | self.kmeans_iters, 69 | sample_fn = self.sample_fn, 70 | all_reduce_fn = self.all_reduce_fn 71 | ) 72 | 73 | self.embedding.weight.data.copy_(embed.squeeze(0)) 74 | self.cluster_size.data.copy_(cluster_size) 75 | self.initted.data.copy_(torch.Tensor([True])) 76 | 77 | def forward(self, x, temp=0., codebook_mask=None): 78 | need_transpose = not self.channel_last and not self.accept_image_fmap 79 | 80 | if self.accept_image_fmap: 81 | height, width = x.shape[-2:] 82 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 83 | 84 | if codebook_mask is not None and self.activate_mask_quantize: 85 | codebook_mask = rearrange(codebook_mask, "b c h w -> b (h w) c").contiguous() 86 | 87 | if need_transpose: 88 | x = rearrange(x, 'b d n -> b n d').contiguous() 89 | shape, device, dtype = x.shape, x.device, x.dtype 90 | flatten = rearrange(x, 'h ... d -> h (...) d').contiguous() 91 | 92 | # if use cosine_sim, whether should norm the feature before k-means initialization ? 93 | # if self.use_cosine_sim: 94 | # flatten = F.normalize(flatten, p = 2, dim = -1) 95 | self.init_embed_(flatten) 96 | 97 | # calculate the distance 98 | if self.use_cosine_sim: # cosine similarity 99 | flatten_norm = F.normalize(flatten, p = 2, dim = -1) 100 | weight_norm = F.normalize(self.embedding.weight, p = 2, dim = -1).unsqueeze(0) 101 | 102 | # compute inner product 103 | dist = einsum('h n d, h c d -> h n c', flatten_norm, weight_norm) 104 | else: # L2 distance 105 | flatten = flatten.view(-1, self.codebook_dim) 106 | dist = - torch.sum(flatten ** 2, dim=1, keepdim=True) - \ 107 | torch.sum(self.embedding.weight**2, dim=1) + 2 * \ 108 | torch.einsum('bd,dn->bn', flatten, rearrange(self.embedding.weight, 'n d -> d n')) # more efficient, add "-" for argmax gumbel sample 109 | 110 | embed_ind = utils.gumbel_sample(dist, dim = -1, temperature = temp) 111 | # embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 112 | embed_ind = embed_ind.view(*shape[:-1]) 113 | 114 | x_q = self.embedding(embed_ind) 115 | 116 | # compute loss for embedding 117 | if codebook_mask is not None and self.activate_mask_quantize: 118 | ratio = 1 / torch.mean(codebook_mask) 119 | loss = ratio * self.beta * torch.mean((x_q.detach()-x) ** 2 * codebook_mask) + ratio * torch.mean((x_q - x.detach()) ** 2 * codebook_mask) 120 | else: 121 | loss = self.beta * torch.mean((x_q.detach()-x)**2) + torch.mean((x_q - x.detach()) ** 2) 122 | 123 | # ortho reg term 124 | if self.orthogonal_reg_weight > 0. : 125 | # eq (2) from https://arxiv.org/abs/2112.00384 126 | emb_weight_after_norm = F.normalize(self.embedding.weight, p = 2, dim = -1) 127 | diff = torch.mm(emb_weight_after_norm, torch.transpose(emb_weight_after_norm, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(emb_weight_after_norm) 128 | ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 129 | 130 | # diff = torch.mm(self.embedding.weight, torch.transpose(self.embedding.weight, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(self.embedding.weight) 131 | # ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 132 | loss = loss + ortho_reg_term 133 | 134 | # preserve gradients 135 | x_q = x + (x_q - x).detach() 136 | 137 | if need_transpose: 138 | x_q = rearrange(x_q, 'b n d -> b d n').contiguous() 139 | 140 | if self.accept_image_fmap: 141 | x_q = rearrange(x_q, 'b (h w) c -> b c h w', h = height, w = width).contiguous() 142 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width).contiguous() 143 | 144 | return x_q, loss, (None, None, embed_ind) 145 | 146 | def get_codebook_entry(self, indices, shape, *kwargs): 147 | # get quantized latent vectors 148 | z_q = self.embedding(indices) # (batch, height, width, channel) 149 | if shape is not None: 150 | z_q = z_q.view(shape) 151 | # reshape back to match original input shape 152 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 153 | return z_q 154 | 155 | @torch.no_grad() 156 | def embed_code_with_depth(self, code, to_latent_shape=False): 157 | code_slices = torch.chunk(code, chunks=code.shape[-1], dim=-1) 158 | 159 | embeds = [self.embedding(code_slice) for i, code_slice in enumerate(code_slices)] 160 | 161 | if to_latent_shape: 162 | embeds = [self.to_latent_shape(embed.squeeze(-2)).unsqueeze(-2) for embed in embeds] 163 | embeds = torch.cat(embeds, dim=-2) 164 | 165 | return embeds, None 166 | 167 | if __name__ == "__main__": 168 | # quantizer = VectorQuantize( 169 | # codebook_size = 1024, 170 | # codebook_dim = 512, 171 | # kmeans_init = True, 172 | # kmeans_iters = 10, 173 | # use_cosine_sim = False, 174 | # channel_last = False, 175 | # accept_image_fmap = False, 176 | # commitment_beta = 0.25, 177 | # orthogonal_reg_weight = 10., 178 | # use_ddp = False, 179 | # ) 180 | 181 | # # x = torch.randn(10, 512, 16, 16) 182 | # x = torch.randn(10, 512, 120) 183 | 184 | # x_q, loss, (_, _, embed_ind) = quantizer(x, 0.) 185 | # print(loss) 186 | pass -------------------------------------------------------------------------------- /modules/vqvae/quantize2.py: -------------------------------------------------------------------------------- 1 | # NOTE: 这一版本修改了vqvae中输入z的方式,使得适配我们的mask机制 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | class VectorQuantizer2(nn.Module): 9 | """ 10 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 11 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 12 | """ 13 | # NOTE: due to a bug the beta term was applied to the wrong term. for 14 | # backwards compatibility we use the buggy version by default, but you can 15 | # specify legacy=False to fix it. 16 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 17 | legacy=True): 18 | super().__init__() 19 | self.n_e = n_e 20 | self.e_dim = e_dim 21 | self.beta = beta 22 | self.legacy = legacy 23 | 24 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 25 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 26 | 27 | self.remap = remap 28 | if self.remap is not None: 29 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 30 | self.re_embed = self.used.shape[0] 31 | self.unknown_index = unknown_index # "random" or "extra" or integer 32 | if self.unknown_index == "extra": 33 | self.unknown_index = self.re_embed 34 | self.re_embed = self.re_embed+1 35 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 36 | f"Using {self.unknown_index} for unknown indices.") 37 | else: 38 | self.re_embed = n_e 39 | 40 | def remap_to_used(self, inds): 41 | ishape = inds.shape 42 | assert len(ishape)>1 43 | inds = inds.reshape(ishape[0],-1) 44 | used = self.used.to(inds) 45 | match = (inds[:,:,None]==used[None,None,...]).long() 46 | new = match.argmax(-1) 47 | unknown = match.sum(2)<1 48 | if self.unknown_index == "random": 49 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 50 | else: 51 | new[unknown] = self.unknown_index 52 | return new.reshape(ishape) 53 | 54 | def unmap_to_all(self, inds): 55 | ishape = inds.shape 56 | assert len(ishape)>1 57 | inds = inds.reshape(ishape[0],-1) 58 | used = self.used.to(inds) 59 | if self.re_embed > self.used.shape[0]: # extra token 60 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 61 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 62 | return back.reshape(ishape) 63 | 64 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 65 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 66 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 67 | assert return_logits==False, "Only for interface compatible with Gumbel" 68 | # reshape z -> (batch, length, channel) and flatten 69 | z = rearrange(z, 'b c l -> b l c').contiguous() 70 | z_flattened = z.view(-1, self.e_dim) 71 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 72 | 73 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 74 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 75 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 76 | 77 | min_encoding_indices = torch.argmin(d, dim=1) 78 | z_q = self.embedding(min_encoding_indices).view(z.shape) 79 | perplexity = None 80 | min_encodings = None 81 | 82 | # compute loss for embedding 83 | if not self.legacy: 84 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 85 | torch.mean((z_q - z.detach()) ** 2) 86 | else: 87 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 88 | torch.mean((z_q - z.detach()) ** 2) 89 | 90 | # preserve gradients 91 | z_q = z + (z_q - z).detach() 92 | 93 | # reshape back to match original input shape 94 | z_q = rearrange(z_q, 'b l c -> b c l').contiguous() 95 | 96 | if self.remap is not None: 97 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 98 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 99 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 100 | 101 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 102 | 103 | def get_codebook_entry(self, indices, shape): 104 | # shape specifying (batch, height, width, channel) 105 | if self.remap is not None: 106 | indices = indices.reshape(shape[0],-1) # add batch axis 107 | indices = self.unmap_to_all(indices) 108 | indices = indices.reshape(-1) # flatten again 109 | 110 | # get quantized latent vectors 111 | z_q = self.embedding(indices) 112 | 113 | if shape is not None: 114 | z_q = z_q.view(shape) 115 | # reshape back to match original input shape 116 | z_q = z_q.permute(0, 2, 1).contiguous() 117 | 118 | return z_q -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # for pytorch_lightning ModelCheckpoint, Callback, LearningRateMonitor, ... modules 2 | import os 3 | import wandb 4 | from omegaconf import OmegaConf 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | 13 | class SetupCallback(Callback): 14 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config, argv_content=None): 15 | super().__init__() 16 | self.resume = resume 17 | self.now = now 18 | self.logdir = logdir 19 | self.ckptdir = ckptdir 20 | self.cfgdir = cfgdir 21 | self.config = config 22 | self.lightning_config = lightning_config 23 | 24 | self.argv_content = argv_content 25 | 26 | # 在pretrain例程开始时调用。 27 | def on_pretrain_routine_start(self, trainer, pl_module): 28 | if trainer.global_rank == 0: 29 | # Create logdirs and save configs 30 | os.makedirs(self.logdir, exist_ok=True) 31 | os.makedirs(self.ckptdir, exist_ok=True) 32 | os.makedirs(self.cfgdir, exist_ok=True) 33 | 34 | print("Project config") 35 | print(OmegaConf.to_yaml(self.config)) 36 | OmegaConf.save(self.config, 37 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 38 | 39 | print("Lightning config") 40 | print(OmegaConf.to_yaml(self.lightning_config)) 41 | OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), 42 | os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) 43 | 44 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: 45 | f.write(str(self.argv_content)) 46 | else: 47 | # ModelCheckpoint callback created log directory --- remove it 48 | if not self.resume and os.path.exists(self.logdir): 49 | dst, name = os.path.split(self.logdir) 50 | dst = os.path.join(dst, "child_runs", name) 51 | os.makedirs(os.path.split(dst)[0], exist_ok=True) 52 | try: 53 | os.rename(self.logdir, dst) 54 | except FileNotFoundError: 55 | pass 56 | 57 | class CaptionImageLogger(Callback): 58 | def __init__(self, batch_frequency, max_images, clamp, type="wandb"): 59 | self.batch_freq = batch_frequency 60 | self.max_images = max_images 61 | self.clamp = clamp 62 | self.logger_log_images = { 63 | pl.loggers.WandbLogger: self._wandb, 64 | pl.loggers.TensorBoardLogger: self._tensorboard, 65 | } 66 | self.type = type # wandb or tensorboard 67 | 68 | 69 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *kwargs): 70 | self.log_img(pl_module, batch, batch_idx, split="train") 71 | 72 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *kwargs): 73 | self.log_img(pl_module, batch, batch_idx, split="val") 74 | 75 | @rank_zero_only 76 | def _wandb(self, pl_module, images, batch_idx, split): 77 | grids = dict() 78 | for k in images: 79 | grid = torchvision.utils.make_grid(images[k], normalize=True) 80 | grids[f"{split}/{k}"] = wandb.Image(grid) 81 | pl_module.logger.experiment.log(grids, commit=False) 82 | 83 | @rank_zero_only 84 | def _tensorboard(self, pl_module, images, batch_idx, split): 85 | for k in images: 86 | grid = torchvision.utils.make_grid(images[k], nrow=4, normalize=True) 87 | tag = f"{split}/{k}" 88 | pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) 89 | 90 | @rank_zero_only 91 | def log_img(self, pl_module, batch, batch_idx, split="train"): 92 | if (batch_idx % self.batch_freq == 0) and hasattr(pl_module, "log_images") and callable(pl_module.log_images) and (self.max_images > 0): 93 | logger = type(pl_module.logger) 94 | is_train = pl_module.training 95 | if is_train: 96 | pl_module.eval() 97 | with torch.no_grad(): 98 | images = pl_module.log_images(batch, split=split) 99 | 100 | # NOTE: 集群的路径总是有bug!!!!!!!! 101 | if "groundtruth_captions" in images: 102 | # if self.type == "wandb": 103 | # pl_module.logger.log_text(key="samples_{}".format(pl_module.global_step), columns=["{}_groundtruth_captions".format(split)], data=images['groundtruth_captions']) 104 | # else: 105 | # pl_module.logger.experiment.add_text("{}_groundtruth_captions".format(split), str(images['groundtruth_captions']), global_step=pl_module.global_step) 106 | del images['groundtruth_captions'] 107 | 108 | if "dest_captions" in images: 109 | # if self.type == "wandb": 110 | # pl_module.logger.log_text(key="samples_{}".format(pl_module.global_step), columns=["{}_dest_captions".format(split)], data=images['dest_captions']) 111 | # else: 112 | # pl_module.logger.experiment.add_text("{}_dest_captions".format(split), str(images['dest_captions']), global_step=pl_module.global_step) 113 | del images['dest_captions'] 114 | 115 | if "sample_captions" in images: 116 | # if self.type == "wandb": 117 | # pl_module.logger.log_text(key="samples_{}".format(pl_module.global_step), columns=["{}_sample_captions".format(split)], data=images['sample_captions']) 118 | # else: 119 | # pl_module.logger.experiment.add_text("{}_sample_captions".format(split), str(images['sample_captions']), global_step=pl_module.global_step) 120 | del images['sample_captions'] 121 | 122 | for k in images: 123 | N = min(images[k].shape[0], self.max_images) 124 | images[k] = images[k][:N] 125 | if isinstance(images[k], torch.Tensor): 126 | images[k] = images[k].detach().cpu() 127 | if self.clamp: 128 | images[k] = torch.clamp(images[k], -1., 1.) 129 | self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx) 130 | logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) 131 | logger_log_images(pl_module, images, pl_module.global_step, split) 132 | 133 | if is_train: 134 | pl_module.train() 135 | 136 | @rank_zero_only 137 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 138 | root = os.path.join(save_dir, "images", split) 139 | for k in images: 140 | grid = torchvision.utils.make_grid(images[k], nrow=4, normalize=True) 141 | grid = grid.transpose(0,1).transpose(1,2).squeeze(-1) 142 | grid = grid.numpy() 143 | grid = (grid*255).astype(np.uint8) 144 | filename = "Step_{:06}-Epoch_{:03}-Batch_{:06}-{}.png".format(global_step,current_epoch,batch_idx,k) 145 | path = os.path.join(root, filename) 146 | os.makedirs(os.path.split(path)[0], exist_ok=True) 147 | Image.fromarray(grid).save(path) 148 | 149 | 150 | if __name__ == "__main__": 151 | pass --------------------------------------------------------------------------------