├── .gitignore ├── LICENSE ├── README.md ├── assets ├── dynamic_framework.png └── dynamic_visual2.png ├── configs ├── stage1 │ ├── dqvae-dual-r-05_imagenet.yml │ ├── dqvae-entropy-dual-r05_imagenet.yml │ └── dqvae-triple-r-03-03_imagenet.yml └── stage2 │ ├── class_imagenet_p6c18.yml │ └── uncond_imagenet_p6c18.yml ├── data ├── build.py ├── data_utils.py ├── default.py ├── faceshq.py ├── ffhq_lmdb.py ├── imagenet.py └── imagenet_base.py ├── environment.yml ├── models ├── stage1 │ ├── rqvae.py │ ├── utils.py │ ├── vqgan.py │ └── vqgan_multivq.py ├── stage1_dynamic │ ├── dqvae_dual_entropy.py │ ├── dqvae_dual_feat.py │ └── dqvae_triple_feat.py ├── stage2 │ ├── class_transformer.py │ ├── text2image_transformer.py │ ├── text2image_transformer2.py │ ├── uncond_rqtransformer.py │ ├── uncond_transformer.py │ └── utils.py └── stage2_dynamic │ ├── dqtransformer_class.py │ ├── dqtransformer_class2_entropy.py │ ├── dqtransformer_t2i.py │ └── dqtransformer_uncond_entropy.py ├── modules ├── diffusionmodules │ ├── attn_model.py │ └── model.py ├── discriminator │ ├── model.py │ ├── stylegan.py │ └── stylegan_lucidrains.py ├── dynamic_modules │ ├── Decoder.py │ ├── DecoderPositional.py │ ├── EncoderDual.py │ ├── EncoderTriple.py │ ├── RouterDual.py │ ├── RouterTriple.py │ ├── budget.py │ ├── fourier_embedding.py │ ├── label_provider.py │ ├── permuter.py │ ├── stackgpt.py │ ├── tools.py │ └── utils.py ├── losses │ ├── lpips.py │ ├── vqperceptual.py │ ├── vqperceptual_budget.py │ ├── vqperceptual_epoch.py │ └── vqperceptual_multidisc.py ├── lpips │ └── vgg.pth ├── scheduler │ ├── lr_scheduler.py │ └── scheduler.py ├── text_encoders │ ├── clip_text_encoder │ │ ├── base_embedding.py │ │ ├── clip │ │ │ ├── README.md │ │ │ ├── clip.py │ │ │ ├── clip_tokenizer.py │ │ │ ├── model.py │ │ │ └── simple_tokenizer.py │ │ ├── clip_text_embedding.py │ │ └── my_tokenizer │ │ │ ├── base_codec.py │ │ │ └── my_tokenize.py │ ├── modules.py │ └── x_transformers.py ├── transformer │ ├── hybrid_decoders.py │ ├── mask_attention.py │ ├── mask_attention_decoders.py │ ├── mingpt.py │ ├── mingpt_t2i.py │ ├── modules.py │ ├── permuter.py │ ├── position_aware_mingpt.py │ ├── position_embeddings.py │ ├── stacked_mingpt.py │ ├── vit.py │ └── vit_modules.py ├── vector_quantization │ ├── common_utils.py │ ├── quantize.py │ ├── quantize2.py │ ├── quantize2_list.py │ ├── quantize2_mask.py │ ├── quantize_codebook_mask.py │ ├── quantize_lucidrains.py │ ├── quantize_rqvae.py │ └── quantize_vqgan.py └── vqvae │ └── quantize2.py ├── scripts ├── sample_images │ └── sample_dynamic_uncond.py ├── sample_val │ └── sample_dynamic_uncond.py └── tools │ ├── calculate_entropy_thresholds.py │ ├── codebook_pca.py │ ├── codebook_usage_dqvae.py │ ├── thresholds │ ├── entropy_thresholds_ffhq_train_patch-16.json │ ├── entropy_thresholds_imagenet_train_patch-16.json │ └── entropy_thresholds_imagenet_val_patch-16.json │ └── visualize_dual_grain.py ├── train.py └── utils ├── logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | /logs/ 133 | /temps/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CrossmodalGroup 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/dynamic_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/DynamicVectorQuantization/be6dc36c2cc0f238acff76d19ae2e26d0b98788d/assets/dynamic_framework.png -------------------------------------------------------------------------------- /assets/dynamic_visual2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/DynamicVectorQuantization/be6dc36c2cc0f238acff76d19ae2e26d0b98788d/assets/dynamic_visual2.png -------------------------------------------------------------------------------- /configs/stage1/dqvae-dual-r-05_imagenet.yml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: models.stage1_dynamic.dqvae_dual_feat.DualGrainVQModel 4 | params: 5 | encoderconfig: 6 | target: modules.dynamic_modules.EncoderDual.DualGrainEncoder 7 | params: 8 | ch: 128 9 | ch_mult: [1,1,2,2,4] 10 | num_res_blocks: 2 11 | attn_resolutions: [16, 32] 12 | dropout: 0.0 13 | resamp_with_conv: true 14 | in_channels: 3 15 | resolution: 256 16 | z_channels: 256 17 | router_config: 18 | target: modules.dynamic_modules.RouterDual.DualGrainFeatureRouter 19 | params: 20 | num_channels: 256 21 | normalization_type: group-32 22 | gate_type: 2layer-fc-SiLu 23 | decoderconfig: 24 | target: modules.dynamic_modules.DecoderPositional.Decoder 25 | params: 26 | ch: 128 27 | in_ch: 256 28 | out_ch: 3 29 | ch_mult: [1,1,2,2] 30 | num_res_blocks: 2 31 | resolution: 256 32 | attn_resolutions: [32] 33 | latent_size: 32 34 | window_size: 2 35 | position_type: fourier+learned 36 | lossconfig: 37 | target: modules.losses.vqperceptual_multidisc.VQLPIPSWithDiscriminator 38 | params: 39 | disc_start: 0 40 | disc_config: 41 | target: modules.discriminator.model.NLayerDiscriminator 42 | params: 43 | input_nc: 3 44 | ndf: 64 45 | n_layers: 3 46 | use_actnorm: false 47 | disc_init: true 48 | codebook_weight: 1.0 49 | pixelloss_weight: 1.0 50 | disc_factor: 1.0 51 | disc_weight: 1.0 52 | perceptual_weight: 1.0 53 | disc_conditional: false 54 | disc_loss: hinge 55 | disc_weight_max: 0.75 56 | budget_loss_config: 57 | target: modules.dynamic_modules.budget.BudgetConstraint_RatioMSE_DualGrain 58 | params: 59 | target_ratio: 0.5 60 | gamma: 10.0 61 | min_grain_size: 16 62 | max_grain_size: 32 63 | calculate_all: True 64 | vqconfig: 65 | target: modules.vector_quantization.quantize2_mask.VectorQuantize2 66 | params: 67 | codebook_size: 1024 68 | codebook_dim: 256 69 | channel_last: false 70 | accept_image_fmap: true 71 | commitment_beta: 0.25 72 | decay: 0.99 73 | restart_unused_codes: True 74 | quant_before_dim: 256 75 | quant_after_dim: 256 76 | quant_sample_temperature: 0.0 77 | image_key: image 78 | monitor: val_rec_loss 79 | warmup_epochs: 0.1 80 | scheduler_type: linear-warmup_cosine-decay 81 | 82 | data: 83 | target: data.build.DataModuleFromConfig 84 | params: 85 | batch_size: 30 86 | num_workers: 8 87 | train: 88 | target: data.imagenet.ImageNetTrain 89 | params: 90 | config: 91 | is_eval: False 92 | size: 256 93 | validation: 94 | target: data.imagenet.ImageNetValidation 95 | params: 96 | config: 97 | is_eval: True 98 | size: 256 -------------------------------------------------------------------------------- /configs/stage1/dqvae-entropy-dual-r05_imagenet.yml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: models.stage1_dynamic.dqvae_dual_entropy.DualGrainVQModel 4 | params: 5 | encoderconfig: 6 | target: modules.dynamic_modules.EncoderDual.DualGrainEncoder 7 | params: 8 | ch: 128 9 | ch_mult: [1,1,2,2,4] 10 | num_res_blocks: 2 11 | attn_resolutions: [16, 32] 12 | dropout: 0.0 13 | resamp_with_conv: true 14 | in_channels: 3 15 | resolution: 256 16 | z_channels: 256 17 | update_router: False 18 | router_config: 19 | target: modules.dynamic_modules.RouterDual.DualGrainFixedEntropyRouter 20 | params: 21 | json_path: scripts/tools/thresholds/entropy_thresholds_imagenet_train_patch-16.json 22 | fine_grain_ratito: 0.5 23 | decoderconfig: 24 | target: modules.dynamic_modules.DecoderPositional.Decoder 25 | params: 26 | ch: 128 27 | in_ch: 256 28 | out_ch: 3 29 | ch_mult: [1,1,2,2] 30 | num_res_blocks: 2 31 | resolution: 256 32 | attn_resolutions: [32] 33 | latent_size: 32 34 | window_size: 2 35 | position_type: fourier+learned 36 | lossconfig: 37 | target: modules.losses.vqperceptual_multidisc.VQLPIPSWithDiscriminator 38 | params: 39 | disc_start: 0 40 | disc_config: 41 | target: modules.discriminator.model.NLayerDiscriminator 42 | params: 43 | input_nc: 3 44 | ndf: 64 45 | n_layers: 3 46 | use_actnorm: false 47 | disc_init: true 48 | codebook_weight: 1.0 49 | pixelloss_weight: 1.0 50 | disc_factor: 1.0 51 | disc_weight: 1.0 52 | perceptual_weight: 1.0 53 | disc_conditional: false 54 | disc_loss: hinge 55 | disc_weight_max: 0.75 56 | # budget_loss_config: 57 | # target: modules.dynamic_modules.budget.BudgetConstraint_RatioMSE_DualGrain 58 | # params: 59 | # target_ratio: 0.5 60 | # gamma: 10.0 61 | # min_grain_size: 16 62 | # max_grain_size: 32 63 | # calculate_all: True 64 | vqconfig: 65 | target: modules.vector_quantization.quantize2_mask.VectorQuantize2 66 | # target: modules.vector_quantization.quantize_codebook_mask.MaskVectorQuantize 67 | params: 68 | codebook_size: 1024 69 | codebook_dim: 256 70 | channel_last: false 71 | accept_image_fmap: true 72 | commitment_beta: 0.25 73 | decay: 0.99 74 | restart_unused_codes: True 75 | quant_before_dim: 256 76 | quant_after_dim: 256 77 | quant_sample_temperature: 0.0 78 | image_key: image 79 | monitor: val_rec_loss 80 | warmup_epochs: 0.1 81 | scheduler_type: linear-warmup_cosine-decay 82 | 83 | data: 84 | target: data.build.DataModuleFromConfig 85 | params: 86 | batch_size: 30 87 | num_workers: 8 88 | train: 89 | target: data.imagenet.ImageNetTrain 90 | params: 91 | config: 92 | is_eval: False 93 | size: 256 94 | validation: 95 | target: data.imagenet.ImageNetValidation 96 | params: 97 | config: 98 | is_eval: True 99 | size: 256 100 | -------------------------------------------------------------------------------- /configs/stage1/dqvae-triple-r-03-03_imagenet.yml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: models.stage1_dynamic.dqvae_triple_feat.TripleGrainVQModel 4 | params: 5 | encoderconfig: 6 | target: modules.dynamic_modules.EncoderTriple.TripleGrainEncoder 7 | params: 8 | ch: 128 9 | ch_mult: [1,1,2,2,4,4] 10 | num_res_blocks: 2 11 | attn_resolutions: [8,16,32] 12 | dropout: 0.0 13 | resamp_with_conv: true 14 | in_channels: 3 15 | resolution: 256 16 | z_channels: 256 17 | router_config: 18 | target: modules.dynamic_modules.RouterTriple.TripleGrainFeatureRouter 19 | params: 20 | num_channels: 256 21 | normalization_type: group-32 22 | gate_type: 2layer-fc-SiLu 23 | decoderconfig: 24 | target: modules.dynamic_modules.DecoderPositional.Decoder 25 | params: 26 | ch: 128 27 | in_ch: 256 28 | out_ch: 3 29 | ch_mult: [1,1,2,2] 30 | num_res_blocks: 2 31 | resolution: 256 32 | attn_resolutions: [32] 33 | latent_size: 32 34 | window_size: 2 35 | position_type: fourier+learned 36 | lossconfig: 37 | target: modules.losses.vqperceptual_multidisc.VQLPIPSWithDiscriminator 38 | params: 39 | disc_start: 0 40 | disc_config: 41 | target: modules.discriminator.model.NLayerDiscriminator 42 | params: 43 | input_nc: 3 44 | ndf: 64 45 | n_layers: 3 46 | use_actnorm: false 47 | disc_init: true 48 | codebook_weight: 1.0 49 | pixelloss_weight: 1.0 50 | disc_factor: 1.0 51 | disc_weight: 1.0 52 | perceptual_weight: 1.0 53 | disc_conditional: false 54 | disc_loss: hinge 55 | disc_weight_max: 0.75 56 | budget_loss_config: 57 | target: modules.dynamic_modules.budget.BudgetConstraint_NormedSeperateRatioMSE_TripleGrain 58 | params: 59 | target_fine_ratio: 0.3 60 | target_median_ratio: 0.3 61 | gamma: 1.0 62 | min_grain_size: 8 63 | median_grain_size: 16 64 | max_grain_size: 32 65 | 66 | vqconfig: 67 | target: modules.vector_quantization.quantize2_mask.VectorQuantize2 68 | params: 69 | codebook_size: 1024 70 | codebook_dim: 256 71 | channel_last: false 72 | accept_image_fmap: true 73 | commitment_beta: 0.25 74 | decay: 0.99 75 | restart_unused_codes: True 76 | quant_before_dim: 256 77 | quant_after_dim: 256 78 | quant_sample_temperature: 0.0 79 | image_key: image 80 | monitor: val_rec_loss 81 | warmup_epochs: 0.1 82 | scheduler_type: linear-warmup_cosine-decay 83 | 84 | data: 85 | target: data.build.DataModuleFromConfig 86 | params: 87 | batch_size: 30 88 | num_workers: 8 89 | train: 90 | target: data.imagenet.ImageNetTrain 91 | params: 92 | config: 93 | is_eval: False 94 | size: 256 95 | validation: 96 | target: data.imagenet.ImageNetValidation 97 | params: 98 | config: 99 | is_eval: True 100 | size: 256 -------------------------------------------------------------------------------- /configs/stage2/class_imagenet_p6c18.yml: -------------------------------------------------------------------------------- 1 | model: 2 | learning_rate: 0.0005 3 | min_learning_rate: 0.0 4 | target: models.stage2_dynamic.dqtransformer_class2_entropy.Dualformer 5 | params: 6 | transformer_config: 7 | target: modules.dynamic_modules.stackgpt.StackGPT 8 | params: 9 | vocab_size: 2026 # 1024 + 1 (pad) + 1 (eos) + 1000 (class number) 10 | coarse_position_size: 1258 # 256 + 1 (pad) + 1 (eos) + 1000 (class number) 11 | fine_position_size: 2026 # 1024 + 1 (pad) + 1 (eos) + 1000 (class number) 12 | segment_size: 2 # coarse and fine 13 | block_size: 2048 # as large as possible 14 | position_layer: 6 # 12 15 | content_layer: 18 # 12 16 | n_head: 8 17 | n_embd: 1024 18 | embd_pdrop: 0.1 19 | resid_pdrop: 0.1 20 | attn_pdrop: 0.1 21 | content_pad_code: 1024 22 | coarse_position_pad_code: 256 23 | fine_position_pad_code: 1024 24 | activate_pad_ignore: True 25 | 26 | first_stage_config: 27 | target: models.stage1_dynamic.dqvae_dual_entropy.DualGrainVQModel 28 | params: 29 | ckpt_path: "put your pre-trained DQ-VAE ckpt path" 30 | encoderconfig: 31 | target: modules.dynamic_modules.EncoderDual.DualGrainEncoder 32 | params: 33 | ch: 128 34 | ch_mult: [1,1,2,2,4] 35 | num_res_blocks: 2 36 | attn_resolutions: [16, 32] 37 | dropout: 0.0 38 | resamp_with_conv: true 39 | in_channels: 3 40 | resolution: 256 41 | z_channels: 256 42 | update_router: False 43 | router_config: 44 | target: modules.dynamic_modules.RouterDual.DualGrainFixedEntropyRouter 45 | params: 46 | json_path: scripts/tools/thresholds/entropy_thresholds_imagenet_train_patch-16.json 47 | fine_grain_ratito: 0.5 48 | decoderconfig: 49 | target: modules.dynamic_modules.DecoderPositional.Decoder 50 | params: 51 | ch: 128 52 | in_ch: 256 53 | out_ch: 3 54 | ch_mult: [1,1,2,2] 55 | num_res_blocks: 2 56 | resolution: 256 57 | attn_resolutions: [32] 58 | latent_size: 32 59 | window_size: 2 60 | position_type: fourier+learned 61 | lossconfig: 62 | target: modules.losses.vqperceptual.DummyLoss 63 | 64 | vqconfig: 65 | target: modules.vector_quantization.quantize2_mask.VectorQuantize2 66 | params: 67 | codebook_size: 1024 68 | codebook_dim: 256 69 | channel_last: false 70 | accept_image_fmap: true 71 | commitment_beta: 0.25 72 | decay: 0.99 73 | restart_unused_codes: True 74 | 75 | quant_before_dim: 256 76 | quant_after_dim: 256 77 | quant_sample_temperature: 0.0 78 | image_key: image 79 | # monitor: val_rec_loss 80 | # warmup_epochs: 0.1 81 | # scheduler_type: linear-warmup_cosine-decay 82 | 83 | class_cond_stage_config: 84 | target: modules.dynamic_modules.label_provider.ClassAwareSOSProvider 85 | params: 86 | n_classes: 1000 87 | threshold_content: 1026 88 | threshold_coarse_position: 258 89 | threshold_fine_position: 1026 90 | coarse_seg_sos: 0 91 | fine_seg_sos: 1 92 | 93 | permuter_config: 94 | target: modules.dynamic_modules.permuter.DualGrainSeperatePermuter 95 | params: 96 | coarse_hw: 16 97 | fine_hw: 32 98 | content_pad_code: 1024 99 | content_eos_code: 1025 100 | coarse_position_pad_code: 256 101 | coarse_position_eos_code: 257 102 | fine_position_pad_code: 1024 103 | fine_position_eos_code: 1025 104 | fine_position_order: row-first 105 | 106 | content_loss_weight: 1.0 107 | position_loss_weight: 1.0 108 | activate_sos_for_fine_sequence: True 109 | weight_decay: 0.01 110 | warmup_epochs: 0 111 | monitor: val_loss 112 | 113 | data: 114 | target: data.build.DataModuleFromConfig 115 | params: 116 | batch_size: 30 117 | num_workers: 8 118 | train: 119 | target: data.imagenet.ImageNetTrain 120 | params: 121 | config: 122 | is_eval: False 123 | size: 256 124 | validation: 125 | target: data.imagenet.ImageNetValidation 126 | params: 127 | config: 128 | is_eval: True 129 | size: 256 130 | -------------------------------------------------------------------------------- /configs/stage2/uncond_imagenet_p6c18.yml: -------------------------------------------------------------------------------- 1 | model: 2 | learning_rate: 0.0005 3 | min_learning_rate: 0.0 4 | target: models.stage2_dynamic.dqtransformer_uncond_entropy.Dualformer 5 | params: 6 | transformer_config: 7 | target: modules.dynamic_modules.stackgpt.StackGPT 8 | params: 9 | vocab_size: 1027 # 1024 + 1 (pad) + 1 (sos) + 1 (eos) 10 | coarse_position_size: 259 # 256 + 1 (pad) + 1 (sos) + 1 (eos) 11 | fine_position_size: 1027 # 1024 + 1 (pad) + 1 (sos) + 1 (eos) 12 | segment_size: 2 # coarse and fine 13 | block_size: 2048 # as large as possible 14 | position_layer: 6 # 12 15 | content_layer: 18 # 12 16 | n_head: 8 17 | n_embd: 1024 18 | embd_pdrop: 0.1 19 | resid_pdrop: 0.1 20 | attn_pdrop: 0.1 21 | content_pad_code: 1024 22 | coarse_position_pad_code: 256 23 | fine_position_pad_code: 1024 24 | activate_pad_ignore: True 25 | 26 | first_stage_config: 27 | target: models.stage1_dynamic.dqvae_dual_entropy.DualGrainVQModel 28 | params: 29 | ckpt_path: "put your pre-trained DQ-VAE ckpt path" 30 | encoderconfig: 31 | target: modules.dynamic_modules.EncoderDual.DualGrainEncoder 32 | params: 33 | ch: 128 34 | ch_mult: [1,1,2,2,4] 35 | num_res_blocks: 2 36 | attn_resolutions: [16, 32] 37 | dropout: 0.0 38 | resamp_with_conv: true 39 | in_channels: 3 40 | resolution: 256 41 | z_channels: 256 42 | update_router: False 43 | router_config: 44 | target: modules.dynamic_modules.RouterDual.DualGrainFixedEntropyRouter 45 | params: 46 | json_path: scripts/tools/thresholds/entropy_thresholds_imagenet_train_patch-16.json 47 | fine_grain_ratito: 0.5 48 | decoderconfig: 49 | target: modules.dynamic_modules.DecoderPositional.Decoder 50 | params: 51 | ch: 128 52 | in_ch: 256 53 | out_ch: 3 54 | ch_mult: [1,1,2,2] 55 | num_res_blocks: 2 56 | resolution: 256 57 | attn_resolutions: [32] 58 | latent_size: 32 59 | window_size: 2 60 | position_type: fourier+learned 61 | lossconfig: 62 | target: modules.losses.vqperceptual.DummyLoss 63 | 64 | vqconfig: 65 | target: modules.vector_quantization.quantize2_mask.VectorQuantize2 66 | params: 67 | codebook_size: 1024 68 | codebook_dim: 256 69 | channel_last: false 70 | accept_image_fmap: true 71 | commitment_beta: 0.25 72 | decay: 0.99 73 | restart_unused_codes: True 74 | 75 | quant_before_dim: 256 76 | quant_after_dim: 256 77 | quant_sample_temperature: 0.0 78 | image_key: image 79 | # monitor: val_rec_loss 80 | # warmup_epochs: 0.1 81 | # scheduler_type: linear-warmup_cosine-decay 82 | 83 | uncond_stage_config: 84 | target: modules.dynamic_modules.label_provider.PositionAwareSOSProvider 85 | params: 86 | coarse_sos: 1026 87 | coarse_pos_sos: 258 88 | fine_sos: 1026 89 | fine_pos_sos: 1026 90 | coarse_seg_sos: 0 91 | fine_seg_sos: 1 92 | 93 | permuter_config: 94 | target: modules.dynamic_modules.permuter.DualGrainSeperatePermuter 95 | params: 96 | coarse_hw: 16 97 | fine_hw: 32 98 | content_pad_code: 1024 99 | content_eos_code: 1025 100 | coarse_position_pad_code: 256 101 | coarse_position_eos_code: 257 102 | fine_position_pad_code: 1024 103 | fine_position_eos_code: 1025 104 | fine_position_order: row-first 105 | 106 | content_loss_weight: 1.0 107 | position_loss_weight: 1.0 108 | activate_sos_for_fine_sequence: True 109 | weight_decay: 0.01 110 | warmup_epochs: 0 111 | monitor: val_loss 112 | 113 | data: 114 | target: data.build.DataModuleFromConfig 115 | params: 116 | batch_size: 30 117 | num_workers: 8 118 | train: 119 | target: data.imagenet.ImageNetTrain 120 | params: 121 | config: 122 | is_eval: False 123 | size: 256 124 | validation: 125 | target: data.imagenet.ImageNetValidation 126 | params: 127 | config: 128 | is_eval: True 129 | size: 256 130 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Dataset 2 | import pytorch_lightning as pl 3 | from utils.utils import instantiate_from_config 4 | 5 | class WrappedDataset(Dataset): 6 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 7 | def __init__(self, dataset): 8 | self.data = dataset 9 | 10 | def __len__(self): 11 | return len(self.data) 12 | 13 | def __getitem__(self, idx): 14 | return self.data[idx] 15 | 16 | class DataModuleFromConfig(pl.LightningDataModule): 17 | def __init__(self, batch_size, train=None, validation=None, test=None, 18 | wrap=False, num_workers=None, train_val=False): 19 | super().__init__() 20 | self.batch_size = batch_size 21 | self.train_val = train_val 22 | self.dataset_configs = dict() 23 | self.num_workers = num_workers if num_workers is not None else batch_size*2 24 | if train is not None: 25 | self.dataset_configs["train"] = train 26 | self.train_dataloader = self._train_dataloader 27 | if validation is not None: 28 | self.dataset_configs["validation"] = validation 29 | self.val_dataloader = self._val_dataloader 30 | if test is not None: 31 | self.dataset_configs["test"] = test 32 | self.test_dataloader = self._test_dataloader 33 | self.wrap = wrap 34 | 35 | # move setup here to avoid warning 36 | self.datasets = dict( 37 | (k, instantiate_from_config(self.dataset_configs[k])) 38 | for k in self.dataset_configs) 39 | if self.wrap: 40 | for k in self.datasets: 41 | self.datasets[k] = WrappedDataset(self.datasets[k]) 42 | if self.train_val: 43 | if "train" in self.datasets.keys() and "validation" in self.datasets.keys(): 44 | self.datasets["train"] = self.datasets["train"] + self.datasets["validation"] 45 | for k in self.datasets.keys(): 46 | print("dataset: ", k, len(self.datasets[k])) 47 | 48 | def prepare_data(self): 49 | for data_cfg in self.dataset_configs.values(): 50 | print("instantiate from: ", data_cfg) 51 | instantiate_from_config(data_cfg) 52 | 53 | # def setup(self, stage=None): 54 | # self.datasets = dict( 55 | # (k, instantiate_from_config(self.dataset_configs[k])) 56 | # for k in self.dataset_configs) 57 | # if self.wrap: 58 | # for k in self.datasets: 59 | # self.datasets[k] = WrappedDataset(self.datasets[k]) 60 | # if self.train_val: 61 | # if "train" in self.datasets.keys() and "validation" in self.datasets.keys(): 62 | # self.datasets["train"] = self.datasets["train"] + self.datasets["validation"] 63 | # for k in self.datasets.keys(): 64 | # print("dataset: ", k, len(self.datasets[k])) 65 | 66 | def _train_dataloader(self): 67 | if hasattr(self.datasets["train"], "collate_fn"): 68 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 69 | num_workers=self.num_workers, shuffle=True, 70 | collate_fn=self.datasets["train"].collate_fn) 71 | else: 72 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 73 | num_workers=self.num_workers, shuffle=True) 74 | 75 | 76 | def _val_dataloader(self): 77 | if hasattr(self.datasets['validation'], "collate_fn"): 78 | return DataLoader(self.datasets["validation"], batch_size=self.batch_size, 79 | num_workers=self.num_workers, collate_fn=self.datasets["validation"].collate_fn) 80 | else: 81 | return DataLoader(self.datasets["validation"], batch_size=self.batch_size, 82 | num_workers=self.num_workers) 83 | 84 | def _test_dataloader(self): 85 | if hasattr(self.datasets["test"], "collate_fn"): 86 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 87 | num_workers=self.num_workers, collate_fn=self.datasets["test"].collate_fn) 88 | else: 89 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 90 | num_workers=self.num_workers) -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | import bisect 4 | import numpy as np 5 | import albumentations 6 | from PIL import Image 7 | from torch.utils.data import Dataset, ConcatDataset 8 | 9 | # data/utils 10 | def mark_prepared(root): 11 | Path(root).joinpath(".ready").touch() 12 | 13 | def is_prepared(root): 14 | return Path(root).joinpath(".ready").exists() 15 | 16 | def instantiate_from_config(config): 17 | if not "target" in config: 18 | raise KeyError("Expected key `target` to instantiate.") 19 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 20 | 21 | def get_obj_from_str(string, reload=False): 22 | module, cls = string.rsplit(".", 1) 23 | if reload: 24 | module_imp = importlib.import_module(module) 25 | importlib.reload(module_imp) 26 | return getattr(importlib.import_module(module, package=None), cls) 27 | 28 | # utils 29 | class KeyNotFoundError(Exception): 30 | def __init__(self, cause, keys=None, visited=None): 31 | self.cause = cause 32 | self.keys = keys 33 | self.visited = visited 34 | messages = list() 35 | if keys is not None: 36 | messages.append("Key not found: {}".format(keys)) 37 | if visited is not None: 38 | messages.append("Visited: {}".format(visited)) 39 | messages.append("Cause:\n{}".format(cause)) 40 | message = "\n".join(messages) 41 | super().__init__(message) 42 | 43 | def retrieve( 44 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 45 | ): 46 | """Given a nested list or dict return the desired value at key expanding 47 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 48 | is done in-place. 49 | 50 | Parameters 51 | ---------- 52 | list_or_dict : list or dict 53 | Possibly nested list or dictionary. 54 | key : str 55 | key/to/value, path like string describing all keys necessary to 56 | consider to get to the desired value. List indices can also be 57 | passed here. 58 | splitval : str 59 | String that defines the delimiter between keys of the 60 | different depth levels in `key`. 61 | default : obj 62 | Value returned if :attr:`key` is not found. 63 | expand : bool 64 | Whether to expand callable nodes on the path or not. 65 | 66 | Returns 67 | ------- 68 | The desired value or if :attr:`default` is not ``None`` and the 69 | :attr:`key` is not found returns ``default``. 70 | 71 | Raises 72 | ------ 73 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 74 | ``None``. 75 | """ 76 | 77 | keys = key.split(splitval) 78 | 79 | success = True 80 | try: 81 | visited = [] 82 | parent = None 83 | last_key = None 84 | for key in keys: 85 | if callable(list_or_dict): 86 | if not expand: 87 | raise KeyNotFoundError( 88 | ValueError( 89 | "Trying to get past callable node with expand=False." 90 | ), 91 | keys=keys, 92 | visited=visited, 93 | ) 94 | list_or_dict = list_or_dict() 95 | parent[last_key] = list_or_dict 96 | 97 | last_key = key 98 | parent = list_or_dict 99 | 100 | try: 101 | if isinstance(list_or_dict, dict): 102 | list_or_dict = list_or_dict[key] 103 | else: 104 | list_or_dict = list_or_dict[int(key)] 105 | except (KeyError, IndexError, ValueError) as e: 106 | raise KeyNotFoundError(e, keys=keys, visited=visited) 107 | 108 | visited += [key] 109 | # final expansion of retrieved value 110 | if expand and callable(list_or_dict): 111 | list_or_dict = list_or_dict() 112 | parent[last_key] = list_or_dict 113 | except KeyNotFoundError as e: 114 | if default is None: 115 | raise e 116 | else: 117 | list_or_dict = default 118 | success = False 119 | 120 | if not pass_success: 121 | return list_or_dict 122 | else: 123 | return list_or_dict, success 124 | 125 | class ConcatDatasetWithIndex(ConcatDataset): 126 | """Modified from original pytorch code to return dataset idx""" 127 | def __getitem__(self, idx): 128 | if idx < 0: 129 | if -idx > len(self): 130 | raise ValueError("absolute value of index should not exceed dataset length") 131 | idx = len(self) + idx 132 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 133 | if dataset_idx == 0: 134 | sample_idx = idx 135 | else: 136 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 137 | return self.datasets[dataset_idx][sample_idx], dataset_idx 138 | 139 | 140 | class ImagePaths(Dataset): 141 | def __init__(self, paths, size=None, random_crop=False, labels=None): 142 | self.size = size 143 | self.random_crop = random_crop 144 | 145 | self.labels = dict() if labels is None else labels 146 | self.labels["file_path_"] = paths 147 | self._length = len(paths) 148 | 149 | if self.size is not None and self.size > 0: 150 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 151 | if not self.random_crop: 152 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 153 | else: 154 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 155 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 156 | else: 157 | self.preprocessor = lambda **kwargs: kwargs 158 | 159 | def __len__(self): 160 | return self._length 161 | 162 | def preprocess_image(self, image_path): 163 | image = Image.open(image_path) 164 | if not image.mode == "RGB": 165 | image = image.convert("RGB") 166 | image = np.array(image).astype(np.uint8) 167 | image = self.preprocessor(image=image)["image"] 168 | image = (image/127.5 - 1.0).astype(np.float32) 169 | return image 170 | 171 | def __getitem__(self, i): 172 | example = dict() 173 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 174 | for k in self.labels: 175 | example[k] = self.labels[k][i] 176 | return example 177 | 178 | 179 | class NumpyPaths(ImagePaths): 180 | def preprocess_image(self, image_path): 181 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 182 | image = np.transpose(image, (1,2,0)) 183 | image = Image.fromarray(image, mode="RGB") 184 | image = np.array(image).astype(np.uint8) 185 | image = self.preprocessor(image=image)["image"] 186 | image = (image/127.5 - 1.0).astype(np.float32) 187 | return image -------------------------------------------------------------------------------- /data/default.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | DefaultDataPath = edict() 4 | 5 | DefaultDataPath.ImageNet = edict() 6 | 7 | # DefaultDataPath.ImageNet.root = "Your Data Path/Datasets/ImageNet" 8 | # DefaultDataPath.ImageNet.train_write_root = "Your Data Path/Datasets/ImageNet/train" 9 | # DefaultDataPath.ImageNet.val_write_root = "Your Data Path/Datasets/ImageNet/val" 10 | 11 | DefaultDataPath.ImageNet.root = "/home/huangmq/Datasets/ImageNet" 12 | DefaultDataPath.ImageNet.train_write_root = "/home/huangmq/Datasets/ImageNet/train" 13 | DefaultDataPath.ImageNet.val_write_root = "/home/huangmq/Datasets/ImageNet/val" -------------------------------------------------------------------------------- /data/ffhq_lmdb.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import pickle 4 | import string 5 | from pathlib import Path 6 | 7 | import lmdb 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | 12 | import os, sys 13 | sys.path.append(os.getcwd()) 14 | from data.default import DefaultDataPath 15 | 16 | class FFHQ_LMDB(torchvision.datasets.VisionDataset): 17 | 18 | def __init__(self, split="train", resolution=256, is_eval=False, **kwargs): 19 | 20 | if split == "train": 21 | lmdb_path = DefaultDataPath.FFHQ.train_lmdb 22 | elif split == "val": 23 | lmdb_path = DefaultDataPath.FFHQ.val_lmdb 24 | else: 25 | raise ValueError() 26 | 27 | root = str(Path(lmdb_path)) 28 | super().__init__(root, **kwargs) 29 | 30 | self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) 31 | with self.env.begin(write=False) as txn: 32 | self.length = txn.stat()["entries"] 33 | 34 | cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) 35 | cache_file = os.path.join(root, cache_file) 36 | if os.path.isfile(cache_file): 37 | self.keys = pickle.load(open(cache_file, "rb")) 38 | else: 39 | with self.env.begin(write=False) as txn: 40 | self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)] 41 | pickle.dump(self.keys, open(cache_file, "wb")) 42 | 43 | if split == "train" and not is_eval: 44 | transforms_ = [ 45 | transforms.RandomResizedCrop(resolution, scale=(0.75, 1.0), ratio=(1.0, 1.0)), 46 | transforms.RandomHorizontalFlip(p=0.5), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 49 | ] 50 | else: 51 | transforms_ = [ 52 | transforms.Resize(resolution), 53 | transforms.CenterCrop(resolution), 54 | transforms.ToTensor(), 55 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 56 | ] 57 | self.transforms = transforms.Compose(transforms_) 58 | 59 | def __getitem__(self, index: int): 60 | env = self.env 61 | with env.begin(write=False) as txn: 62 | imgbuf = txn.get(self.keys[index]) 63 | 64 | buf = io.BytesIO() 65 | buf.write(imgbuf) 66 | buf.seek(0) 67 | img = Image.open(buf).convert("RGB") 68 | 69 | if self.transforms is not None: 70 | img = self.transforms(img) 71 | 72 | return { 73 | "image": img 74 | } 75 | 76 | def __len__(self): 77 | return self.length 78 | 79 | 80 | if __name__ == "__main__": 81 | dataset = FFHQ_LMDB(split='train', resolution=256, is_eval=False) 82 | dataset_val = FFHQ_LMDB(split='val', resolution=256, is_eval=False) 83 | 84 | print(len(dataset)) 85 | print(len(dataset_val)) 86 | 87 | # sample = dataset.__getitem__(0) 88 | 89 | # torchvision.utils.save_image(sample["image"], "sample_ffhq.png", normalize=True) 90 | -------------------------------------------------------------------------------- /data/imagenet_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import albumentations 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | 7 | class ImagePaths(Dataset): 8 | def __init__(self, split, is_val, paths, size=None, random_crop=False, labels=None): 9 | self.size = size 10 | self.random_crop = random_crop 11 | 12 | self.labels = dict() if labels is None else labels 13 | self.labels["file_path_"] = paths 14 | self._length = len(paths) 15 | 16 | if split == "train" and not is_val: 17 | transforms_ = [ 18 | transforms.Resize(256), 19 | transforms.RandomCrop(256), 20 | transforms.RandomHorizontalFlip(p=0.5), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 23 | ] 24 | else: 25 | transforms_ = [ 26 | transforms.Resize(256), 27 | transforms.CenterCrop(256), 28 | transforms.Resize((256, 256)), 29 | transforms.ToTensor(), 30 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 31 | ] 32 | self.transforms = transforms.Compose(transforms_) 33 | 34 | # if self.size is not None and self.size > 0: 35 | # self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 36 | # if not self.random_crop: 37 | # self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 38 | # else: 39 | # self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 40 | # self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 41 | # else: 42 | # self.preprocessor = lambda **kwargs: kwargs 43 | 44 | def __len__(self): 45 | return self._length 46 | 47 | def preprocess_image(self, image_path): 48 | image = Image.open(image_path) 49 | if not image.mode == "RGB": 50 | image = image.convert("RGB") 51 | # image = np.array(image).astype(np.uint8) 52 | # we replace the original taming version image preprocess 53 | # with the one in RQVAE 54 | # image = self.preprocessor(image=image)["image"] 55 | # image = (image/127.5 - 1.0).astype(np.float32) 56 | 57 | image = self.transforms(image) 58 | return image 59 | 60 | def __getitem__(self, i): 61 | example = dict() 62 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 63 | for k in self.labels: 64 | example[k] = self.labels[k][i] 65 | return example -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vq 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - bzip2=1.0.8=h7b6447c_0 7 | - ca-certificates=2022.10.11=h06a4308_0 8 | - certifi=2022.12.7=py310h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - libuuid=1.0.3=h7f8727e_2 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1s=h7f8727e_0 16 | - pip=22.3.1=py310h06a4308_0 17 | - python=3.10.4=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=65.6.3=py310h06a4308_0 20 | - sqlite=3.38.5=hc218d9a_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - tzdata=2022g=h04d1e81_0 23 | - wheel=0.37.1=pyhd3eb1b0_0 24 | - xz=5.2.5=h7f8727e_1 25 | - zlib=1.2.12=h7f8727e_2 26 | - pip: 27 | - absl-py==1.3.0 28 | - academictorrents==2.3.3 29 | - aiohttp==3.8.3 30 | - aiosignal==1.3.1 31 | - albumentations==1.3.0 32 | - aniso8601==9.0.1 33 | - antlr4-python3-runtime==4.9.3 34 | - anyio==3.6.2 35 | - argon2-cffi==21.3.0 36 | - argon2-cffi-bindings==21.2.0 37 | - arrow==1.2.3 38 | - astroid==2.14.1 39 | - asttokens==2.2.1 40 | - astunparse==1.6.3 41 | - async-timeout==4.0.2 42 | - attrs==22.2.0 43 | - autofaiss==2.15.5 44 | - backcall==0.2.0 45 | - beautifulsoup4==4.11.2 46 | - bencode.py==2.0.0 47 | - bitstring==3.1.5 48 | - bleach==6.0.0 49 | - braceexpand==0.1.7 50 | - cachetools==5.2.1 51 | - cffi==1.15.1 52 | - charset-normalizer==2.1.1 53 | - click==8.1.3 54 | - clip-anytorch==2.5.0 55 | - clip-retrieval==2.36.1 56 | - comm==0.1.2 57 | - contourpy==1.0.6 58 | - cycler==0.11.0 59 | - dataclasses==0.6 60 | - debugpy==1.6.6 61 | - decorator==5.1.1 62 | - defusedxml==0.7.1 63 | - diffusers==0.16.1 64 | - dill==0.3.6 65 | - docker-pycreds==0.4.0 66 | - easydict==1.10 67 | - einops==0.6.0 68 | - embedding-reader==1.5.0 69 | - executing==1.2.0 70 | - exifread-nocycle==3.0.1 71 | - faiss-cpu==1.7.3 72 | - fastjsonschema==2.16.3 73 | - filelock==3.9.0 74 | - fire==0.4.0 75 | - flask==2.2.2 76 | - flask-cors==3.0.10 77 | - flask-restful==0.3.9 78 | - flatbuffers==23.3.3 79 | - fonttools==4.38.0 80 | - fqdn==1.5.1 81 | - frozenlist==1.3.3 82 | - fsspec==2022.11.0 83 | - ftfy==6.1.1 84 | - future==0.16.0 85 | - gast==0.4.0 86 | - gitdb==4.0.10 87 | - gitpython==3.1.30 88 | - google-auth==2.16.0 89 | - google-auth-oauthlib==0.4.6 90 | - google-pasta==0.2.0 91 | - grpcio==1.51.1 92 | - h5py==3.8.0 93 | - huggingface-hub==0.14.1 94 | - idna==3.4 95 | - imageio==2.25.1 96 | - img2dataset==1.41.0 97 | - importlib-metadata==6.6.0 98 | - ipykernel==6.21.3 99 | - ipython==8.10.0 100 | - ipython-genutils==0.2.0 101 | - ipywidgets==8.0.4 102 | - isoduration==20.11.0 103 | - isort==5.12.0 104 | - itsdangerous==2.1.2 105 | - jedi==0.18.2 106 | - jinja2==3.1.2 107 | - joblib==1.2.0 108 | - jsonpointer==2.3 109 | - jsonschema==4.17.3 110 | - jupyter==1.0.0 111 | - jupyter-console==6.6.3 112 | - jupyter-events==0.6.3 113 | - jupyter_client==8.0.3 114 | - jupyter_core==5.3.0 115 | - jupyter_server==2.5.0 116 | - jupyter_server_terminals==0.4.4 117 | - jupyterlab-pygments==0.2.2 118 | - jupyterlab-widgets==3.0.5 119 | - keras==2.11.0 120 | - kiwisolver==1.4.4 121 | - kornia==0.6.9 122 | - lazy-object-proxy==1.9.0 123 | - libclang==15.0.6.1 124 | - lmdb==1.4.0 125 | - lpips==0.1.4 126 | - markdown==3.4.1 127 | - markupsafe==2.1.1 128 | - matplotlib==3.6.2 129 | - matplotlib-inline==0.1.6 130 | - mccabe==0.7.0 131 | - mistune==2.0.5 132 | - multidict==6.0.4 133 | - multilingual-clip==1.0.10 134 | - nbclassic==0.5.3 135 | - nbclient==0.7.2 136 | - nbconvert==7.2.10 137 | - nbformat==5.7.3 138 | - nest-asyncio==1.5.6 139 | - networkx==3.0 140 | - nltk==3.8.1 141 | - notebook==6.5.3 142 | - notebook_shim==0.2.2 143 | - numpy==1.24.1 144 | - nvidia-cublas-cu11==11.10.3.66 145 | - nvidia-cuda-nvrtc-cu11==11.7.99 146 | - nvidia-cuda-runtime-cu11==11.7.99 147 | - nvidia-cudnn-cu11==8.5.0.96 148 | - oauthlib==3.2.2 149 | - omegaconf==2.3.0 150 | - open-clip-torch==2.13.0 151 | - opencv-python==4.7.0.68 152 | - opencv-python-headless==4.7.0.68 153 | - opt-einsum==3.3.0 154 | - packaging==23.0 155 | - pandas==1.5.3 156 | - pandocfilters==1.5.0 157 | - parso==0.8.3 158 | - pathtools==0.1.2 159 | - pexpect==4.8.0 160 | - pickleshare==0.7.5 161 | - pillow==9.4.0 162 | - platformdirs==3.0.0 163 | - prometheus-client==0.16.0 164 | - promise==2.3 165 | - prompt-toolkit==3.0.36 166 | - protobuf==3.19.6 167 | - psutil==5.9.4 168 | - ptyprocess==0.7.0 169 | - pure-eval==0.2.2 170 | - pyarrow==7.0.0 171 | - pyasn1==0.4.8 172 | - pyasn1-modules==0.2.8 173 | - pycocotools==2.0.6 174 | - pycparser==2.21 175 | - pydeprecate==0.3.1 176 | - pygments==2.14.0 177 | - pylint==2.16.1 178 | - pyparsing==3.0.9 179 | - pypubsub==3.3.0 180 | - pyrsistent==0.19.3 181 | - python-dateutil==2.8.2 182 | - python-json-logger==2.0.7 183 | - python-version==0.0.2 184 | - pytorch-lightning==1.5.6 185 | - pytz==2022.7 186 | - pywavelets==1.4.1 187 | - pyyaml==6.0 188 | - pyzmq==25.0.1 189 | - qtconsole==5.4.1 190 | - qtpy==2.3.0 191 | - qudida==0.0.4 192 | - regex==2022.10.31 193 | - requests==2.28.1 194 | - requests-oauthlib==1.3.1 195 | - rfc3339-validator==0.1.4 196 | - rfc3986-validator==0.1.1 197 | - rsa==4.9 198 | - scikit-image==0.19.3 199 | - scikit-learn==1.2.1 200 | - scipy==1.10.0 201 | - seaborn==0.12.2 202 | - send2trash==1.8.0 203 | - sentence-transformers==2.2.2 204 | - sentencepiece==0.1.97 205 | - sentry-sdk==1.12.1 206 | - setproctitle==1.3.2 207 | - shortuuid==1.0.11 208 | - six==1.16.0 209 | - sklearn==0.0.post1 210 | - smmap==5.0.0 211 | - sniffio==1.3.0 212 | - soupsieve==2.4 213 | - stack-data==0.6.2 214 | - tensorboard==2.11.0 215 | - tensorboard-data-server==0.6.1 216 | - tensorboard-plugin-wit==1.8.1 217 | - tensorflow-estimator==2.11.0 218 | - tensorflow-io-gcs-filesystem==0.31.0 219 | - termcolor==2.2.0 220 | - terminado==0.17.1 221 | - test-tube==0.7.5 222 | - threadpoolctl==3.1.0 223 | - tifffile==2023.2.3 224 | - timm==0.6.12 225 | - tinycss2==1.2.1 226 | - tokenizers==0.13.2 227 | - tomli==2.0.1 228 | - tomlkit==0.11.6 229 | - torch==1.13.1 230 | - torchmetrics==0.11.0 231 | - torchvision==0.14.1 232 | - tornado==6.2 233 | - tqdm==4.64.1 234 | - traitlets==5.9.0 235 | - transformers==4.26.1 236 | - typing_extensions==4.4.0 237 | - uri-template==1.2.0 238 | - urllib3==1.26.13 239 | - wandb==0.12.21 240 | - wcwidth==0.2.5 241 | - webcolors==1.12 242 | - webdataset==0.2.33 243 | - webencodings==0.5.1 244 | - websocket-client==1.5.1 245 | - werkzeug==2.2.2 246 | - widgetsnbextension==4.0.5 247 | - wrapt==1.14.1 248 | - yarl==1.8.2 249 | - zipp==3.15.0 250 | prefix: /home/huangmq/anaconda3/envs/vq 251 | -------------------------------------------------------------------------------- /models/stage1/rqvae.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from modules.vector_quantization.quantize_rqvae import RQBottleneck 8 | from modules.diffusionmodules.model import Encoder, Decoder, ResnetBlock 9 | 10 | class Stage1Model(nn.Module, metaclass=abc.ABCMeta): 11 | 12 | @abc.abstractmethod 13 | def get_codes(self, *args, **kwargs): 14 | """Generate the code from the input.""" 15 | pass 16 | 17 | @abc.abstractmethod 18 | def decode_code(self, *args, **kwargs): 19 | """Generate the decoded image from the given code.""" 20 | pass 21 | 22 | @abc.abstractmethod 23 | def get_recon_imgs(self, *args, **kwargs): 24 | """Scales the real and recon images properly. 25 | """ 26 | pass 27 | 28 | @abc.abstractmethod 29 | def compute_loss(self, *args, **kwargs): 30 | """Compute the losses necessary for training. 31 | 32 | return { 33 | 'loss_total': ..., 34 | 'loss_recon': ..., 35 | 'loss_latent': ..., 36 | 'codes': ..., 37 | ... 38 | } 39 | """ 40 | pass 41 | 42 | class RQVAE(Stage1Model): 43 | def __init__(self, 44 | *, 45 | embed_dim=64, 46 | n_embed=512, 47 | decay=0.99, 48 | loss_type='mse', 49 | latent_loss_weight=0.25, 50 | bottleneck_type='rq', 51 | ddconfig=None, 52 | checkpointing=False, 53 | ckpt_path=None, 54 | ignore_keys=[], 55 | **kwargs): 56 | super().__init__() 57 | 58 | assert loss_type in ['mse', 'l1'] 59 | 60 | self.encoder = Encoder(**ddconfig) 61 | self.decoder = Decoder(**ddconfig) 62 | 63 | def set_checkpointing(m): 64 | if isinstance(m, ResnetBlock): 65 | m.checkpointing = checkpointing 66 | 67 | self.encoder.apply(set_checkpointing) 68 | self.decoder.apply(set_checkpointing) 69 | 70 | if bottleneck_type == 'rq': 71 | latent_shape = kwargs['latent_shape'] 72 | code_shape = kwargs['code_shape'] 73 | shared_codebook = kwargs['shared_codebook'] 74 | restart_unused_codes = kwargs['restart_unused_codes'] 75 | self.quantizer = RQBottleneck(latent_shape=latent_shape, 76 | code_shape=code_shape, 77 | n_embed=n_embed, 78 | decay=decay, 79 | shared_codebook=shared_codebook, 80 | restart_unused_codes=restart_unused_codes, 81 | ) 82 | self.code_shape = code_shape 83 | else: 84 | raise ValueError("invalid 'bottleneck_type' (must be 'rq')") 85 | 86 | self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 87 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 88 | 89 | self.loss_type = loss_type 90 | self.latent_loss_weight = latent_loss_weight 91 | 92 | if ckpt_path is not None: 93 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 94 | 95 | def init_from_ckpt(self, path, ignore_keys=list()): 96 | sd = torch.load(path, map_location="cpu")["state_dict"] 97 | keys = list(sd.keys()) 98 | for k in keys: 99 | for ik in ignore_keys: 100 | if k.startswith(ik): 101 | print("Deleting key {} from state_dict.".format(k)) 102 | del sd[k] 103 | self.load_state_dict(sd, strict=False) 104 | print(f"Restored from {path}") 105 | 106 | def forward(self, xs): 107 | z_e = self.encode(xs) 108 | z_q, quant_loss, code = self.quantizer(z_e) 109 | out = self.decode(z_q) 110 | return out, quant_loss, code 111 | 112 | def encode(self, x): 113 | z_e = self.encoder(x) 114 | z_e = self.quant_conv(z_e).permute(0, 2, 3, 1).contiguous() 115 | return z_e 116 | 117 | def decode(self, z_q): 118 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 119 | z_q = self.post_quant_conv(z_q) 120 | out = self.decoder(z_q) 121 | return out 122 | 123 | @torch.no_grad() 124 | def get_codes(self, xs): 125 | z_e = self.encode(xs) 126 | _, _, code = self.quantizer(z_e) 127 | return code 128 | 129 | @torch.no_grad() 130 | def get_soft_codes(self, xs, temp=1.0, stochastic=False): 131 | assert hasattr(self.quantizer, 'get_soft_codes') 132 | 133 | z_e = self.encode(xs) 134 | soft_code, code = self.quantizer.get_soft_codes(z_e, temp=temp, stochastic=stochastic) 135 | return soft_code, code 136 | 137 | @torch.no_grad() 138 | def decode_code(self, code): 139 | z_q = self.quantizer.embed_code(code) 140 | decoded = self.decode(z_q) 141 | return decoded 142 | 143 | def get_recon_imgs(self, xs_real, xs_recon): 144 | 145 | xs_real = xs_real * 0.5 + 0.5 146 | xs_recon = xs_recon * 0.5 + 0.5 147 | xs_recon = torch.clamp(xs_recon, 0, 1) 148 | 149 | return xs_real, xs_recon 150 | 151 | def compute_loss(self, out, quant_loss, code, xs=None, valid=False): 152 | 153 | if self.loss_type == 'mse': 154 | loss_recon = F.mse_loss(out, xs, reduction='mean') 155 | elif self.loss_type == 'l1': 156 | loss_recon = F.l1_loss(out, xs, reduction='mean') 157 | else: 158 | raise ValueError('incompatible loss type') 159 | 160 | loss_latent = quant_loss 161 | 162 | if valid: 163 | loss_recon = loss_recon * xs.shape[0] * xs.shape[1] 164 | loss_latent = loss_latent * xs.shape[0] 165 | 166 | loss_total = loss_recon + self.latent_loss_weight * loss_latent 167 | 168 | return { 169 | 'loss_total': loss_total, 170 | 'loss_recon': loss_recon, 171 | 'loss_latent': loss_latent, 172 | 'codes': [code] 173 | } 174 | 175 | def get_last_layer(self): 176 | return self.decoder.conv_out.weight 177 | 178 | @torch.no_grad() 179 | def get_code_emb_with_depth(self, code): 180 | return self.quantizer.embed_code_with_depth(code) 181 | 182 | @torch.no_grad() 183 | def decode_partial_code(self, code, code_idx, decode_type='select'): 184 | r""" 185 | Use partial codebooks and decode the codebook features. 186 | If decode_type == 'select', the (code_idx)-th codebook features are decoded. 187 | If decode_type == 'add', the [0,1,...,code_idx]-th codebook features are added and decoded. 188 | """ 189 | z_q = self.quantizer.embed_partial_code(code, code_idx, decode_type) 190 | decoded = self.decode(z_q) 191 | return decoded 192 | 193 | @torch.no_grad() 194 | def forward_partial_code(self, xs, code_idx, decode_type='select'): 195 | r""" 196 | Reconstuct an input using partial codebooks. 197 | """ 198 | code = self.get_codes(xs) 199 | out = self.decode_partial_code(code, code_idx, decode_type) 200 | return out 201 | 202 | 203 | if __name__ == "__main__": 204 | pass -------------------------------------------------------------------------------- /models/stage1/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn_LinearWarmup(warmup_steps, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: 10 | return 1.0 11 | 12 | def Scheduler_LinearWarmup(warmup_steps): 13 | return partial(fn_LinearWarmup, warmup_steps) 14 | 15 | 16 | def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step): 17 | if step < warmup_steps: # linear warmup 18 | return float(step) / float(max(1, warmup_steps)) 19 | else: # cosine learning rate schedule 20 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 21 | return max(multipler, multipler_min) 22 | 23 | def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min): 24 | return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min) 25 | -------------------------------------------------------------------------------- /models/stage2/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn(warmup_steps, max_steps, multipler_min, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: # cosine learning rate schedule 10 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 11 | return max(multipler, multipler_min) 12 | 13 | def learning_rate_schedule(warmup_steps, max_steps, multipler_min): 14 | return partial(fn, warmup_steps, max_steps, multipler_min) 15 | 16 | def disabled_train(self, mode=True): 17 | """Overwrite model.train with this function to make sure train/eval mode 18 | does not change anymore.""" 19 | return self 20 | 21 | # commonly used sample functions 22 | def top_k_logits(logits, k): 23 | v, ix = torch.topk(logits, k) 24 | out = logits.clone() 25 | out[out < v[..., [-1]]] = -float('Inf') 26 | return out 27 | 28 | def top_p_logits(probs, p): 29 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 30 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 31 | 32 | sorted_idx_remove_cond = cum_probs >= p 33 | 34 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 35 | sorted_idx_remove_cond[..., 0] = 0 36 | 37 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 38 | probs = probs.masked_fill(indices_to_remove, 0.0) 39 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) 40 | return norm_probs -------------------------------------------------------------------------------- /modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from utils.utils import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /modules/discriminator/stylegan_lucidrains.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def default(val, d): 5 | return val if val is not None else d 6 | 7 | class ChanLayerNorm(nn.Module): 8 | def __init__( 9 | self, 10 | dim, 11 | eps = 1e-5 12 | ): 13 | super().__init__() 14 | self.eps = eps 15 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) 16 | 17 | def forward(self, x): 18 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 19 | mean = torch.mean(x, dim = 1, keepdim = True) 20 | return (x - mean) * (var + self.eps).rsqrt() * self.gamma 21 | 22 | class CrossEmbedLayer(nn.Module): 23 | def __init__( 24 | self, 25 | dim_in, 26 | kernel_sizes, 27 | dim_out = None, 28 | stride = 2 29 | ): 30 | super().__init__() 31 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) 32 | dim_out = default(dim_out, dim_in) 33 | 34 | kernel_sizes = sorted(kernel_sizes) 35 | num_scales = len(kernel_sizes) 36 | 37 | # calculate the dimension at each scale 38 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] 39 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] 40 | 41 | self.convs = nn.ModuleList([]) 42 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 43 | self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) 44 | 45 | def forward(self, x): 46 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 47 | return torch.cat(fmaps, dim = 1) 48 | 49 | class Block(nn.Module): 50 | def __init__( 51 | self, 52 | dim, 53 | dim_out, 54 | groups = 8 55 | ): 56 | super().__init__() 57 | self.groupnorm = nn.GroupNorm(groups, dim) 58 | self.activation = nn.LeakyReLU(0.1) 59 | self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) 60 | 61 | def forward(self, x, scale_shift = None): 62 | x = self.groupnorm(x) 63 | x = self.activation(x) 64 | return self.project(x) 65 | 66 | class ResnetBlock(nn.Module): 67 | def __init__( 68 | self, 69 | dim, 70 | dim_out = None, 71 | *, 72 | groups = 8 73 | ): 74 | super().__init__() 75 | dim_out = default(dim_out, dim) 76 | self.block = Block(dim, dim_out, groups = groups) 77 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 78 | 79 | def forward(self, x): 80 | h = self.block(x) 81 | return h + self.res_conv(x) 82 | 83 | # discriminator 84 | 85 | class Discriminator(nn.Module): 86 | def __init__( 87 | self, 88 | dim = 256, 89 | discr_layers = 6, 90 | channels = 3, 91 | groups = 8, 92 | cross_embed_kernel_sizes = (3, 7, 15) 93 | ): 94 | super().__init__() 95 | 96 | layer_mults = list(map(lambda t: 2 ** t, range(discr_layers))) 97 | layer_dims = [dim * mult for mult in layer_mults] 98 | dims = (dim, *layer_dims) 99 | 100 | init_dim, *_, final_dim = dims 101 | dim_pairs = zip(dims[:-1], dims[1:]) 102 | 103 | self.layers = nn.ModuleList([nn.Sequential( 104 | CrossEmbedLayer(channels, cross_embed_kernel_sizes, init_dim, stride = 1), 105 | nn.LeakyReLU(0.1) 106 | )]) 107 | 108 | for dim_in, dim_out in dim_pairs: 109 | self.layers.append(nn.Sequential( 110 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), 111 | nn.LeakyReLU(0.1), 112 | nn.GroupNorm(groups, dim_out), 113 | ResnetBlock(dim_out, dim_out), 114 | )) 115 | 116 | self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training 117 | nn.Conv2d(final_dim, final_dim, 1), 118 | nn.LeakyReLU(0.1), 119 | nn.Conv2d(final_dim, 1, 4) 120 | ) 121 | 122 | def forward(self, x): 123 | for net in self.layers: 124 | x = net(x) 125 | 126 | return self.to_logits(x) 127 | 128 | if __name__ == "__main__": 129 | x = torch.randn(10, 3, 256, 256) 130 | 131 | D = Discriminator( 132 | dim = 256, 133 | discr_layers = 6, 134 | channels = 3, 135 | groups = 8, 136 | cross_embed_kernel_sizes = (3, 7, 15) 137 | ) 138 | 139 | y = D(x) 140 | print(y.size()) -------------------------------------------------------------------------------- /modules/dynamic_modules/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import os, sys 6 | sys.path.append(os.getcwd()) 7 | from modules.diffusionmodules.model import ResnetBlock, AttnBlock, Upsample, Normalize, nonlinearity 8 | 9 | 10 | class Decoder(nn.Module): 11 | def __init__(self, 12 | ch, in_ch, out_ch, ch_mult, num_res_blocks, resolution, 13 | attn_resolutions, dropout = 0.0, resamp_with_conv = True, give_pre_end = False, 14 | ): 15 | super().__init__() 16 | self.num_resolutions = len(ch_mult) 17 | self.num_res_blocks = num_res_blocks 18 | self.resolution = resolution 19 | self.in_ch = in_ch 20 | self.temb_ch = 0 21 | self.ch = ch 22 | self.give_pre_end = give_pre_end 23 | 24 | # compute block_in and curr_res at lowest res 25 | block_in = ch*ch_mult[self.num_resolutions-1] 26 | curr_res = resolution // 2**(self.num_resolutions-1) 27 | self.z_shape = (1,in_ch,curr_res,curr_res) 28 | print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) 29 | 30 | self.conv_in = torch.nn.Conv2d(in_ch, block_in, kernel_size=3, stride=1, padding=1) 31 | 32 | # middle 33 | self.mid = nn.Module() 34 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 35 | self.mid.attn_1 = AttnBlock(block_in) 36 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 37 | 38 | # upsampling 39 | self.up = nn.ModuleList() 40 | for i_level in reversed(range(self.num_resolutions)): 41 | block = nn.ModuleList() 42 | attn = nn.ModuleList() 43 | block_out = ch*ch_mult[i_level] 44 | for i_block in range(self.num_res_blocks+1): 45 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) 46 | block_in = block_out 47 | if curr_res in attn_resolutions: 48 | attn.append(AttnBlock(block_in)) 49 | up = nn.Module() 50 | up.block = block 51 | up.attn = attn 52 | if i_level != 0: 53 | up.upsample = Upsample(block_in, resamp_with_conv) 54 | curr_res = curr_res * 2 55 | self.up.insert(0, up) # prepend to get consistent order 56 | 57 | # end 58 | self.norm_out = Normalize(block_in) 59 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 60 | 61 | def forward(self, h, grain_indices): 62 | # z to block_in 63 | temb = None 64 | h = self.conv_in(h) 65 | 66 | # middle 67 | h = self.mid.block_1(h, temb) 68 | h = self.mid.attn_1(h) 69 | h = self.mid.block_2(h, temb) 70 | 71 | # upsampling 72 | for i_level in reversed(range(self.num_resolutions)): 73 | for i_block in range(self.num_res_blocks+1): 74 | h = self.up[i_level].block[i_block](h, temb) 75 | if len(self.up[i_level].attn) > 0: 76 | h = self.up[i_level].attn[i_block](h) 77 | if i_level != 0: 78 | h = self.up[i_level].upsample(h) 79 | 80 | # end 81 | if self.give_pre_end: 82 | return h 83 | 84 | h = self.norm_out(h) 85 | h = nonlinearity(h) 86 | h = self.conv_out(h) 87 | 88 | return h -------------------------------------------------------------------------------- /modules/dynamic_modules/DecoderPositional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | from einops import rearrange 6 | 7 | import os, sys 8 | sys.path.append(os.getcwd()) 9 | from modules.diffusionmodules.model import ResnetBlock, AttnBlock, Upsample, Normalize, nonlinearity 10 | from modules.dynamic_modules.tools import trunc_normal_ 11 | from modules.dynamic_modules.fourier_embedding import FourierPositionEmbedding 12 | 13 | class PositionEmbedding2DLearned(nn.Module): 14 | def __init__(self, n_row, feats_dim, n_col=None): 15 | super().__init__() 16 | n_col = n_col if n_col is not None else n_row 17 | self.row_embed = nn.Embedding(n_row, feats_dim) 18 | self.col_embed = nn.Embedding(n_col, feats_dim) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | # nn.init.uniform_(self.row_embed.weight) 23 | # nn.init.uniform_(self.col_embed.weight) 24 | trunc_normal_(self.row_embed.weight) 25 | trunc_normal_(self.col_embed.weight) 26 | 27 | def forward(self, x): 28 | h, w = x.shape[-2:] 29 | i = torch.arange(w, device=x.device) 30 | j = torch.arange(h, device=x.device) 31 | x_emb = self.col_embed(i).unsqueeze(0).repeat(h, 1, 1) 32 | y_emb = self.row_embed(j).unsqueeze(1).repeat(1, w, 1) 33 | pos = (x_emb + y_emb).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 34 | 35 | if x.dim() == 5: # relative position embedding 36 | pos = pos.unsqueeze(-3) 37 | 38 | x = x + pos 39 | return x 40 | 41 | class Decoder(nn.Module): 42 | def __init__(self, 43 | ch, in_ch, out_ch, ch_mult, num_res_blocks, resolution, 44 | attn_resolutions, dropout = 0.0, resamp_with_conv = True, give_pre_end = False, 45 | latent_size = 32, window_size = 2, position_type = "relative" 46 | ): 47 | super().__init__() 48 | self.num_resolutions = len(ch_mult) 49 | self.num_res_blocks = num_res_blocks 50 | self.resolution = resolution 51 | self.in_ch = in_ch 52 | self.temb_ch = 0 53 | self.ch = ch 54 | self.give_pre_end = give_pre_end 55 | 56 | # compute block_in and curr_res at lowest res 57 | block_in = ch*ch_mult[self.num_resolutions-1] 58 | curr_res = resolution // 2**(self.num_resolutions-1) 59 | self.z_shape = (1,in_ch,curr_res,curr_res) 60 | print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) 61 | 62 | self.conv_in = torch.nn.Conv2d(in_ch, block_in, kernel_size=3, stride=1, padding=1) 63 | 64 | # middle 65 | self.mid = nn.Module() 66 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 67 | self.mid.attn_1 = AttnBlock(block_in) 68 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 69 | 70 | # upsampling 71 | self.up = nn.ModuleList() 72 | for i_level in reversed(range(self.num_resolutions)): 73 | block = nn.ModuleList() 74 | attn = nn.ModuleList() 75 | block_out = ch*ch_mult[i_level] 76 | for i_block in range(self.num_res_blocks+1): 77 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) 78 | block_in = block_out 79 | if curr_res in attn_resolutions: 80 | attn.append(AttnBlock(block_in)) 81 | up = nn.Module() 82 | up.block = block 83 | up.attn = attn 84 | if i_level != 0: 85 | up.upsample = Upsample(block_in, resamp_with_conv) 86 | curr_res = curr_res * 2 87 | self.up.insert(0, up) # prepend to get consistent order 88 | 89 | # end 90 | self.norm_out = Normalize(block_in) 91 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 92 | 93 | # relative pos embeddings 94 | self.position_type = position_type 95 | if self.position_type == "learned": 96 | self.position_bias = PositionEmbedding2DLearned(n_row=latent_size, feats_dim=in_ch) 97 | elif self.position_type == "learned-relative": 98 | self.position_bias = PositionEmbedding2DLearned(n_row=window_size, feats_dim=in_ch) 99 | self.window_size = window_size 100 | self.window_num = latent_size // window_size 101 | elif self.position_type == "fourier": 102 | self.position_bias = FourierPositionEmbedding(coord_size=latent_size, hidden_size=in_ch) 103 | elif self.position_type == "fourier+learned": 104 | self.position_bias_fourier = FourierPositionEmbedding(coord_size=latent_size, hidden_size=in_ch) 105 | self.position_bias_learned = PositionEmbedding2DLearned(n_row=latent_size, feats_dim=in_ch) 106 | else: 107 | raise NotImplementedError() 108 | 109 | def forward(self, h, grain_indices): 110 | if self.position_type == "full" or self.position_type == "fourier": 111 | h = self.position_bias(h) 112 | elif self.position_type == "relative": 113 | h = rearrange(h, "B C (n1 nH) (n2 nW) -> B C (n1 n2) nH nW", n1=self.window_num, nH=self.window_size, n2=self.window_num, nW=self.window_size) 114 | h = self.position_bias(h) 115 | h = rearrange(h, "B C (n1 n2) nH nW -> B C (n1 nH) (n2 nW)", n1=self.window_num, nH=self.window_size) 116 | elif self.position_type == "fourier+learned": 117 | h = self.position_bias_fourier(h) 118 | h = self.position_bias_learned(h) 119 | 120 | # z to block_in 121 | temb = None 122 | h = self.conv_in(h) 123 | 124 | # middle 125 | h = self.mid.block_1(h, temb) 126 | h = self.mid.attn_1(h) 127 | h = self.mid.block_2(h, temb) 128 | 129 | # upsampling 130 | for i_level in reversed(range(self.num_resolutions)): 131 | for i_block in range(self.num_res_blocks+1): 132 | h = self.up[i_level].block[i_block](h, temb) 133 | if len(self.up[i_level].attn) > 0: 134 | h = self.up[i_level].attn[i_block](h) 135 | if i_level != 0: 136 | h = self.up[i_level].upsample(h) 137 | 138 | # end 139 | if self.give_pre_end: 140 | return h 141 | 142 | h = self.norm_out(h) 143 | h = nonlinearity(h) 144 | h = self.conv_out(h) 145 | 146 | return h -------------------------------------------------------------------------------- /modules/dynamic_modules/EncoderDual.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | sys.path.append(os.getcwd()) 9 | 10 | import pytorch_lightning as pl 11 | from modules.diffusionmodules.model import (AttnBlock, Downsample, Normalize, ResnetBlock, nonlinearity) 12 | from utils.utils import instantiate_from_config 13 | 14 | # the last two grains 15 | class DualGrainEncoder(pl.LightningModule): 16 | def __init__(self, 17 | *, 18 | ch, 19 | ch_mult=(1,2,4,8), 20 | num_res_blocks, 21 | attn_resolutions, 22 | dropout=0.0, 23 | resamp_with_conv=True, 24 | in_channels, 25 | resolution, 26 | z_channels, 27 | router_config=None, 28 | update_router=True, 29 | **ignore_kwargs 30 | ): 31 | super().__init__() 32 | 33 | self.ch = ch 34 | self.temb_ch = 0 35 | self.num_resolutions = len(ch_mult) 36 | self.num_res_blocks = num_res_blocks 37 | self.resolution = resolution 38 | self.in_channels = in_channels 39 | 40 | # downsampling 41 | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 42 | 43 | curr_res = resolution 44 | in_ch_mult = (1,)+tuple(ch_mult) 45 | self.down = nn.ModuleList() 46 | for i_level in range(self.num_resolutions): 47 | block = nn.ModuleList() 48 | attn = nn.ModuleList() 49 | block_in = ch*in_ch_mult[i_level] 50 | block_out = ch*ch_mult[i_level] 51 | for i_block in range(self.num_res_blocks): 52 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) 53 | block_in = block_out 54 | if curr_res in attn_resolutions: 55 | attn.append(AttnBlock(block_in)) 56 | down = nn.Module() 57 | down.block = block 58 | down.attn = attn 59 | if i_level != self.num_resolutions-1: 60 | down.downsample = Downsample(block_in, resamp_with_conv) 61 | curr_res = curr_res // 2 62 | self.down.append(down) 63 | 64 | # middle for the coarse grain 65 | self.mid_coarse = nn.Module() 66 | self.mid_coarse.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 67 | self.mid_coarse.attn_1 = AttnBlock(block_in) 68 | self.mid_coarse.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 69 | 70 | # end for the coarse grain 71 | self.norm_out_coarse = Normalize(block_in) 72 | self.conv_out_coarse = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) 73 | 74 | block_in_finegrain = block_in // (ch_mult[-1] // ch_mult[-2]) 75 | # middle for the fine grain 76 | self.mid_fine = nn.Module() 77 | self.mid_fine.block_1 = ResnetBlock(in_channels=block_in_finegrain, out_channels=block_in_finegrain, temb_channels=self.temb_ch, dropout=dropout) 78 | self.mid_fine.attn_1 = AttnBlock(block_in_finegrain) 79 | self.mid_fine.block_2 = ResnetBlock(in_channels=block_in_finegrain, out_channels=block_in_finegrain, temb_channels=self.temb_ch, dropout=dropout) 80 | 81 | # end for the fine grain 82 | self.norm_out_fine = Normalize(block_in_finegrain) 83 | self.conv_out_fine = torch.nn.Conv2d(block_in_finegrain, z_channels, kernel_size=3, stride=1, padding=1) 84 | 85 | self.router = instantiate_from_config(router_config) 86 | self.update_router = update_router 87 | 88 | 89 | def forward(self, x, x_entropy): 90 | assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 91 | 92 | # timestep embedding 93 | temb = None 94 | 95 | # downsampling 96 | hs = [self.conv_in(x)] 97 | for i_level in range(self.num_resolutions): 98 | for i_block in range(self.num_res_blocks): 99 | h = self.down[i_level].block[i_block](hs[-1], temb) 100 | if len(self.down[i_level].attn) > 0: 101 | h = self.down[i_level].attn[i_block](h) 102 | hs.append(h) 103 | if i_level != self.num_resolutions-1: 104 | hs.append(self.down[i_level].downsample(hs[-1])) 105 | if i_level == self.num_resolutions-2: 106 | h_fine = h 107 | 108 | h_coarse = hs[-1] 109 | 110 | # middle for the h_coarse 111 | h_coarse = self.mid_coarse.block_1(h_coarse, temb) 112 | h_coarse = self.mid_coarse.attn_1(h_coarse) 113 | h_coarse = self.mid_coarse.block_2(h_coarse, temb) 114 | 115 | # end for the h_coarse 116 | h_coarse = self.norm_out_coarse(h_coarse) 117 | h_coarse = nonlinearity(h_coarse) 118 | h_coarse = self.conv_out_coarse(h_coarse) 119 | 120 | # middle for the h_fine 121 | h_fine = self.mid_fine.block_1(h_fine, temb) 122 | h_fine = self.mid_fine.attn_1(h_fine) 123 | h_fine = self.mid_fine.block_2(h_fine, temb) 124 | 125 | # end for the h_coarse 126 | h_fine = self.norm_out_fine(h_fine) 127 | h_fine = nonlinearity(h_fine) 128 | h_fine = self.conv_out_fine(h_fine) 129 | 130 | # dynamic routing 131 | gate = self.router(h_fine=h_fine, h_coarse=h_coarse, entropy=x_entropy) 132 | if self.update_router and self.training: 133 | gate = F.gumbel_softmax(gate, dim=-1, hard=True) 134 | gate = gate.permute(0,3,1,2) 135 | indices = gate.argmax(dim=1) 136 | 137 | h_coarse = h_coarse.repeat_interleave(2, dim=-1).repeat_interleave(2, dim=-2) 138 | indices_repeat = indices.repeat_interleave(2, dim=-1).repeat_interleave(2, dim=-2).unsqueeze(1) 139 | # 0 for coarse-grained and 1 for fine-grained 140 | h_dual = torch.where(indices_repeat==0, h_coarse, h_fine) 141 | 142 | if self.update_router and self.training: 143 | gate_grad = gate.max(dim=1, keepdim=True)[0] 144 | gate_grad = gate_grad.repeat_interleave(2, dim=-1).repeat_interleave(2, dim=-2) 145 | h_dual = h_dual * gate_grad 146 | 147 | coarse_mask = 0.25 * torch.ones_like(indices_repeat).to(h_dual.device) 148 | fine_mask = 1.0 * torch.ones_like(indices_repeat).to(h_dual.device) 149 | codebook_mask = torch.where(indices_repeat==0, coarse_mask, fine_mask) 150 | 151 | return { 152 | "h_dual": h_dual, 153 | "indices": indices, 154 | "codebook_mask": codebook_mask, 155 | "gate": gate, 156 | } -------------------------------------------------------------------------------- /modules/dynamic_modules/RouterDual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import json 5 | 6 | class DualGrainFeatureRouter(nn.Module): 7 | def __init__(self, num_channels, normalization_type="none", gate_type="1layer-fc"): 8 | super().__init__() 9 | self.gate_pool = nn.AvgPool2d(2, 2) 10 | self.gate_type = gate_type 11 | if gate_type == "1layer-fc": 12 | self.gate = nn.Linear(num_channels * 2, 2) 13 | elif gate_type == "2layer-fc-SiLu": 14 | self.gate = nn.Sequential( 15 | nn.Linear(num_channels * 2, num_channels * 2), 16 | nn.SiLU(inplace=True), 17 | nn.Linear(num_channels * 2, 2), 18 | ) 19 | else: 20 | raise NotImplementedError() 21 | 22 | self.num_splits = 2 23 | self.normalization_type = normalization_type 24 | if self.normalization_type == "none": 25 | self.feature_norm_fine = nn.Identity() 26 | self.feature_norm_coarse = nn.Identity() 27 | elif "group" in self.normalization_type: # like "group-32" 28 | num_groups = int(self.normalization_type.split("-")[-1]) 29 | self.feature_norm_fine = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) 30 | self.feature_norm_coarse = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) 31 | else: 32 | raise NotImplementedError() 33 | 34 | 35 | def forward(self, h_fine, h_coarse, entropy=None): 36 | h_fine = self.feature_norm_fine(h_fine) 37 | h_coarse = self.feature_norm_coarse(h_coarse) 38 | 39 | avg_h_fine = self.gate_pool(h_fine) 40 | h_logistic = torch.cat([h_coarse, avg_h_fine], dim=1).permute(0,2,3,1) 41 | 42 | gate = self.gate(h_logistic) # torch.Size([30, 16, 16, 2]) 43 | return gate 44 | 45 | 46 | class DualGrainFixedEntropyRouter(nn.Module): 47 | def __init__(self, json_path, fine_grain_ratito,): 48 | super().__init__() 49 | with open(json_path, "r", encoding="utf-8") as f: 50 | content = json.load(f) 51 | self.fine_grain_threshold = content["{}".format(str(int(100 - fine_grain_ratito * 100)))] 52 | 53 | def forward(self, h_fine=None, h_coarse=None, entropy=None): 54 | gate_fine = (entropy > self.fine_grain_threshold).bool().long().unsqueeze(-1) 55 | gate_coarse = (entropy <= self.fine_grain_threshold).bool().long().unsqueeze(-1) 56 | gate = torch.cat([gate_coarse, gate_fine], dim=-1) 57 | return gate 58 | 59 | # class DualGrainDynamicEntropyRouter(nn.Module): 60 | # def __init__(self, json_path, fine_grain_ratito_min=0.01, fine_grain_ratito_max=0.99): 61 | # super().__init__() 62 | # with open(json_path, "r", encoding="utf-8") as f: 63 | # self.content = json.load(f) 64 | # self.fine_grain_ratito_min = int(fine_grain_ratito_min * 100) # inclusive 65 | # self.fine_grain_ratito_max = int(fine_grain_ratito_max * 100) + 1 # exclusive 66 | 67 | # def forward(self, h_fine=None, h_coarse=None, entropy=None): 68 | # # fine_grain_ratito = torch.randint(low=self.fine_grain_ratito_min, high=self.fine_grain_ratito_max, size=(1)) 69 | # fine_grain_ratito = np.random.randint(low=self.fine_grain_ratito_min, high=self.fine_grain_ratito_max) 70 | # fine_grain_threshold = self.content["{}".format(str(fine_grain_ratito))] 71 | 72 | # gate_fine = (entropy > fine_grain_threshold).bool().long().unsqueeze(-1) 73 | # gate_coarse = (entropy <= fine_grain_threshold).bool().long().unsqueeze(-1) 74 | # gate = torch.cat([gate_coarse, gate_fine], dim=-1) 75 | # return gate 76 | 77 | # if __name__ == "__main__": 78 | # model = DualGrainFixedEntropyRouter(json_path="scripts/tools/thresholds/entropy_thresholds_imagenet_train_patch-16.json", fine_grain_ratito=0.5) -------------------------------------------------------------------------------- /modules/dynamic_modules/RouterTriple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import json 5 | 6 | class TripleGrainFeatureRouter(nn.Module): 7 | def __init__(self, num_channels, normalization_type="none", gate_type="1layer-fc"): 8 | super().__init__() 9 | self.gate_median_pool = nn.AvgPool2d(2, 2) 10 | self.gate_fine_pool = nn.AvgPool2d(4, 4) 11 | 12 | self.num_splits = 3 13 | 14 | self.gate_type = gate_type 15 | if gate_type == "1layer-fc": 16 | self.gate = nn.Linear(num_channels * self.num_splits, self.num_splits) 17 | elif gate_type == "2layer-fc-SiLu": 18 | self.gate = nn.Sequential( 19 | nn.Linear(num_channels * self.num_splits, num_channels * self.num_splits), 20 | nn.SiLU(inplace=True), 21 | nn.Linear(num_channels * self.num_splits, self.num_splits), 22 | ) 23 | elif gate_type == "2layer-fc-ReLu": 24 | self.gate = nn.Sequential( 25 | nn.Linear(num_channels * self.num_splits, num_channels * self.num_splits), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(num_channels * self.num_splits, self.num_splits), 28 | ) 29 | else: 30 | raise NotImplementedError() 31 | 32 | self.normalization_type = normalization_type 33 | if self.normalization_type == "none": 34 | self.feature_norm_fine = nn.Identity() 35 | self.feature_norm_median = nn.Identity() 36 | self.feature_norm_coarse = nn.Identity() 37 | elif "group" in self.normalization_type: # like "group-32" 38 | num_groups = int(self.normalization_type.split("-")[-1]) 39 | self.feature_norm_fine = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) 40 | self.feature_norm_median = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) 41 | self.feature_norm_coarse = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True) 42 | else: 43 | raise NotImplementedError() 44 | 45 | 46 | def forward(self, h_fine, h_median, h_coarse, entropy=None): 47 | h_fine = self.feature_norm_fine(h_fine) 48 | h_median = self.feature_norm_median(h_median) 49 | h_coarse = self.feature_norm_coarse(h_coarse) 50 | 51 | avg_h_fine = self.gate_fine_pool(h_fine) 52 | avg_h_median = self.gate_median_pool(h_median) 53 | 54 | h_logistic = torch.cat([h_coarse, avg_h_median, avg_h_fine], dim=1).permute(0,2,3,1) 55 | gate = self.gate(h_logistic) 56 | return gate -------------------------------------------------------------------------------- /modules/dynamic_modules/budget.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class BudgetConstraint_RatioMSE_DualGrain(nn.Module): 5 | def __init__(self, target_ratio=0., gamma=1.0, min_grain_size=8, max_grain_size=16, calculate_all=True): 6 | super().__init__() 7 | self.target_ratio = target_ratio # e.g., 0.8 means 80% are fine-grained 8 | self.gamma = gamma 9 | self.calculate_all = calculate_all # calculate all grains 10 | self.loss = nn.MSELoss() 11 | 12 | self.const = min_grain_size * min_grain_size 13 | self.max_const = max_grain_size * max_grain_size - self.const 14 | 15 | def forward(self, gate): 16 | # 0 for coarse-grained and 1 for fine-grained 17 | # gate: (batch, 2, min_grain_size, min_grain_size) 18 | beta = 1.0 * gate[:, 0, :, :] + 4.0 * gate[:, 1, :, :] 19 | beta = (beta.sum() / gate.size(0)) - self.const 20 | budget_ratio = beta / self.max_const 21 | target_ratio = self.target_ratio * torch.ones_like(budget_ratio).to(gate.device) 22 | loss_budget = self.gamma * self.loss(budget_ratio, target_ratio) 23 | 24 | if self.calculate_all: 25 | loss_budget_last = self.gamma * self.loss(1 - budget_ratio, 1 - target_ratio) 26 | return loss_budget_last + loss_budget_last 27 | 28 | return loss_budget 29 | 30 | class BudgetConstraint_NormedSeperateRatioMSE_TripleGrain(nn.Module): 31 | def __init__(self, target_fine_ratio=0., target_median_ratio=0., gamma=1.0, min_grain_size=8, median_grain_size=16, max_grain_size=32): 32 | super().__init__() 33 | assert target_fine_ratio + target_median_ratio <= 1.0 34 | self.target_fine_ratio = target_fine_ratio # e.g., 0.8 means 80% are fine-grained 35 | self.target_median_ratio = target_median_ratio 36 | self.gamma = gamma 37 | self.loss = nn.MSELoss() 38 | 39 | self.min_const = min_grain_size * min_grain_size 40 | self.median_const = median_grain_size * median_grain_size - self.min_const 41 | self.max_const = max_grain_size * max_grain_size - self.min_const 42 | 43 | def forward(self, gate): 44 | # 0 for coarse-grained, 1 for median-grained, 2 for fine grained 45 | # gate: (batch, 3, min_grain_size, min_grain_size) 46 | beta_median = 1.0 * gate[:, 0, :, :] + 4.0 * gate[:, 1, :, :] + 1.0 * gate[:, 2, :, :] # the last term is the compensation for median ratio 47 | beta_median = (beta_median.sum() / gate.size(0)) - self.min_const 48 | budget_ratio_median = beta_median / self.median_const 49 | 50 | target_ratio_median = self.target_median_ratio * torch.ones_like(budget_ratio_median).to(gate.device) 51 | loss_budget_median = self.loss(budget_ratio_median, target_ratio_median) 52 | 53 | beta_fine = 1.0 * gate[:, 0, :, :] + 16.0 * gate[:, 2, :, :] + 1.0 * gate[:, 1, :, :] # the last term is the compensation for fine ratio 54 | beta_fine = (beta_fine.sum() / gate.size(0)) - self.min_const 55 | budget_ratio_fine = beta_fine / self.max_const 56 | 57 | target_ratio_fine = self.target_fine_ratio * torch.ones_like(budget_ratio_fine).to(gate.device) 58 | loss_budget_fine = self.gamma * self.loss(budget_ratio_fine, target_ratio_fine) 59 | 60 | return loss_budget_fine + loss_budget_median -------------------------------------------------------------------------------- /modules/dynamic_modules/fourier_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def convert_to_coord_format(b, h, w, device='cpu', integer_values=False): 6 | if integer_values: 7 | x_channel = torch.arange(w, dtype=torch.float, device=device).view(1, 1, 1, -1).repeat(b, 1, w, 1) 8 | y_channel = torch.arange(h, dtype=torch.float, device=device).view(1, 1, -1, 1).repeat(b, 1, 1, h) 9 | else: 10 | x_channel = torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).repeat(b, 1, w, 1) 11 | y_channel = torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).repeat(b, 1, 1, h) 12 | return torch.cat((x_channel, y_channel), dim=1) 13 | 14 | class ConLinear(nn.Module): 15 | def __init__(self, ch_in, ch_out, is_first=False, bias=True): 16 | super(ConLinear, self).__init__() 17 | self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=bias) 18 | if is_first: 19 | nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in)) 20 | else: 21 | nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in)) 22 | 23 | def forward(self, x): 24 | return self.conv(x) 25 | 26 | class SinActivation(nn.Module): 27 | def __init__(self,): 28 | super(SinActivation, self).__init__() 29 | 30 | def forward(self, x): 31 | return torch.sin(x) 32 | 33 | 34 | class LFF(nn.Module): 35 | def __init__(self, hidden_size, ): 36 | super(LFF, self).__init__() 37 | self.ffm = ConLinear(2, hidden_size, is_first=True) 38 | self.activation = SinActivation() 39 | 40 | def forward(self, x): 41 | x = self.ffm(x) 42 | x = self.activation(x) 43 | return x 44 | 45 | class FourierPositionEmbedding(nn.Module): 46 | def __init__(self, coord_size, hidden_size, integer_values=False): 47 | super().__init__() 48 | self.coord = convert_to_coord_format(1, coord_size, coord_size, "cpu", integer_values) 49 | self.lff = LFF(hidden_size) 50 | 51 | def forward(self, x): 52 | coord = self.coord.to(x.device) 53 | fourier_features = self.lff(coord) 54 | x = x + fourier_features 55 | return x 56 | 57 | if __name__ == "__main__": 58 | x = torch.randn(10, 64, 32, 32) 59 | module = FourierPositionEmbedding(coord_size=32, hidden_size=64) 60 | module(x) -------------------------------------------------------------------------------- /modules/dynamic_modules/label_provider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AbstractEncoder(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def encode(self, *args, **kwargs): 9 | raise NotImplementedError 10 | 11 | class PositionAwareSOSProvider(AbstractEncoder): 12 | # for unconditional training with dynamic grainularity quantized transformer 13 | def __init__(self, coarse_sos, coarse_pos_sos, fine_sos=None, fine_pos_sos=None, coarse_seg_sos=None, fine_seg_sos=None): 14 | super().__init__() 15 | self.coarse_sos = coarse_sos 16 | self.fine_sos = fine_sos 17 | self.coarse_pos_sos = coarse_pos_sos 18 | self.fine_pos_sos = fine_pos_sos 19 | self.activate_seg = True if coarse_seg_sos is not None else False 20 | if self.activate_seg: 21 | self.coarse_seg_sos = coarse_seg_sos 22 | self.fine_seg_sos = fine_seg_sos 23 | 24 | def encode(self, x): 25 | # get batch size from data and replicate sos_token 26 | batch_size = x.size(0) 27 | device = x.device 28 | 29 | c_coarse = (torch.ones(batch_size, 1) * self.coarse_sos).long().to(device) 30 | if self.fine_sos is not None: 31 | c_fine = (torch.ones(batch_size, 1) * self.fine_sos).long().to(device) 32 | else: 33 | c_fine = None 34 | 35 | c_pos_coarse = (torch.ones(batch_size, 1) * self.coarse_pos_sos).long().to(device) 36 | if self.fine_pos_sos is not None: 37 | c_pos_fine = (torch.ones(batch_size, 1) * self.fine_pos_sos).long().to(device) 38 | else: 39 | c_pos_fine = None 40 | 41 | if self.activate_seg: 42 | c_seg_coarse = (torch.ones(batch_size, 1) * self.coarse_seg_sos).long().to(device) 43 | c_seg_fine = (torch.ones(batch_size, 1) * self.fine_seg_sos).long().to(device) 44 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine 45 | 46 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, None, None 47 | 48 | class ClassForContentOnlyPositionAwareSOSProvider(AbstractEncoder): 49 | # for class-conditional training with dynamic grainularity quantized transformer 50 | # compared with unconditional, we replace [coarse_sos, fine_sos] by class-label 51 | # class-label += threshold 52 | def __init__(self, n_classes, threshold, coarse_pos_sos, fine_pos_sos=None, coarse_seg_sos=None, fine_seg_sos=None): 53 | super().__init__() 54 | self.n_classes = n_classes 55 | self.threshold = threshold 56 | 57 | self.coarse_pos_sos = coarse_pos_sos 58 | self.fine_pos_sos = fine_pos_sos 59 | self.activate_seg = True if coarse_seg_sos is not None else False 60 | if self.activate_seg: 61 | self.coarse_seg_sos = coarse_seg_sos 62 | self.fine_seg_sos = fine_seg_sos 63 | 64 | def encode(self, x): 65 | # get batch size from data and replicate sos_token 66 | batch_size = x.size(0) 67 | device = x.device 68 | 69 | x = x[:,None] 70 | 71 | c_coarse = x + self.threshold 72 | if self.fine_pos_sos is not None: 73 | c_fine = x + self.threshold 74 | else: 75 | c_fine = None 76 | 77 | c_pos_coarse = (torch.ones(batch_size, 1) * self.coarse_pos_sos).long().to(device) 78 | if self.fine_pos_sos is not None: 79 | c_pos_fine = (torch.ones(batch_size, 1) * self.fine_pos_sos).long().to(device) 80 | else: 81 | c_pos_fine = None 82 | 83 | if self.activate_seg: 84 | c_seg_coarse = (torch.ones(batch_size, 1) * self.coarse_seg_sos).long().to(device) 85 | c_seg_fine = (torch.ones(batch_size, 1) * self.fine_seg_sos).long().to(device) 86 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine 87 | 88 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, None, None 89 | 90 | class ClassAwareSOSProvider(AbstractEncoder): 91 | # for class-conditional training with dynamic grainularity quantized transformer 92 | # compared with unconditional, we replace [coarse_sos, fine_sos, coarse_pos_sos, fine_pos_sos] by class-label 93 | # class-label += threshold 94 | def __init__(self, n_classes, threshold_content, threshold_coarse_position, threshold_fine_position, coarse_seg_sos=None, fine_seg_sos=None): 95 | super().__init__() 96 | self.n_classes = n_classes 97 | self.threshold_content = threshold_content 98 | self.threshold_coarse_position = threshold_coarse_position 99 | self.threshold_fine_position = threshold_fine_position 100 | 101 | self.activate_seg = True if coarse_seg_sos is not None else False 102 | self.coarse_seg_sos = coarse_seg_sos 103 | self.fine_seg_sos = fine_seg_sos 104 | 105 | def encode(self, x): 106 | # get batch size from data and replicate sos_token 107 | batch_size = x.size(0) 108 | device = x.device 109 | 110 | x = x[:,None] 111 | 112 | c_coarse = x + self.threshold_content 113 | if self.fine_seg_sos is not None: 114 | c_fine = x + self.threshold_content 115 | else: 116 | c_fine = None 117 | 118 | c_pos_coarse = x + self.threshold_coarse_position 119 | if self.fine_seg_sos is not None: 120 | c_pos_fine = x + self.threshold_fine_position 121 | else: 122 | c_pos_fine = None 123 | 124 | if self.activate_seg: 125 | c_seg_coarse = (torch.ones(batch_size, 1) * self.coarse_seg_sos).long().to(device) 126 | c_seg_fine = (torch.ones(batch_size, 1) * self.fine_seg_sos).long().to(device) 127 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine 128 | 129 | return c_coarse, c_fine, c_pos_coarse, c_pos_fine, None, None -------------------------------------------------------------------------------- /modules/dynamic_modules/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 6 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 7 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 8 | def norm_cdf(x): 9 | # Computes standard normal cumulative distribution function 10 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 11 | 12 | if (mean < a - 2 * std) or (mean > b + 2 * std): 13 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 14 | "The distribution of values may be incorrect.", 15 | stacklevel=2) 16 | 17 | with torch.no_grad(): 18 | # Values are generated by using a truncated uniform distribution and 19 | # then using the inverse CDF for the normal distribution. 20 | # Get upper and lower cdf values 21 | l = norm_cdf((a - mean) / std) 22 | u = norm_cdf((b - mean) / std) 23 | 24 | # Uniformly fill tensor with values from [l, u], then translate to 25 | # [2l-1, 2u-1]. 26 | tensor.uniform_(2 * l - 1, 2 * u - 1) 27 | 28 | # Use inverse cdf transform for normal distribution to get truncated 29 | # standard normal 30 | tensor.erfinv_() 31 | 32 | # Transform to proper mean, std 33 | tensor.mul_(std * math.sqrt(2.)) 34 | tensor.add_(mean) 35 | 36 | # Clamp to ensure it's in the proper range 37 | tensor.clamp_(min=a, max=b) 38 | return tensor 39 | 40 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 41 | # type: (Tensor, float, float, float, float) -> Tensor 42 | r"""Fills the input Tensor with values drawn from a truncated 43 | normal distribution. The values are effectively drawn from the 44 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 45 | with values outside :math:`[a, b]` redrawn until they are within 46 | the bounds. The method used for generating the random values works 47 | best when :math:`a \leq \text{mean} \leq b`. 48 | Args: 49 | tensor: an n-dimensional `torch.Tensor` 50 | mean: the mean of the normal distribution 51 | std: the standard deviation of the normal distribution 52 | a: the minimum cutoff value 53 | b: the maximum cutoff value 54 | Examples: 55 | >>> w = torch.empty(3, 5) 56 | >>> nn.init.trunc_normal_(w) 57 | """ 58 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from utils.utils import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "modules/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules.losses.lpips import LPIPS 6 | from modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | class DummyLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | 14 | def adopt_weight(weight, global_step, threshold=0, value=0.): 15 | if global_step < threshold: 16 | weight = value 17 | return weight 18 | 19 | 20 | def hinge_d_loss(logits_real, logits_fake): 21 | loss_real = torch.mean(F.relu(1. - logits_real)) 22 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 23 | d_loss = 0.5 * (loss_real + loss_fake) 24 | return d_loss 25 | 26 | 27 | def vanilla_d_loss(logits_real, logits_fake): 28 | d_loss = 0.5 * ( 29 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 30 | torch.mean(torch.nn.functional.softplus(logits_fake))) 31 | return d_loss 32 | 33 | 34 | class VQLPIPSWithDiscriminator(nn.Module): 35 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 36 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 37 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 38 | disc_ndf=64, disc_loss="hinge", disc_weight_max=None): 39 | super().__init__() 40 | assert disc_loss in ["hinge", "vanilla"] 41 | self.codebook_weight = codebook_weight 42 | self.pixel_weight = pixelloss_weight 43 | self.perceptual_loss = LPIPS().eval() 44 | self.perceptual_weight = perceptual_weight 45 | 46 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 47 | n_layers=disc_num_layers, 48 | use_actnorm=use_actnorm, 49 | ndf=disc_ndf 50 | ).apply(weights_init) 51 | self.discriminator_iter_start = disc_start 52 | if disc_loss == "hinge": 53 | self.disc_loss = hinge_d_loss 54 | elif disc_loss == "vanilla": 55 | self.disc_loss = vanilla_d_loss 56 | else: 57 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 58 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 59 | self.disc_factor = disc_factor 60 | self.discriminator_weight = disc_weight 61 | self.disc_conditional = disc_conditional 62 | 63 | self.disc_weight_max = disc_weight_max 64 | # recommend: 1000 for [churches, bedrooms], 1 for [ffhq] 65 | # paper: Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes 66 | 67 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 68 | if last_layer is not None: 69 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 70 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 71 | else: 72 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 73 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 74 | 75 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 76 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 77 | d_weight = d_weight * self.discriminator_weight 78 | return d_weight 79 | 80 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, cond=None, split="train"): 81 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 82 | if self.perceptual_weight > 0: 83 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 84 | rec_loss = rec_loss + self.perceptual_weight * p_loss 85 | else: 86 | p_loss = torch.tensor([0.0]) 87 | 88 | nll_loss = rec_loss 89 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 90 | nll_loss = torch.mean(nll_loss) 91 | 92 | # now the GAN part 93 | if optimizer_idx == 0: 94 | # generator update 95 | if cond is None: 96 | assert not self.disc_conditional 97 | logits_fake = self.discriminator(reconstructions.contiguous()) 98 | else: 99 | assert self.disc_conditional 100 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 101 | g_loss = -torch.mean(logits_fake) 102 | 103 | try: 104 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 105 | except RuntimeError: 106 | assert not self.training 107 | d_weight = torch.tensor(0.0) 108 | 109 | # 增加对disc_weight最大值的限制 110 | if self.disc_weight_max is not None: 111 | d_weight.clamp_max_(self.disc_weight_max) 112 | 113 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 114 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 115 | 116 | log = {"{}_total_loss".format(split): loss.clone().detach().mean(), 117 | "{}_quant_loss".format(split): codebook_loss.detach().mean(), 118 | "{}_nll_loss".format(split): nll_loss.detach().mean(), 119 | "{}_rec_loss".format(split): rec_loss.detach().mean(), 120 | "{}_p_loss".format(split): p_loss.detach().mean(), 121 | "{}_d_weight".format(split): d_weight.detach(), 122 | "{}_disc_factor".format(split): torch.tensor(disc_factor), 123 | "{}_g_loss".format(split): g_loss.detach().mean(), 124 | } 125 | return loss, log 126 | 127 | if optimizer_idx == 1: 128 | # second pass for discriminator update 129 | if cond is None: 130 | logits_real = self.discriminator(inputs.contiguous().detach()) 131 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 132 | else: 133 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 134 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 135 | 136 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 137 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 138 | 139 | log = {"{}_disc_loss".format(split): d_loss.clone().detach().mean(), 140 | "{}_logits_real".format(split): logits_real.detach().mean(), 141 | "{}_logits_fake".format(split): logits_fake.detach().mean() 142 | } 143 | return d_loss, log 144 | -------------------------------------------------------------------------------- /modules/losses/vqperceptual_epoch.py: -------------------------------------------------------------------------------- 1 | # Discard! remove to vqperceptual_multidisc 2 | # since we start adversarial training just at the begining, no need to check epoch 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from modules.losses.lpips import LPIPS 9 | from modules.discriminator.model import NLayerDiscriminator, weights_init 10 | 11 | 12 | class DummyLoss(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | 17 | def adopt_weight(weight, global_step, threshold=0, value=0.): 18 | if global_step < threshold: 19 | weight = value 20 | return weight 21 | 22 | 23 | def hinge_d_loss(logits_real, logits_fake): 24 | loss_real = torch.mean(F.relu(1. - logits_real)) 25 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 26 | d_loss = 0.5 * (loss_real + loss_fake) 27 | return d_loss 28 | 29 | 30 | def vanilla_d_loss(logits_real, logits_fake): 31 | d_loss = 0.5 * ( 32 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 33 | torch.mean(torch.nn.functional.softplus(logits_fake))) 34 | return d_loss 35 | 36 | 37 | # we assume the disc_start denotes the start epoch, i.e. [0,1,2,3,4,...] 38 | class VQLPIPSWithDiscriminator(nn.Module): 39 | def __init__(self, disc_start_epoch, codebook_weight=1.0, pixelloss_weight=1.0, 40 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 41 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 42 | disc_ndf=64, disc_loss="hinge", disc_weight_max=None): 43 | super().__init__() 44 | assert disc_loss in ["hinge", "vanilla"] 45 | self.codebook_weight = codebook_weight 46 | self.pixel_weight = pixelloss_weight 47 | self.perceptual_loss = LPIPS().eval() 48 | self.perceptual_weight = perceptual_weight 49 | 50 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 51 | n_layers=disc_num_layers, 52 | use_actnorm=use_actnorm, 53 | ndf=disc_ndf 54 | ).apply(weights_init) 55 | self.discriminator_iter_start_epoch = disc_start_epoch 56 | if disc_loss == "hinge": 57 | self.disc_loss = hinge_d_loss 58 | elif disc_loss == "vanilla": 59 | self.disc_loss = vanilla_d_loss 60 | else: 61 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 62 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 63 | self.disc_factor = disc_factor 64 | self.discriminator_weight = disc_weight 65 | self.disc_conditional = disc_conditional 66 | 67 | self.disc_weight_max = disc_weight_max 68 | # recommend: 1000 for [churches, bedrooms], 1 for [ffhq] 69 | # paper: Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes 70 | 71 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 72 | if last_layer is not None: 73 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 74 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 75 | else: 76 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 77 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 78 | 79 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 80 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 81 | d_weight = d_weight * self.discriminator_weight 82 | return d_weight 83 | 84 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, current_epoch, last_layer=None, cond=None, split="train"): 85 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 86 | if self.perceptual_weight > 0: 87 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 88 | rec_loss = rec_loss + self.perceptual_weight * p_loss 89 | else: 90 | p_loss = torch.tensor([0.0]) 91 | 92 | nll_loss = rec_loss 93 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 94 | nll_loss = torch.mean(nll_loss) 95 | 96 | # now the GAN part 97 | if optimizer_idx == 0: 98 | # generator update 99 | if cond is None: 100 | assert not self.disc_conditional 101 | logits_fake = self.discriminator(reconstructions.contiguous()) 102 | else: 103 | assert self.disc_conditional 104 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 105 | g_loss = -torch.mean(logits_fake) 106 | 107 | try: 108 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 109 | except RuntimeError: 110 | assert not self.training 111 | d_weight = torch.tensor(0.0) 112 | 113 | # 增加对disc_weight最大值的限制 114 | if self.disc_weight_max is not None: 115 | d_weight.clamp_max_(self.disc_weight_max) 116 | 117 | disc_factor = adopt_weight( 118 | self.disc_factor, current_epoch, threshold=self.discriminator_iter_start_epoch 119 | ) 120 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 121 | 122 | log = {"{}_total_loss".format(split): loss.clone().detach().mean(), 123 | "{}_quant_loss".format(split): codebook_loss.detach().mean(), 124 | "{}_nll_loss".format(split): nll_loss.detach().mean(), 125 | "{}_rec_loss".format(split): rec_loss.detach().mean(), 126 | "{}_p_loss".format(split): p_loss.detach().mean(), 127 | "{}_d_weight".format(split): d_weight.detach(), 128 | "{}_disc_factor".format(split): torch.tensor(disc_factor), 129 | "{}_g_loss".format(split): g_loss.detach().mean(), 130 | } 131 | return loss, log 132 | 133 | if optimizer_idx == 1: 134 | # second pass for discriminator update 135 | if cond is None: 136 | logits_real = self.discriminator(inputs.contiguous().detach()) 137 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 138 | else: 139 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 140 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 141 | 142 | disc_factor = adopt_weight( 143 | self.disc_factor, current_epoch, threshold=self.discriminator_iter_start_epoch 144 | ) 145 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 146 | 147 | log = {"{}_disc_loss".format(split): d_loss.clone().detach().mean(), 148 | "{}_logits_real".format(split): logits_real.detach().mean(), 149 | "{}_logits_fake".format(split): logits_fake.detach().mean() 150 | } 151 | return d_loss, log 152 | -------------------------------------------------------------------------------- /modules/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrossmodalGroup/DynamicVectorQuantization/be6dc36c2cc0f238acff76d19ae2e26d0b98788d/modules/lpips/vgg.pth -------------------------------------------------------------------------------- /modules/scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LinearWarmUpScheduler: 5 | def __init__(self, ): 6 | pass 7 | 8 | 9 | 10 | class LambdaWarmUpCosineScheduler: 11 | """ 12 | note: use with a base_lr of 1.0 13 | """ 14 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 15 | self.lr_warm_up_steps = warm_up_steps 16 | self.lr_start = lr_start 17 | self.lr_min = lr_min 18 | self.lr_max = lr_max 19 | self.lr_max_decay_steps = max_decay_steps 20 | self.last_lr = 0. 21 | self.verbosity_interval = verbosity_interval 22 | 23 | def schedule(self, n): 24 | if self.verbosity_interval > 0: 25 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 26 | if n < self.lr_warm_up_steps: 27 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 28 | self.last_lr = lr 29 | return lr 30 | else: 31 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 32 | t = min(t, 1.0) 33 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 34 | 1 + np.cos(t * np.pi)) 35 | self.last_lr = lr 36 | return lr 37 | 38 | def __call__(self, n): 39 | return self.schedule(n) 40 | 41 | -------------------------------------------------------------------------------- /modules/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.optim.lr_scheduler import CosineAnnealingLR 4 | 5 | 6 | def create_scheduler(optimizer, scheduler_config, steps_per_epoch, max_epoch): 7 | 8 | multiplier = scheduler_config.multiplier 9 | warmup_steps = scheduler_config.warmip_epoch * steps_per_epoch 10 | buffer_steps = scheduler_config.buffer_epoch * steps_per_epoch 11 | final_steps = max_epoch * steps_per_epoch 12 | min_lr = scheduler_config.min_lr 13 | warmup_policy = scheduler_config.warmup_policy 14 | start_from_zero = scheduler_config.start_from_zero 15 | 16 | scheduler = CosineAnnealingLR( 17 | optimizer, T_max=final_steps - warmup_steps - buffer_steps, eta_min=min_lr 18 | ) 19 | 20 | if warmup_steps > 0.0: 21 | if warmup_policy == 'linear': 22 | multiplier = max(1.0, multiplier) 23 | elif warmup_policy == 'sqrt': 24 | multiplier = max(1.0, multiplier) 25 | elif warmup_policy == 'fix': 26 | multiplier = max(1.0, multiplier) 27 | elif warmup_policy == 'none': 28 | pass 29 | else: 30 | raise NotImplementedError(f'{warmup_policy} is not a valid warmup policy') 31 | warmup = GradualWarmup( 32 | optimizer, 33 | steps=warmup_steps, 34 | buffer_steps=buffer_steps, 35 | multiplier=multiplier, 36 | start_from_zero=start_from_zero 37 | ) 38 | else: 39 | warmup = None 40 | 41 | scheduler = Scheduler(warmup_scheduler=warmup, after_scheduler=scheduler) 42 | 43 | return scheduler 44 | 45 | 46 | class GradualWarmup(torch.optim.lr_scheduler._LRScheduler): 47 | def __init__(self, optimizer, steps, buffer_steps, multiplier, start_from_zero=True, last_epoch=-1): 48 | self.steps = steps 49 | self.t_steps = steps + buffer_steps 50 | self.multiplier = multiplier 51 | self.start_from_zero = start_from_zero 52 | 53 | super(GradualWarmup, self).__init__(optimizer, last_epoch) 54 | 55 | def get_lr(self): 56 | if self.last_epoch > self.steps: 57 | return [group['lr'] for group in self.optimizer.param_groups] 58 | 59 | if self.start_from_zero: 60 | multiplier = self.multiplier * min(1.0, (self.last_epoch / self.steps)) 61 | else: 62 | multiplier = 1 + ((self.multiplier - 1) * min(1.0, (self.last_epoch / self.steps))) 63 | return [lr * multiplier for lr in self.base_lrs] 64 | 65 | 66 | class Scheduler: 67 | def __init__(self, warmup_scheduler, after_scheduler): 68 | self.warmup_scheduler = warmup_scheduler 69 | self.after_scheduler = after_scheduler 70 | 71 | def step(self, epoch=None): 72 | if self.warmup_scheduler is not None: 73 | self.warmup_scheduler.step(epoch=epoch) 74 | 75 | if self.warmup_scheduler is None or \ 76 | self.warmup_scheduler.last_epoch > self.warmup_scheduler.t_steps: 77 | self.after_scheduler.step(epoch=epoch) 78 | 79 | def get_last_lr(self): 80 | if self.warmup_scheduler is not None and \ 81 | self.warmup_scheduler.last_epoch <= self.warmup_scheduler.t_steps: 82 | print("self.warmup_scheduler.get_last_lr(): ", self.warmup_scheduler.get_last_lr()) 83 | return self.warmup_scheduler.get_last_lr() 84 | else: 85 | return self.after_scheduler.get_last_lr() 86 | 87 | def state_dict(self): 88 | return { 89 | 'warmup': None if self.warmup_scheduler is None else self.warmup_scheduler.state_dict(), 90 | 'after': self.after_scheduler.state_dict() 91 | } 92 | 93 | def load_state_dict(self, state_dict): 94 | if self.warmup_scheduler is not None: 95 | self.warmup_scheduler.load_state_dict(state_dict['warmup']) 96 | self.after_scheduler.load_state_dict(state_dict['after']) -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/base_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseEmbedding(nn.Module): 6 | 7 | def get_loss(self): 8 | return None 9 | 10 | def forward(self, **kwargs): 11 | raise NotImplementedError 12 | 13 | def train(self, mode=True): 14 | self.training = mode 15 | if self.trainable and mode: 16 | super().train() 17 | return self 18 | 19 | def _set_trainable(self): 20 | if not self.trainable: 21 | for pn, p in self.named_parameters(): 22 | p.requires_grad = False 23 | self.eval() 24 | 25 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/README.md: -------------------------------------------------------------------------------- 1 | https://github.com/openai/CLIP -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/clip_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, end_idx=49152, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | # merges = merges[1:49152-256-2+1] 68 | 69 | # end_idx can be 49152 for CLIP 70 | # or 16384 for DALL-E 71 | merges = merges[1:end_idx-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] # with length 256 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) # with length = end_idx+256 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + (token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token + '' 92 | 93 | while True: 94 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word) - 1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, end_idx=49152, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | # merges = merges[1:49152-256-2+1] 68 | 69 | # end_idx can be 49152 for CLIP 70 | # or 16384 for DALL-E 71 | merges = merges[1:end_idx-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] # with length 256 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) # with length = end_idx+256 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + (token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token + '' 92 | 93 | while True: 94 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word) - 1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/clip_text_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os, sys 4 | sys.path.append(os.getcwd()) 5 | from modules.clip_text_encoder.clip import clip 6 | from modules.clip_text_encoder.clip import model as clip_model 7 | from modules.clip_text_encoder.base_embedding import BaseEmbedding 8 | 9 | class CLIPTextEmbedding(BaseEmbedding): 10 | def __init__(self, 11 | clip_name='ViT-B/32', 12 | num_embed=49408, 13 | normalize=True, 14 | pick_last_embedding=True, 15 | keep_seq_len_dim=False, 16 | additional_last_embedding=False, 17 | embed_dim=512, 18 | ): 19 | super().__init__() 20 | self.num_embed = num_embed 21 | self.clip_name = clip_name 22 | self.normalize = normalize 23 | self.pick_last_embedding = pick_last_embedding 24 | self.keep_seq_len_dim = keep_seq_len_dim 25 | self.additional_last_embedding = additional_last_embedding 26 | 27 | model, _ = clip.load(clip_name, device='cpu',jit=False) 28 | model = clip_model.build_model(model.state_dict()) 29 | 30 | self.token_embedding = model.token_embedding 31 | self.positional_embedding = model.positional_embedding 32 | self.transformer = model.transformer 33 | self.ln_final = model.ln_final 34 | self.text_projection = model.text_projection 35 | 36 | if embed_dim == 1024: 37 | self.embed_dim = self.text_projection.shape[1]*2 # to fit 1024 dimension of image embedding 38 | else: 39 | self.embed_dim = self.text_projection.shape[1] # original output, 512 dim 40 | 41 | self.trainable = False 42 | self._set_trainable() 43 | 44 | @property 45 | def dtype(self): 46 | return self.transformer.resblocks[0].attn.in_proj_weight.dtype 47 | 48 | def encode_text(self, text): 49 | text[text < 0] = 0 # some padded text token maybe negative, so set them to 0 50 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 51 | 52 | x = x + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | if self.pick_last_embedding: 60 | # take features from the eot embedding (eot_token is the highest number in each sequence) 61 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection # [batch_size, transformer.width] 62 | if self.keep_seq_len_dim: 63 | x = x.unsqueeze(dim=1) # [batch_size, 1, transformer.width] 64 | return x 65 | 66 | 67 | 68 | def forward(self, index, **kwargs): 69 | """ 70 | index: B x L, index 71 | mask: B x L, bool type. The value of False indicating padded index 72 | """ 73 | assert index.dim() == 2 # B x L 74 | text_feature = self.encode_text(index) 75 | 76 | if self.embed_dim == 1024: 77 | text_features = torch.cat((text_feature, text_feature), dim=2) 78 | else: 79 | text_features = text_feature 80 | if self.normalize: 81 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 82 | 83 | if self.additional_last_embedding == True: 84 | last_feature = text_feature[torch.arange(text_feature.shape[0]), index.argmax(dim=-1)] @ self.text_projection 85 | if self.keep_seq_len_dim: 86 | last_feature = last_feature.unsqueeze(dim=1) 87 | return text_features, last_feature 88 | 89 | 90 | return text_features 91 | 92 | 93 | if __name__ == "__main__": 94 | model = CLIPTextEmbedding() -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/my_tokenizer/base_codec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BaseCodec(nn.Module): 6 | 7 | def get_tokens(self, x, **kwargs): 8 | """ 9 | Input: 10 | x: input data 11 | Return: 12 | indices: B x L, the codebook indices, where L is the length 13 | of flattened feature map size 14 | """ 15 | raise NotImplementedError 16 | 17 | def get_number_of_tokens(self): 18 | """ 19 | Return: int, the number of tokens 20 | """ 21 | raise NotImplementedError 22 | 23 | def encode(self, img): 24 | raise NotImplementedError 25 | 26 | def decode(self, img_seq): 27 | raise NotImplementedError 28 | 29 | def forward(self, **kwargs): 30 | raise NotImplementedError 31 | 32 | def train(self, mode=True): 33 | self.training = mode 34 | if self.trainable and mode: 35 | return super().train(True) 36 | else: 37 | return super().train(False) 38 | 39 | def _set_trainable(self): 40 | if not self.trainable: 41 | for pn, p in self.named_parameters(): 42 | p.requires_grad = False 43 | self.eval() -------------------------------------------------------------------------------- /modules/text_encoders/clip_text_encoder/my_tokenizer/my_tokenize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, sys 3 | sys.path.append(os.getcwd()) 4 | from modules.clip_text_encoder.clip.clip import tokenize 5 | from modules.clip_text_encoder.my_tokenizer.base_codec import BaseCodec 6 | from utils.utils import instantiate_from_config 7 | 8 | class Tokenize(BaseCodec): 9 | def __init__(self, 10 | context_length:int = 77, 11 | add_start_and_end:bool = True, 12 | just_token = False, 13 | with_mask:bool = True, 14 | pad_value:int = 0, 15 | clip_embedding = False, 16 | condition_emb_config = None, 17 | tokenizer_config={ 18 | 'target': 'modules.clip_text_encoder.clip.simple_tokenizer.SimpleTokenizer', 19 | 'params':{ 20 | 'end_idx': 49152 # 16384 fo DALL-E 21 | }, 22 | }, 23 | ): 24 | """ 25 | This is a wrapper class for tokenize of texts. 26 | For CLIP and DALLE-pytorch tokenize, the default 27 | arguments are different: 28 | 29 | CLIP based: 30 | context_length: 77 31 | add_start_and_end: True 32 | 33 | DALLE-pytorch based: 34 | context_length: 256 35 | add_start_and_end: False 36 | 37 | """ 38 | super().__init__() 39 | self.context_length = context_length 40 | self.add_start_and_end = add_start_and_end 41 | self.with_mask = with_mask 42 | self.pad_value = pad_value 43 | self.just_token = just_token 44 | self.trainable = False 45 | self.condition_emb = None 46 | self.clip_embedding = clip_embedding 47 | if self.clip_embedding == True: 48 | assert condition_emb_config != None 49 | self.condition_emb = instantiate_from_config(condition_emb_config) 50 | 51 | self.tokenizer = instantiate_from_config(tokenizer_config) 52 | 53 | def __repr__(self): 54 | rep = "Tokenize for text\n\tcontent_length: {}\n\tadd_start_and_end: {}\n\twith_mask: {}"\ 55 | .format(self.context_length, self.add_start_and_end, self.with_mask) 56 | return rep 57 | 58 | def check_length(self, token): 59 | return len(token) <= self.context_length 60 | 61 | def get_tokens(self, text, **kwargs): 62 | text_token = tokenize(text, context_length=self.context_length, 63 | add_start_and_end=self.add_start_and_end, 64 | with_mask=self.with_mask, pad_value=self.pad_value, 65 | tokenizer=self.tokenizer, 66 | just_token=self.just_token) 67 | if self.clip_embedding == False: 68 | return text_token 69 | else: 70 | if self.condition_emb.additional_last_embedding == True: 71 | with torch.no_grad(): 72 | cond_emb, last_embedding = self.condition_emb(text_token['token'].cuda()) 73 | text_token['embed_token'] = cond_emb.detach() 74 | text_token['last_embed'] = last_embedding 75 | else: 76 | with torch.no_grad(): 77 | cond_emb = self.condition_emb(text_token['token'].cuda()) 78 | text_token['embed_token'] = cond_emb.detach() 79 | 80 | return text_token 81 | 82 | 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | text = "this is a test !" 88 | tokenizer = Tokenize() 89 | 90 | ids = tokenizer.get_tokens(text) 91 | print(text) 92 | print(ids) -------------------------------------------------------------------------------- /modules/transformer/hybrid_decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | from modules.transformer.modules import Block 7 | from modules.transformer.position_embeddings import build_position_embed 8 | from utils.utils import instantiate_from_config 9 | 10 | class VisionTransformerDecoder(nn.Module): 11 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, 12 | depth, num_heads, attn_drop_rate=0., drop_rate=0.,init_type="default", 13 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 14 | super().__init__() 15 | 16 | self.hw = image_size // patch_size 17 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 18 | self.blocks = nn.ModuleList([ 19 | Block( 20 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 21 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 22 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 23 | for i in range(depth)]) 24 | 25 | self.patch_size = patch_size 26 | 27 | if init_type == "default": 28 | self.apply(self._init_weights) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.LayerNorm): 36 | nn.init.constant_(m.bias, 0) 37 | nn.init.constant_(m.weight, 1.0) 38 | 39 | def forward(self, x): 40 | B = x.size(0) 41 | x = self.pos_emb(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 47 | return x 48 | 49 | 50 | class HybrdDecoder(nn.Module): 51 | def __init__(self, transformer_config, cnn_config): 52 | super().__init__() 53 | self.transformer = instantiate_from_config(transformer_config) 54 | self.cnn = instantiate_from_config(cnn_config) 55 | self.conv_out = self.cnn.conv_out 56 | 57 | def forward(self, x): 58 | x = self.transformer(x) 59 | x = self.cnn(x) 60 | return x 61 | 62 | # transformer 模块的输入还加入了mask,以实现自定义的功能 63 | class HybrdDecoder_V2(nn.Module): 64 | def __init__(self, transformer_config, cnn_config): 65 | super().__init__() 66 | self.transformer = instantiate_from_config(transformer_config) 67 | self.cnn = instantiate_from_config(cnn_config) 68 | self.conv_out = self.cnn.conv_out 69 | 70 | def forward(self, x, mask): 71 | x = self.transformer(x, mask) 72 | x = self.cnn(x) 73 | return x 74 | 75 | if __name__ == "__main__": 76 | x = torch.randn(10, 512, 16, 16) 77 | vit_decoder = VisionTransformerDecoder( 78 | image_size=256, patch_size=16, pos_embed_type="learned-2d", embed_dim=512, drop_rate=0., 79 | depth=6, num_heads=4, attn_drop_rate=0., output_channel=3, init_type="default" 80 | ) 81 | y = vit_decoder(x) 82 | print(y.size()) -------------------------------------------------------------------------------- /modules/transformer/mask_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # 我们希望自注意力中的信息流向 bias 为 unmasked token -> masked token 5 | class MaskSelfAttention_SquareGrowth(nn.Module): 6 | def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): 7 | super().__init__() 8 | self.num_heads = num_heads 9 | head_dim = dim // num_heads 10 | self.scale = head_dim ** -0.5 11 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 12 | self.proj = nn.Linear(dim, dim) 13 | 14 | self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj_drop = nn.Dropout(proj_drop) 16 | 17 | def forward(self, h, mask=None): 18 | # mask (_type_, optional): [batch_size, length] 19 | B, N, C = h.shape 20 | qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 21 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 22 | attn = (q @ k.transpose(-2, -1)).contiguous() * self.scale 23 | attn = attn.softmax(dim=-1) 24 | 25 | if mask is not None: 26 | unsqueezed_mask = mask.unsqueeze(-2).unsqueeze(-2) 27 | attn = attn * unsqueezed_mask 28 | 29 | # update mask with SquareGrowth 30 | new_mask = torch.sqrt(mask) 31 | 32 | attn = self.attn_drop(attn) 33 | h = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, C) 34 | h = self.proj(h) 35 | h = self.proj_drop(h) 36 | return h, new_mask 37 | 38 | 39 | if __name__ == "__main__": 40 | batch_size = 1 41 | dim = 256 42 | height = 2 43 | model = MaskSelfAttention_SquareGrowth(dim=256, num_heads=4) 44 | 45 | h = torch.randn(batch_size, height*height, 256) # (10,256,16,16) 46 | mask = torch.randint(0, 2, (batch_size, height*height)) 47 | print(mask) 48 | 49 | new_mask = mask + 0.02 * (1 - mask) 50 | print(new_mask) 51 | 52 | 53 | model(h, new_mask) -------------------------------------------------------------------------------- /modules/transformer/mask_attention_decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | import os, sys 7 | sys.path.append(os.getcwd()) 8 | 9 | from modules.transformer.modules import norm, Mlp 10 | from modules.transformer.mask_attention import MaskSelfAttention_SquareGrowth 11 | from modules.transformer.position_embeddings import build_position_embed 12 | 13 | class MaskBlock(nn.Module): 14 | def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., init_values=None, 15 | act_type="GELU", norm_type="layer", attn_type="msa-sg", size=None): 16 | super().__init__() 17 | self.norm1 = norm(norm_type, dim) 18 | 19 | self.norm2 = norm(norm_type, dim) 20 | self.mlp = Mlp(dim, hidden_radio=mlp_ratio, act_type=act_type, drop=drop) 21 | 22 | if attn_type == "msa-sg": 23 | self.attn = MaskSelfAttention_SquareGrowth( 24 | dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop 25 | ) 26 | else: 27 | raise ValueError 28 | 29 | if init_values > 0: 30 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 31 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 32 | else: 33 | self.gamma_1, self.gamma_2 = None, None 34 | 35 | def forward(self, x, mask, **ignore_kwargs): 36 | if self.gamma_1 is None: 37 | attn, new_mask = self.attn(h=self.norm1(x), mask=mask) 38 | x = x + attn 39 | x = x + self.mlp(self.norm2(x)) 40 | else: 41 | attn, new_mask = self.attn(h=self.norm1(x), mask=mask) 42 | x = x + self.gamma_1 * attn 43 | x = x + self.gamma_2 * self.mlp(self.norm2(x)) 44 | return x, new_mask 45 | 46 | class MaskVisionTransformerDecoder(nn.Module): 47 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, 48 | depth, num_heads, attn_drop_rate=0., drop_rate=0.,init_type="default", 49 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 50 | super().__init__() 51 | 52 | self.hw = image_size // patch_size 53 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 54 | self.blocks = nn.ModuleList([ 55 | MaskBlock( 56 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 57 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 58 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 59 | for i in range(depth)]) 60 | 61 | self.patch_size = patch_size 62 | 63 | if init_type == "default": 64 | self.apply(self._init_weights) 65 | 66 | def _init_weights(self, m): 67 | if isinstance(m, nn.Linear): 68 | trunc_normal_(m.weight, std=.02) 69 | if isinstance(m, nn.Linear) and m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.LayerNorm): 72 | nn.init.constant_(m.bias, 0) 73 | nn.init.constant_(m.weight, 1.0) 74 | 75 | def forward(self, x, mask): 76 | B = x.size(0) 77 | x = self.pos_emb(x) 78 | 79 | new_mask = mask + 0.02 * (1 - mask) # 将0初始化为一个小值,0.02 80 | for blk in self.blocks: 81 | # print(new_mask) 82 | x, new_mask = blk(x=x, mask=new_mask) 83 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 84 | return x -------------------------------------------------------------------------------- /modules/transformer/position_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | from einops import rearrange 6 | 7 | # input is (B,C,H,W), output is (B,HW,C) 8 | def build_position_embed(embed_type='learned', feats_dim=512, dropout=0., n_row=16): 9 | if embed_type == 'sine-1d': 10 | pos_embed = PositionalEncoding1d(emb_dim=feats_dim, dropout=dropout) 11 | elif embed_type == "sine-2d": 12 | pos_embed = PositionalEncoding2d(emb_dim=feats_dim, dropout=dropout) 13 | elif embed_type == "learned-2d": 14 | pos_embed = PositionEmbeddingLearned(n_row=n_row, feats_dim=feats_dim, dropout=dropout) 15 | else: 16 | raise ValueError(f"nor supported {embed_type}") 17 | return pos_embed 18 | 19 | 20 | ###################################################################################### 21 | # 1D position embedding 22 | ###################################################################################### 23 | class PositionalEncoding1d(nn.Module): 24 | def __init__(self, emb_dim, dropout=0.1, max_len=5000): 25 | super(PositionalEncoding1d, self).__init__() 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | pe = torch.zeros(max_len, emb_dim) 29 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 30 | div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim)) 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | pe[:, 1::2] = torch.cos(position * div_term) 33 | pe = pe.unsqueeze(0).transpose(0, 1) 34 | self.register_buffer('pe', pe) 35 | 36 | def forward(self, x): 37 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 38 | x = x + self.pe[:x.size(0), :].to(x.device) 39 | return self.dropout(x) 40 | 41 | ###################################################################################### 42 | # 2D position embedding 43 | ###################################################################################### 44 | class PositionalEncoding2d(nn.Module): 45 | def __init__(self, emb_dim, dropout, max_len=5000): 46 | super(PositionalEncoding2d, self).__init__() 47 | 48 | self.dropout = nn.Dropout(dropout) 49 | 50 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 51 | div_term = torch.exp(torch.arange(0, (emb_dim//2), 2).float() * (-math.log(10000.0) / (emb_dim//2))) 52 | pe_x = torch.zeros(max_len, emb_dim//2) 53 | pe_x[:, 0::2] = torch.sin(position * div_term) 54 | pe_x[:, 1::2] = torch.cos(position * div_term) 55 | 56 | pe_y = torch.zeros(max_len, emb_dim//2) 57 | pe_y[:, 1::2] = torch.sin(position * div_term) 58 | pe_y[:, 0::2] = torch.cos(position * div_term) 59 | 60 | self.register_buffer('pe_x', pe_x) 61 | self.register_buffer('pe_y', pe_y) 62 | 63 | def forward(self, x): 64 | _, _, h, w = x.shape 65 | add_x = self.pe_x[:h, :].unsqueeze(1).repeat(1,w,1) 66 | add_y = self.pe_y[:w, :].unsqueeze(0).repeat(h,1,1) 67 | add = torch.cat([add_x, add_y], dim=-1) #shape: h x w x dim 68 | add = add.permute(2, 0, 1).unsqueeze(0) 69 | 70 | x = x + add 71 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 72 | return self.dropout(x) 73 | 74 | class PositionEmbeddingLearned(nn.Module): 75 | """ 76 | This is a learned version of the position embedding 77 | """ 78 | def __init__(self, n_row, feats_dim, dropout, n_col=None): 79 | super().__init__() 80 | self.dropout = nn.Dropout(dropout) 81 | 82 | n_col = n_col if n_col is not None else n_row 83 | self.row_embed = nn.Embedding(n_row, feats_dim) 84 | self.col_embed = nn.Embedding(n_col, feats_dim) 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | nn.init.uniform_(self.row_embed.weight) 89 | nn.init.uniform_(self.col_embed.weight) 90 | 91 | def forward(self, x): 92 | h, w = x.shape[-2:] 93 | i = torch.arange(w, device=x.device) 94 | j = torch.arange(h, device=x.device) 95 | x_emb = self.col_embed(i).unsqueeze(0).repeat(h, 1, 1) 96 | y_emb = self.row_embed(j).unsqueeze(1).repeat(1, w, 1) 97 | pos = (x_emb + y_emb).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 98 | 99 | x = x + pos 100 | x = rearrange(x, "B C H W -> B (H W) C").contiguous() 101 | return self.dropout(x) -------------------------------------------------------------------------------- /modules/transformer/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from timm.models.layers import trunc_normal_ 5 | 6 | 7 | from modules.transformer.modules import PatchEmbed, Block 8 | from modules.transformer.position_embeddings import build_position_embed 9 | 10 | # vision transformer 11 | class VisionTransformerEncoder(nn.Module): 12 | def __init__(self, image_size, patch_size, input_channel, embed_dim, init_type, 13 | pos_embed_type, attn_drop_rate, drop_rate, depth, num_heads, 14 | mlp_ratio=4, norm_type="layer", act_type="GELU", init_values=0, attn_type="sa"): 15 | super().__init__() 16 | 17 | self.patch_embed = PatchEmbed( 18 | img_size=image_size, patch_size=patch_size, in_chans=input_channel, embed_dim=embed_dim 19 | ) 20 | 21 | self.hw = image_size // patch_size 22 | self.pos_emb = build_position_embed( 23 | embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw 24 | ) 25 | 26 | self.blocks = nn.ModuleList([ 27 | Block( 28 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 29 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 30 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 31 | for i in range(depth)]) 32 | 33 | if init_type == "default": 34 | self.apply(self._init_weights) 35 | 36 | def _init_weights(self, m): 37 | if isinstance(m, nn.Linear): 38 | trunc_normal_(m.weight, std=.02) 39 | if isinstance(m, nn.Linear) and m.bias is not None: 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.LayerNorm): 42 | nn.init.constant_(m.bias, 0) 43 | nn.init.constant_(m.weight, 1.0) 44 | 45 | def forward(self, images): 46 | x = self.patch_embed(images) 47 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 48 | x = self.pos_emb(x) 49 | for blk in self.blocks: 50 | x = blk(x) 51 | x = rearrange(x, "B (H W) C -> B C H W", H=self.hw, W=self.hw).contiguous() 52 | return x 53 | 54 | 55 | # vision transformer decoder 56 | class VisionTransformerDecoder(nn.Module): 57 | def __init__(self, image_size, patch_size, pos_embed_type, embed_dim, drop_rate, 58 | depth, num_heads, attn_drop_rate, output_channel, init_type, 59 | mlp_ratio=4, act_type="GELU", norm_type="layer", attn_type="sa", init_values=0): 60 | super().__init__() 61 | 62 | self.hw = image_size // patch_size 63 | self.pos_emb = build_position_embed(embed_type=pos_embed_type, feats_dim=embed_dim, dropout=drop_rate, n_row=self.hw) 64 | self.blocks = nn.ModuleList([ 65 | Block( 66 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, 67 | attn_drop=attn_drop_rate, init_values=init_values, act_type=act_type, 68 | norm_type=norm_type, attn_type=attn_type, size=int(self.hw)) 69 | for i in range(depth)]) 70 | 71 | # see: VECTOR-QUANTIZED IMAGE MODELING WITH IMPROVED VQGAN 72 | # self.output_linear = nn.Linear(embed_dim, patch_size*patch_size*output_channel) 73 | # self.output_linear = nn.Sequential( 74 | # nn.Linear(embed_dim, embed_dim), 75 | # nn.Tanh(), 76 | # nn.Linear(embed_dim, patch_size*patch_size*output_channel) 77 | # ) 78 | self.output_linear1 = nn.Linear(embed_dim, patch_size*patch_size*output_channel) 79 | self.conv_out = nn.Linear(patch_size*patch_size*output_channel, patch_size*patch_size*output_channel) 80 | # align with VQGAN 81 | 82 | self.output_channel = output_channel 83 | self.patch_size = patch_size 84 | 85 | if init_type == "default": 86 | self.apply(self._init_weights) 87 | 88 | def _init_weights(self, m): 89 | if isinstance(m, nn.Linear): 90 | trunc_normal_(m.weight, std=.02) 91 | if isinstance(m, nn.Linear) and m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.LayerNorm): 94 | nn.init.constant_(m.bias, 0) 95 | nn.init.constant_(m.weight, 1.0) 96 | 97 | def forward(self, x): 98 | B = x.size(0) 99 | x = self.pos_emb(x) 100 | 101 | for blk in self.blocks: 102 | x = blk(x) 103 | 104 | # x = self.output_linear(x) 105 | x = self.output_linear1(x) 106 | x = nn.Tanh()(x) 107 | x = self.conv_out(x) 108 | x = rearrange( 109 | x, "B (h w) (h_size w_size c) -> B c (h h_size) (w w_size)", 110 | B=B, h=self.hw, w=self.hw, h_size=self.patch_size, w_size=self.patch_size, c=self.output_channel).contiguous() 111 | return x 112 | 113 | if __name__ == "__main__": 114 | images = torch.randn(10,3,256,256) 115 | vit_encoder = VisionTransformerEncoder( 116 | image_size=256, patch_size=16, input_channel=3, embed_dim=512, init_type="default", 117 | pos_embed_type="learned-2d", attn_drop_rate=0., drop_rate=0., depth=6, num_heads=4 118 | ) 119 | x = vit_encoder(images) 120 | print(x.size()) 121 | 122 | vit_decoder = VisionTransformerDecoder( 123 | image_size=256, patch_size=16, pos_embed_type="learned-2d", embed_dim=512, drop_rate=0., 124 | depth=6, num_heads=4, attn_drop_rate=0., output_channel=3, init_type="default" 125 | ) 126 | y = vit_decoder(x) 127 | print(y.size()) -------------------------------------------------------------------------------- /modules/vector_quantization/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import torch.distributed as distributed 5 | from einops import rearrange, repeat 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | def noop(*args, **kwargs): 14 | pass 15 | 16 | def l2norm(t): 17 | return F.normalize(t, p = 2, dim = -1) 18 | 19 | def log(t, eps = 1e-20): 20 | return torch.log(t.clamp(min = eps)) 21 | 22 | def uniform_init(*shape): 23 | t = torch.empty(shape) 24 | nn.init.kaiming_uniform_(t) 25 | return t 26 | 27 | def gumbel_noise(t): 28 | noise = torch.zeros_like(t).uniform_(0, 1) 29 | return -log(-log(noise)) 30 | 31 | def gumbel_sample(t, temperature = 1., dim = -1): 32 | if temperature == 0: 33 | return t.argmax(dim = dim) 34 | 35 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 36 | 37 | def ema_inplace(moving_avg, new, decay): 38 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 39 | 40 | def laplace_smoothing(x, n_categories, eps = 1e-5): 41 | return (x + eps) / (x.sum() + n_categories * eps) 42 | 43 | def sample_vectors(samples, num): 44 | num_samples, device = samples.shape[0], samples.device 45 | if num_samples >= num: 46 | indices = torch.randperm(num_samples, device = device)[:num] 47 | else: 48 | indices = torch.randint(0, num_samples, (num,), device = device) 49 | 50 | return samples[indices] 51 | 52 | def batched_sample_vectors(samples, num): 53 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 54 | 55 | def pad_shape(shape, size, dim = 0): 56 | return [size if i == dim else s for i, s in enumerate(shape)] 57 | 58 | def sample_multinomial(total_count, probs): 59 | device = probs.device 60 | probs = probs.cpu() 61 | 62 | total_count = probs.new_full((), total_count) 63 | remainder = probs.new_ones(()) 64 | sample = torch.empty_like(probs, dtype = torch.long) 65 | 66 | for i, p in enumerate(probs): 67 | s = torch.binomial(total_count, p / remainder) 68 | sample[i] = s 69 | total_count -= s 70 | remainder -= p 71 | 72 | return sample.to(device) 73 | 74 | def all_gather_sizes(x, dim): 75 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 76 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 77 | distributed.all_gather(all_sizes, size) 78 | 79 | return torch.stack(all_sizes) 80 | 81 | def all_gather_variably_sized(x, sizes, dim = 0): 82 | rank = distributed.get_rank() 83 | all_x = [] 84 | 85 | for i, size in enumerate(sizes): 86 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 87 | distributed.broadcast(t, src = i, async_op = True) 88 | all_x.append(t) 89 | 90 | distributed.barrier() 91 | return all_x 92 | 93 | def sample_vectors_distributed(local_samples, num): 94 | rank = distributed.get_rank() 95 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 96 | 97 | if rank == 0: 98 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 99 | else: 100 | samples_per_rank = torch.empty_like(all_num_samples) 101 | 102 | distributed.broadcast(samples_per_rank, src = 0) 103 | samples_per_rank = samples_per_rank.tolist() 104 | 105 | local_samples = batched_sample_vectors(local_samples, samples_per_rank[rank]) 106 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 107 | return torch.cat(all_samples, dim = 0) 108 | 109 | def batched_bincount(x, *, minlength): 110 | batch, dtype, device = x.shape[0], x.dtype, x.device 111 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 112 | values = torch.ones_like(x) 113 | target.scatter_add_(-1, x, values) 114 | return target 115 | 116 | def kmeans( 117 | samples, 118 | num_clusters, 119 | num_iters = 10, 120 | use_cosine_sim = False, 121 | sample_fn = batched_sample_vectors, 122 | all_reduce_fn = noop 123 | ): 124 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 125 | 126 | means = sample_fn(samples, num_clusters) 127 | 128 | for _ in range(num_iters): 129 | if use_cosine_sim: 130 | dists = samples @ rearrange(means, 'h n d -> h d n') 131 | else: 132 | dists = -torch.cdist(samples, means, p = 2) 133 | 134 | buckets = torch.argmax(dists, dim = -1) 135 | bins = batched_bincount(buckets, minlength = num_clusters) 136 | all_reduce_fn(bins) 137 | 138 | zero_mask = bins == 0 139 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 140 | 141 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 142 | 143 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 144 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 145 | all_reduce_fn(new_means) 146 | 147 | if use_cosine_sim: 148 | new_means = l2norm(new_means) 149 | 150 | means = torch.where( 151 | rearrange(zero_mask, '... -> ... 1'), 152 | means, 153 | new_means 154 | ) 155 | 156 | return means, bins 157 | 158 | def batched_embedding(indices, embeds): 159 | batch, dim = indices.shape[1], embeds.shape[-1] 160 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 161 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 162 | return embeds.gather(2, indices) 163 | -------------------------------------------------------------------------------- /modules/vector_quantization/quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch import einsum 9 | 10 | sys.path.append(os.getcwd()) 11 | 12 | import modules.vector_quantization.common_utils as utils 13 | 14 | # support: cosine-similarity; kmeans centroids initialization; orthogonal_reg_weight, temperature sample 15 | class VectorQuantize(nn.Module): 16 | def __init__( 17 | self, 18 | codebook_size, 19 | codebook_dim = None, 20 | kmeans_init = False, 21 | kmeans_iters = 10, 22 | use_cosine_sim = False, 23 | use_cosine_distance = False, 24 | channel_last = False, 25 | accept_image_fmap = True, 26 | commitment_beta = 0.25, 27 | orthogonal_reg_weight = 0., 28 | ): 29 | super().__init__() 30 | 31 | self.codebook_size = codebook_size 32 | self.codebook_dim = codebook_dim 33 | self.accept_image_fmap = accept_image_fmap 34 | self.channel_last = channel_last 35 | self.use_cosine_sim = use_cosine_sim 36 | self.use_cosine_distance = use_cosine_distance 37 | self.beta = commitment_beta 38 | 39 | self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim) 40 | 41 | # codebook initialization 42 | if not kmeans_init: 43 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) 44 | else: 45 | self.embedding.weight.data.zero_() 46 | self.kmeans_iters = kmeans_iters 47 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 48 | self.register_buffer('cluster_size', torch.zeros(1, codebook_size)) 49 | 50 | self.sample_fn = utils.batched_sample_vectors 51 | self.all_reduce_fn = utils.noop 52 | 53 | # codebook_orthogonal_loss 54 | self.orthogonal_reg_weight = orthogonal_reg_weight 55 | 56 | def init_embed_(self, data): 57 | if self.initted: 58 | return 59 | 60 | data = rearrange(data, '... -> 1 ...').contiguous() 61 | data = rearrange(data, 'h ... d -> h (...) d').contiguous() 62 | 63 | embed, cluster_size = utils.kmeans( 64 | data, 65 | self.codebook_size, 66 | self.kmeans_iters, 67 | sample_fn = self.sample_fn, 68 | all_reduce_fn = self.all_reduce_fn 69 | ) 70 | 71 | self.embedding.weight.data.copy_(embed.squeeze(0)) 72 | self.cluster_size.data.copy_(cluster_size) 73 | self.initted.data.copy_(torch.Tensor([True])) 74 | 75 | def forward(self, x, temp=0.): 76 | need_transpose = not self.channel_last and not self.accept_image_fmap 77 | 78 | if self.accept_image_fmap: 79 | height, width = x.shape[-2:] 80 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 81 | 82 | if need_transpose: 83 | x = rearrange(x, 'b d n -> b n d').contiguous() 84 | shape, device, dtype = x.shape, x.device, x.dtype 85 | flatten = rearrange(x, 'h ... d -> h (...) d').contiguous() 86 | 87 | # if use cosine_sim, whether should norm the feature before k-means initialization ? 88 | # if self.use_cosine_sim: 89 | # flatten = F.normalize(flatten, p = 2, dim = -1) 90 | self.init_embed_(flatten) 91 | 92 | # calculate the distance 93 | if self.use_cosine_sim: # cosine similarity 94 | flatten_norm = F.normalize(flatten, p = 2, dim = -1) 95 | weight_norm = F.normalize(self.embedding.weight, p = 2, dim = -1).unsqueeze(0) 96 | 97 | dist = torch.matmul(flatten_norm, weight_norm.transpose(1, 2)) 98 | # dist = einsum('h n d, h c d -> h n c', flatten_norm, weight_norm) 99 | elif self.use_cosine_distance: # l2 normed distance, NOTE: experimental 100 | flatten_norm = F.normalize(flatten, p = 2, dim = -1).view(-1, self.codebook_dim) 101 | weight_norm = F.normalize(self.embedding.weight, p = 2, dim = -1) 102 | 103 | dist = - torch.sum(flatten_norm ** 2, dim=1, keepdim=True) - \ 104 | torch.sum(weight_norm**2, dim=1) + 2 * \ 105 | torch.einsum('bd,dn->bn', flatten_norm, rearrange(weight_norm, 'n d -> d n')) # more efficient, add "-" for argmax gumbel sample 106 | else: # L2 distance 107 | flatten = flatten.view(-1, self.codebook_dim) 108 | dist = - torch.sum(flatten ** 2, dim=1, keepdim=True) - \ 109 | torch.sum(self.embedding.weight**2, dim=1) + 2 * \ 110 | torch.einsum('bd,dn->bn', flatten, rearrange(self.embedding.weight, 'n d -> d n')) # more efficient, add "-" for argmax gumbel sample 111 | 112 | embed_ind = utils.gumbel_sample(dist, dim = -1, temperature = temp) 113 | # embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 114 | embed_ind = embed_ind.view(*shape[:-1]) 115 | 116 | x_q = self.embedding(embed_ind) 117 | 118 | # compute loss for embedding 119 | loss = self.beta * torch.mean((x_q.detach()-x)**2) + torch.mean((x_q - x.detach()) ** 2) 120 | 121 | # ortho reg term 122 | if self.orthogonal_reg_weight > 0. : 123 | # eq (2) from https://arxiv.org/abs/2112.00384 124 | emb_weight_after_norm = F.normalize(self.embedding.weight, p = 2, dim = -1) 125 | diff = torch.mm(emb_weight_after_norm, torch.transpose(emb_weight_after_norm, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(emb_weight_after_norm) 126 | ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 127 | 128 | # diff = torch.mm(self.embedding.weight, torch.transpose(self.embedding.weight, 0, 1)) - torch.eye(self.codebook_size, self.codebook_size).type_as(self.embedding.weight) 129 | # ortho_reg_term = self.orthogonal_reg_weight * torch.sum(diff**2) / (diff.size(0)**2) 130 | loss = loss + ortho_reg_term 131 | 132 | # preserve gradients 133 | x_q = x + (x_q - x).detach() 134 | 135 | if need_transpose: 136 | x_q = rearrange(x_q, 'b n d -> b d n').contiguous() 137 | 138 | if self.accept_image_fmap: 139 | x_q = rearrange(x_q, 'b (h w) c -> b c h w', h = height, w = width).contiguous() 140 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width).contiguous() 141 | 142 | return x_q, loss, (None, None, embed_ind) 143 | 144 | def get_codebook_entry(self, indices, *kwargs): 145 | # get quantized latent vectors 146 | z_q = self.embedding(indices) # (batch, height, width, channel) 147 | return z_q 148 | 149 | 150 | if __name__ == "__main__": 151 | quantizer = VectorQuantize( 152 | codebook_size = 1024, 153 | codebook_dim = 512, 154 | kmeans_init = True, 155 | kmeans_iters = 10, 156 | use_cosine_sim = False, 157 | use_cosine_distance = True, 158 | channel_last = False, 159 | accept_image_fmap = False, 160 | commitment_beta = 0.25, 161 | orthogonal_reg_weight = 10., 162 | ) 163 | 164 | # x = torch.randn(10, 512, 16, 16) 165 | x = torch.randn(10, 512, 120) 166 | 167 | x_q, loss, (_, _, embed_ind) = quantizer(x, 0.) 168 | print(loss) 169 | -------------------------------------------------------------------------------- /modules/vector_quantization/quantize2_list.py: -------------------------------------------------------------------------------- 1 | # add the random restart of rqvae 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from einops import rearrange 8 | 9 | 10 | class VQEmbedding(nn.Embedding): 11 | """VQ embedding module with ema update.""" 12 | 13 | def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5): 14 | super().__init__(n_embed + 1, embed_dim, padding_idx=n_embed) 15 | 16 | self.ema = ema 17 | self.decay = decay 18 | self.eps = eps 19 | self.restart_unused_codes = restart_unused_codes 20 | self.n_embed = n_embed 21 | 22 | if self.ema: 23 | _ = [p.requires_grad_(False) for p in self.parameters()] 24 | 25 | # padding index is not updated by EMA 26 | self.register_buffer('cluster_size_ema', torch.zeros(n_embed)) 27 | self.register_buffer('embed_ema', self.weight[:-1, :].detach().clone()) 28 | 29 | @torch.no_grad() 30 | def compute_distances(self, inputs): 31 | codebook_t = self.weight[:-1, :].t() 32 | 33 | (embed_dim, _) = codebook_t.shape 34 | inputs_shape = inputs.shape 35 | assert inputs_shape[-1] == embed_dim 36 | 37 | inputs_flat = inputs.reshape(-1, embed_dim) 38 | 39 | inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) 40 | codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) 41 | distances = torch.addmm( 42 | inputs_norm_sq + codebook_t_norm_sq, 43 | inputs_flat, 44 | codebook_t, 45 | alpha=-2.0, 46 | ) 47 | distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1] 48 | return distances 49 | 50 | @torch.no_grad() 51 | def find_nearest_embedding(self, inputs): 52 | distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1] 53 | embed_idxs = distances.argmin(dim=-1) # use padding index or not 54 | 55 | return embed_idxs 56 | 57 | @torch.no_grad() 58 | def _tile_with_noise(self, x, target_n): 59 | B, embed_dim = x.shape 60 | n_repeats = (target_n + B -1) // B 61 | std = x.new_ones(embed_dim) * 0.01 / np.sqrt(embed_dim) 62 | x = x.repeat(n_repeats, 1) 63 | x = x + torch.rand_like(x) * std 64 | return x 65 | 66 | @torch.no_grad() 67 | def _update_buffers(self, vectors, idxs): 68 | 69 | n_embed, embed_dim = self.weight.shape[0]-1, self.weight.shape[-1] 70 | 71 | vectors = vectors.reshape(-1, embed_dim) 72 | idxs = idxs.reshape(-1) 73 | 74 | n_vectors = vectors.shape[0] 75 | n_total_embed = n_embed 76 | 77 | one_hot_idxs = vectors.new_zeros(n_total_embed, n_vectors) 78 | one_hot_idxs.scatter_(dim=0, 79 | index=idxs.unsqueeze(0), 80 | src=vectors.new_ones(1, n_vectors) 81 | ) 82 | 83 | cluster_size = one_hot_idxs.sum(dim=1) 84 | vectors_sum_per_cluster = one_hot_idxs @ vectors 85 | 86 | if dist.is_initialized(): 87 | dist.all_reduce(vectors_sum_per_cluster, op=dist.ReduceOp.SUM) 88 | dist.all_reduce(cluster_size, op=dist.ReduceOp.SUM) 89 | 90 | self.cluster_size_ema.mul_(self.decay).add_(cluster_size, alpha=1 - self.decay) 91 | self.embed_ema.mul_(self.decay).add_(vectors_sum_per_cluster, alpha=1 - self.decay) 92 | 93 | if self.restart_unused_codes: 94 | if n_vectors < n_embed: 95 | vectors = self._tile_with_noise(vectors, n_embed) 96 | n_vectors = vectors.shape[0] 97 | _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed] 98 | 99 | if dist.is_initialized(): 100 | dist.broadcast(_vectors_random, 0) 101 | 102 | usage = (self.cluster_size_ema.view(-1, 1) >= 1).float() 103 | self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage)) 104 | self.cluster_size_ema.mul_(usage.view(-1)) 105 | self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1)) 106 | 107 | @torch.no_grad() 108 | def _update_embedding(self): 109 | 110 | n_embed = self.weight.shape[0] - 1 111 | n = self.cluster_size_ema.sum() 112 | normalized_cluster_size = ( 113 | n * (self.cluster_size_ema + self.eps) / (n + n_embed * self.eps) 114 | ) 115 | self.weight[:-1, :] = self.embed_ema / normalized_cluster_size.reshape(-1, 1) 116 | 117 | def forward(self, inputs): 118 | embed_idxs = self.find_nearest_embedding(inputs) 119 | if self.training: 120 | if self.ema: 121 | self._update_buffers(inputs, embed_idxs) 122 | 123 | embeds = self.embed(embed_idxs) 124 | 125 | if self.ema and self.training: 126 | self._update_embedding() 127 | 128 | return embeds, embed_idxs 129 | 130 | def embed(self, idxs): 131 | embeds = super().forward(idxs) 132 | return embeds 133 | 134 | # simplified version with random restart unused, accept list features input 135 | class VectorQuantize2(nn.Module): 136 | def __init__(self, 137 | codebook_size, 138 | codebook_dim = None, 139 | commitment_beta = 0.25, 140 | decay = 0.99, 141 | restart_unused_codes = True, 142 | ): 143 | super().__init__() 144 | self.beta = commitment_beta 145 | self.restart_unused_codes = restart_unused_codes 146 | self.codebook = VQEmbedding(codebook_size, 147 | codebook_dim, 148 | decay = decay, 149 | restart_unused_codes = restart_unused_codes, 150 | ) 151 | self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) 152 | 153 | def forward(self, x_list, *ignorewargs, **ignorekwargs): 154 | batch_size = len(x_list) 155 | 156 | x_q_list, x_code_list = [], [] 157 | loss = 0. 158 | for i in range(batch_size): 159 | x_q_i, x_code_i = self.codebook(x_list[i]) 160 | 161 | loss += self.beta * torch.mean((x_q_i.detach()-x_list[i])**2) + torch.mean((x_q_i - x_list[i].detach()) ** 2) 162 | 163 | # preserve gradients 164 | x_q_i = x_list[i] + (x_q_i - x_list[i]).detach() 165 | 166 | x_q_list.append(x_q_i) 167 | x_code_list.append(x_code_i) 168 | 169 | loss /= batch_size 170 | return x_q_list, loss, (None, None, x_code_list) 171 | 172 | @torch.no_grad() 173 | def get_soft_codes(self, x, temp=1.0, stochastic=False): 174 | distances = self.codebook.compute_distances(x) 175 | soft_code = F.softmax(-distances / temp, dim=-1) 176 | 177 | if stochastic: 178 | soft_code_flat = soft_code.reshape(-1, soft_code.shape[-1]) 179 | code = torch.multinomial(soft_code_flat, 1) 180 | code = code.reshape(*soft_code.shape[:-1]) 181 | else: 182 | code = distances.argmin(dim=-1) 183 | 184 | return soft_code, code 185 | 186 | def get_codebook_entry(self, indices, *kwargs): 187 | # get quantized latent vectors 188 | z_q = self.codebook.embed(indices) # (batch, height, width, channel) 189 | return z_q -------------------------------------------------------------------------------- /modules/vqvae/quantize2.py: -------------------------------------------------------------------------------- 1 | # NOTE: 这一版本修改了vqvae中输入z的方式,使得适配我们的mask机制 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | class VectorQuantizer2(nn.Module): 9 | """ 10 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 11 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 12 | """ 13 | # NOTE: due to a bug the beta term was applied to the wrong term. for 14 | # backwards compatibility we use the buggy version by default, but you can 15 | # specify legacy=False to fix it. 16 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 17 | legacy=True): 18 | super().__init__() 19 | self.n_e = n_e 20 | self.e_dim = e_dim 21 | self.beta = beta 22 | self.legacy = legacy 23 | 24 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 25 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 26 | 27 | self.remap = remap 28 | if self.remap is not None: 29 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 30 | self.re_embed = self.used.shape[0] 31 | self.unknown_index = unknown_index # "random" or "extra" or integer 32 | if self.unknown_index == "extra": 33 | self.unknown_index = self.re_embed 34 | self.re_embed = self.re_embed+1 35 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 36 | f"Using {self.unknown_index} for unknown indices.") 37 | else: 38 | self.re_embed = n_e 39 | 40 | def remap_to_used(self, inds): 41 | ishape = inds.shape 42 | assert len(ishape)>1 43 | inds = inds.reshape(ishape[0],-1) 44 | used = self.used.to(inds) 45 | match = (inds[:,:,None]==used[None,None,...]).long() 46 | new = match.argmax(-1) 47 | unknown = match.sum(2)<1 48 | if self.unknown_index == "random": 49 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 50 | else: 51 | new[unknown] = self.unknown_index 52 | return new.reshape(ishape) 53 | 54 | def unmap_to_all(self, inds): 55 | ishape = inds.shape 56 | assert len(ishape)>1 57 | inds = inds.reshape(ishape[0],-1) 58 | used = self.used.to(inds) 59 | if self.re_embed > self.used.shape[0]: # extra token 60 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 61 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 62 | return back.reshape(ishape) 63 | 64 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 65 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 66 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 67 | assert return_logits==False, "Only for interface compatible with Gumbel" 68 | # reshape z -> (batch, length, channel) and flatten 69 | z = rearrange(z, 'b c l -> b l c').contiguous() 70 | z_flattened = z.view(-1, self.e_dim) 71 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 72 | 73 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 74 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 75 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 76 | 77 | min_encoding_indices = torch.argmin(d, dim=1) 78 | z_q = self.embedding(min_encoding_indices).view(z.shape) 79 | perplexity = None 80 | min_encodings = None 81 | 82 | # compute loss for embedding 83 | if not self.legacy: 84 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 85 | torch.mean((z_q - z.detach()) ** 2) 86 | else: 87 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 88 | torch.mean((z_q - z.detach()) ** 2) 89 | 90 | # preserve gradients 91 | z_q = z + (z_q - z).detach() 92 | 93 | # reshape back to match original input shape 94 | z_q = rearrange(z_q, 'b l c -> b c l').contiguous() 95 | 96 | if self.remap is not None: 97 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 98 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 99 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 100 | 101 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 102 | 103 | def get_codebook_entry(self, indices, shape): 104 | # shape specifying (batch, height, width, channel) 105 | if self.remap is not None: 106 | indices = indices.reshape(shape[0],-1) # add batch axis 107 | indices = self.unmap_to_all(indices) 108 | indices = indices.reshape(-1) # flatten again 109 | 110 | # get quantized latent vectors 111 | z_q = self.embedding(indices) 112 | 113 | if shape is not None: 114 | z_q = z_q.view(shape) 115 | # reshape back to match original input shape 116 | z_q = z_q.permute(0, 2, 1).contiguous() 117 | 118 | return z_q -------------------------------------------------------------------------------- /scripts/sample_images/sample_dynamic_uncond.py: -------------------------------------------------------------------------------- 1 | from locale import normalize 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.getcwd()) 6 | import argparse 7 | import pickle 8 | 9 | import torch 10 | from omegaconf import OmegaConf 11 | from tqdm import trange 12 | from utils.utils import instantiate_from_config 13 | 14 | import datetime 15 | import pytz 16 | import torchvision 17 | 18 | def get_parser(**parser_kwargs): 19 | parser = argparse.ArgumentParser(**parser_kwargs) 20 | parser.add_argument("--yaml_path", type=str, default="") 21 | parser.add_argument("--model_path", type=str, default="") 22 | parser.add_argument("--sample_with_fixed_pos", action="store_true", default=False) 23 | 24 | parser.add_argument("--batch_size", type=int, default=50) 25 | parser.add_argument("--temperature", type=float, default=1.0) 26 | parser.add_argument("--top_k", type=int, default=300) 27 | parser.add_argument("--top_k_pos", type=int, default=1024) # 50 28 | parser.add_argument("--top_p", type=float, default=1.0) 29 | parser.add_argument("--top_p_pos", type=float, default=1.0) 30 | parser.add_argument("--sample_num", type=int, default=5000) 31 | 32 | return parser 33 | 34 | if __name__ == "__main__": 35 | Shanghai = pytz.timezone("Asia/Shanghai") 36 | now = datetime.datetime.now().astimezone(Shanghai).strftime("%m-%dT%H-%M-%S") 37 | 38 | parser = get_parser() 39 | opt, unknown = parser.parse_known_args() 40 | 41 | save_path = opt.model_path.replace(".ckpt", "") + "_{}_Num-{}/".format(now, opt.sample_num) 42 | if opt.sample_with_fixed_pos: 43 | save_path_image = save_path + "fixed_TopK-{}-{}_TopP-{}-{}_Temp-{}_image".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 44 | save_path_pickle = save_path + "fixed_TopK-{}-{}_TopP-{}-{}_Temp-{}_pickle".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 45 | os.makedirs(save_path_image, exist_ok=True) 46 | else: 47 | save_path_image = save_path + "TopK-{}-{}_TopP-{}-{}_Temp-{}_image".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 48 | save_path_pickle = save_path + "TopK-{}-{}_TopP-{}-{}_Temp-{}_pickle".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 49 | os.makedirs(save_path_image, exist_ok=True) 50 | 51 | # init and save configs 52 | configs = OmegaConf.load(opt.yaml_path) 53 | # model 54 | model = instantiate_from_config(configs.model) 55 | state_dict = torch.load(opt.model_path)['state_dict'] 56 | model.load_state_dict(state_dict) 57 | model.eval().cuda() 58 | 59 | if opt.sample_num % opt.batch_size == 0: 60 | total_batch = opt.sample_num // opt.batch_size 61 | else: 62 | total_batch = opt.sample_num // opt.batch_size + 1 63 | 64 | batch_size = opt.batch_size 65 | for i in trange(total_batch): 66 | if opt.sample_num % opt.batch_size != 0 and i == total_batch - 1: 67 | batch_size = opt.sample_num % opt.batch_size 68 | x0 = torch.randn(batch_size, ).cuda() 69 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine = model.encode_to_c(x0) 70 | 71 | if opt.sample_with_fixed_pos: 72 | coarse_content, fine_content, coarse_position, fine_position = model.sample_from_scratch( 73 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine, 74 | temperature = opt.temperature, 75 | sample = True, 76 | top_k = opt.top_k, 77 | top_p = opt.top_p, 78 | top_k_pos = opt.top_k_pos, 79 | top_p_pos = opt.top_p_pos, 80 | process = True, 81 | fix_fine_position = True, 82 | ) 83 | samples = model.decode_to_img(coarse_content, fine_content, coarse_position, fine_position) 84 | sample = torch.clamp((samples * 0.5 + 0.5), 0, 1) 85 | 86 | for batch_i in range(batch_size): 87 | torchvision.utils.save_image(samples[batch_i], "{}/batch_{}_{}.png".format(save_path_image, i, batch_i), normalize=True) 88 | else: 89 | coarse_content, fine_content, coarse_position, fine_position = model.sample_from_scratch( 90 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine, 91 | temperature = opt.temperature, 92 | sample = True, 93 | top_k = opt.top_k, 94 | top_p = opt.top_p, 95 | top_k_pos = opt.top_k_pos, 96 | top_p_pos = opt.top_p_pos, 97 | process = True, 98 | fix_fine_position = False, 99 | ) 100 | samples = model.decode_to_img(coarse_content, fine_content, coarse_position, fine_position) 101 | samples = torch.clamp((samples * 0.5 + 0.5), 0, 1) 102 | for batch_i in range(batch_size): 103 | torchvision.utils.save_image(samples[batch_i], "{}/batch_{}_{}.png".format(save_path_image, i, batch_i), normalize=True) -------------------------------------------------------------------------------- /scripts/sample_val/sample_dynamic_uncond.py: -------------------------------------------------------------------------------- 1 | from ast import parse 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.getcwd()) 6 | import argparse 7 | import pickle 8 | 9 | import torch 10 | from omegaconf import OmegaConf 11 | from tqdm import trange 12 | from utils.utils import instantiate_from_config 13 | 14 | import datetime 15 | import pytz 16 | import torchvision 17 | 18 | def save_pickle(fname, data): 19 | with open(fname, 'wb') as fp: 20 | pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL) 21 | 22 | def get_parser(**parser_kwargs): 23 | parser = argparse.ArgumentParser(**parser_kwargs) 24 | parser.add_argument("--yaml_path", type=str, default="") 25 | parser.add_argument("--model_path", type=str, default="") 26 | parser.add_argument("--sample_with_fixed_pos", action="store_true", default=False) 27 | parser.add_argument("--save_image", action="store_true", default=False) 28 | 29 | parser.add_argument("--batch_size", type=int, default=50) 30 | parser.add_argument("--temperature", type=float, default=1.0) 31 | parser.add_argument("--top_k", type=int, default=300) 32 | parser.add_argument("--top_k_pos", type=int, default=1024) # 50 33 | parser.add_argument("--top_p", type=float, default=1.0) 34 | parser.add_argument("--top_p_pos", type=float, default=1.0) 35 | parser.add_argument("--sample_num", type=int, default=5000) 36 | 37 | return parser 38 | 39 | if __name__ == "__main__": 40 | Shanghai = pytz.timezone("Asia/Shanghai") 41 | now = datetime.datetime.now().astimezone(Shanghai).strftime("%m-%dT%H-%M-%S") 42 | 43 | parser = get_parser() 44 | opt, unknown = parser.parse_known_args() 45 | 46 | save_path = opt.model_path.replace(".ckpt", "") + "_{}_Num-{}/".format(now, opt.sample_num) 47 | if opt.sample_with_fixed_pos: 48 | save_path_image_fixed = save_path + "fixed_TopK-{}-{}_TopP-{}-{}_Temp-{}_image".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 49 | save_path_pickle_fixed = save_path + "fixed_TopK-{}-{}_TopP-{}-{}_Temp-{}_pickle".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 50 | if opt.save_image: 51 | os.makedirs(save_path_image_fixed, exist_ok=True) 52 | os.makedirs(save_path_pickle_fixed, exist_ok=True) 53 | else: 54 | save_path_image = save_path + "TopK-{}-{}_TopP-{}-{}_Temp-{}_image".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 55 | save_path_pickle = save_path + "TopK-{}-{}_TopP-{}-{}_Temp-{}_pickle".format(opt.top_k, opt.top_k_pos, opt.top_p, opt.top_p_pos, opt.temperature) 56 | if opt.save_image: 57 | os.makedirs(save_path_image, exist_ok=True) 58 | os.makedirs(save_path_pickle, exist_ok=True) 59 | 60 | # init and save configs 61 | configs = OmegaConf.load(opt.yaml_path) 62 | # model 63 | model = instantiate_from_config(configs.model) 64 | state_dict = torch.load(opt.model_path)['state_dict'] 65 | model.load_state_dict(state_dict) 66 | model.eval().cuda() 67 | 68 | if opt.sample_num % opt.batch_size == 0: 69 | total_batch = opt.sample_num // opt.batch_size 70 | else: 71 | total_batch = opt.sample_num // opt.batch_size + 1 72 | 73 | batch_size = opt.batch_size 74 | for i in trange(total_batch): 75 | if opt.sample_num % opt.batch_size != 0 and i == total_batch - 1: 76 | batch_size = opt.sample_num % opt.batch_size 77 | x0 = torch.randn(batch_size, ).cuda() 78 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine = model.encode_to_c(x0) 79 | 80 | if opt.sample_with_fixed_pos: 81 | coarse_content, fine_content, coarse_position, fine_position = model.sample_from_scratch( 82 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine, 83 | temperature = opt.temperature, 84 | sample = True, 85 | top_k = opt.top_k, 86 | top_p = opt.top_p, 87 | top_k_pos = opt.top_k_pos, 88 | top_p_pos = opt.top_p_pos, 89 | process = True, 90 | fix_fine_position = True, 91 | ) 92 | samples_fixed_fine_position = model.decode_to_img(coarse_content, fine_content, coarse_position, fine_position) 93 | samples_fixed_fine_position = torch.clamp((samples_fixed_fine_position * 0.5 + 0.5), 0, 1) 94 | if opt.save_image: 95 | torchvision.utils.save_image(samples_fixed_fine_position, "{}/batch_{}.png".format(save_path_image_fixed, i)) 96 | save_pickle( 97 | os.path.join(save_path_pickle_fixed, 'samples_({}_{}).pkl'.format(i, total_batch)), 98 | samples_fixed_fine_position.cpu().numpy(), 99 | ) 100 | else: 101 | coarse_content, fine_content, coarse_position, fine_position = model.sample_from_scratch( 102 | c_coarse, c_fine, c_pos_coarse, c_pos_fine, c_seg_coarse, c_seg_fine, 103 | temperature = opt.temperature, 104 | sample = True, 105 | top_k = opt.top_k, 106 | top_p = opt.top_p, 107 | top_k_pos = opt.top_k_pos, 108 | top_p_pos = opt.top_p_pos, 109 | process = True, 110 | fix_fine_position = False, 111 | ) 112 | samples = model.decode_to_img(coarse_content, fine_content, coarse_position, fine_position) 113 | samples = torch.clamp((samples * 0.5 + 0.5), 0, 1) 114 | if opt.save_image: 115 | torchvision.utils.save_image(samples, "{}/batch_{}.png".format(save_path_image, i)) 116 | save_pickle( 117 | os.path.join(save_path_pickle, 'samples_({}_{}).pkl'.format(i, total_batch)), 118 | samples.cpu().numpy(), 119 | ) -------------------------------------------------------------------------------- /scripts/tools/calculate_entropy_thresholds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | import argparse 6 | 7 | import torch 8 | from einops import rearrange 9 | from torch import nn, Tensor 10 | import torchvision 11 | from data.imagenet_lmdb import Imagenet_LMDB 12 | from data.ffhq_lmdb import FFHQ_LMDB 13 | from tqdm import tqdm 14 | import numpy as np 15 | np.set_printoptions(suppress=True) 16 | import json 17 | 18 | def get_parser(**parser_kwargs): 19 | parser = argparse.ArgumentParser(**parser_kwargs) 20 | parser.add_argument("--batch_size", type=int, default=10) 21 | parser.add_argument("--dataset_type", type=str, default="ffhq") 22 | parser.add_argument("--split", type=str, default="val") 23 | parser.add_argument("--patch_size", type=int, default=16) 24 | parser.add_argument("--image_size", type=int, default=256) 25 | return parser 26 | 27 | class Entropy(nn.Sequential): 28 | def __init__(self, patch_size, image_width, image_height): 29 | super(Entropy, self).__init__() 30 | self.width = image_width 31 | self.height = image_height 32 | self.psize = patch_size 33 | # number of patches per image 34 | self.patch_num = int(self.width * self.height / self.psize ** 2) 35 | self.hw = int(self.width // self.psize) 36 | # unfolding image to non overlapping patches 37 | self.unfold = torch.nn.Unfold(kernel_size=(self.psize, self.psize), stride=self.psize) 38 | 39 | def entropy(self, values: torch.Tensor, bins: torch.Tensor, sigma: torch.Tensor, batch: int) -> torch.Tensor: 40 | """Function that calculates the entropy using marginal probability distribution function of the input tensor 41 | based on the number of histogram bins. 42 | Args: 43 | values: shape [BxNx1]. 44 | bins: shape [NUM_BINS]. 45 | sigma: shape [1], gaussian smoothing factor. 46 | batch: int, size of the batch 47 | Returns: 48 | torch.Tensor: 49 | """ 50 | epsilon = 1e-40 51 | values = values.unsqueeze(2) 52 | residuals = values - bins.unsqueeze(0).unsqueeze(0) 53 | kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2)) 54 | 55 | pdf = torch.mean(kernel_values, dim=1) 56 | normalization = torch.sum(pdf, dim=1).unsqueeze(1) + epsilon 57 | pdf = pdf / normalization + epsilon 58 | entropy = - torch.sum(pdf * torch.log(pdf), dim=1) 59 | entropy = entropy.reshape((batch, -1)) 60 | entropy = rearrange(entropy, "B (H W) -> B H W", H=self.hw, W=self.hw) 61 | return entropy 62 | 63 | def forward(self, inputs: Tensor) -> torch.Tensor: 64 | batch_size = inputs.shape[0] 65 | gray_images = 0.2989 * inputs[:, 0:1, :, :] + 0.5870 * inputs[:, 1:2, :, :] + 0.1140 * inputs[:, 2:, :, :] 66 | 67 | # create patches of size (batch x patch_size*patch_size x h*w/ (patch_size*patch_size)) 68 | unfolded_images = self.unfold(gray_images) 69 | # reshape to (batch * h*w/ (patch_size*patch_size) x (patch_size*patch_size) 70 | unfolded_images = unfolded_images.transpose(1, 2) 71 | unfolded_images = torch.reshape(unfolded_images.unsqueeze(2), 72 | (unfolded_images.shape[0] * self.patch_num, unfolded_images.shape[2])) 73 | 74 | entropy = self.entropy(unfolded_images, bins=torch.linspace(0, 1, 32).to(device=inputs.device), 75 | sigma=torch.tensor(0.01), batch=batch_size) 76 | 77 | return entropy 78 | 79 | if __name__ == "__main__": 80 | parser = get_parser() 81 | opt, unknown = parser.parse_known_args() 82 | 83 | if opt.dataset_type == "ffhq": 84 | dset = FFHQ_LMDB(split=opt.split, resolution=opt.image_size, is_eval=True) 85 | dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, num_workers=0, shuffle=False) 86 | elif opt.dataset_type == "imagenet": 87 | dset = Imagenet_LMDB(split=opt.split, resolution=opt.image_size, is_eval=True) 88 | dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, num_workers=0, shuffle=False) 89 | 90 | model = Entropy(opt.patch_size, opt.image_size, opt.image_size).cuda() 91 | 92 | with torch.no_grad(): 93 | for i, data in tqdm(enumerate(dloader)): 94 | image = data["image"].cuda() 95 | if i == 0: 96 | entropy_numpy = model(image).view(-1).cpu().numpy() 97 | else: 98 | entropy_numpy = np.concatenate((entropy_numpy, model(image).view(-1).cpu().numpy())) 99 | 100 | entropy_numpy = np.sort(entropy_numpy) 101 | size = entropy_numpy.shape[0] 102 | print(size) 103 | 104 | with open("scripts/tools/thresholds/entropy_thresholds_{}_{}_patch-{}.json".format(opt.dataset_type, opt.split, opt.patch_size), "w") as f: 105 | data = {} 106 | for i in range(99): 107 | cur_threshold = entropy_numpy[int((size * (i + 1)) // 100)] 108 | cur_threshold = cur_threshold.item() 109 | data["{}".format(str(i+1))] = cur_threshold 110 | json.dump(data, f) -------------------------------------------------------------------------------- /scripts/tools/codebook_pca.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | import argparse 6 | 7 | import torch 8 | import torchvision 9 | from data.imagenet_lmdb import Imagenet_LMDB 10 | from data.ffhq_lmdb import FFHQ_LMDB 11 | from omegaconf import OmegaConf 12 | from utils.utils import instantiate_from_config 13 | from modules.masked_quantization.tools import build_score_image 14 | from tqdm import tqdm 15 | import matplotlib.pyplot as plt 16 | from sklearn.decomposition import PCA 17 | 18 | 19 | def get_parser(**parser_kwargs): 20 | parser = argparse.ArgumentParser(**parser_kwargs) 21 | parser.add_argument("--yaml_path", type=str, default="") 22 | parser.add_argument("--model_path", type=str, default="") 23 | parser.add_argument("--title", type=str, default="PCA of codebook") 24 | parser.add_argument("--save_name", type=str, default="") 25 | 26 | return parser 27 | 28 | if __name__ == "__main__": 29 | parser = get_parser() 30 | opt, unknown = parser.parse_known_args() 31 | 32 | # init and save configs 33 | configs = OmegaConf.load(opt.yaml_path) 34 | # model 35 | model = instantiate_from_config(configs.model) 36 | state_dict = torch.load(opt.model_path)['state_dict'] 37 | model.load_state_dict(state_dict) 38 | model.eval().cuda() 39 | 40 | try: 41 | embedding_data = model.quantize.embedding.weight.data 42 | except: 43 | try: 44 | embedding_data = model.quantize.codebook.weight.data 45 | except: 46 | embedding_data = model.quantizer.codebooks[0].weight.data 47 | embedding_data = embedding_data.cpu().numpy() 48 | print(embedding_data) 49 | 50 | pca = PCA(n_components=2) # 实例化 51 | pca = pca.fit(embedding_data) # 拟合模型 52 | x_dr = pca.transform(embedding_data) # 获取新矩阵 53 | 54 | print(x_dr.shape) 55 | 56 | plt.figure() # 创建一个画布 57 | plt.scatter(x_dr[:,0],x_dr[:,1],c="red") # plt.scatter(x_dr[y==0,0],x_dr[y==0,1],c="red",label = iris.target_names[0]) 58 | # plt.legend() # 显示图例 59 | plt.title("{}".format(opt.title)) # 显示标题 60 | plt.savefig("{}.png".format(opt.save_name)) -------------------------------------------------------------------------------- /scripts/tools/codebook_usage_dqvae.py: -------------------------------------------------------------------------------- 1 | # The codebook usage is calculated as the percentage of 2 | # used codes given a batch of 256 test images averaged over the entire test set. 3 | 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | import argparse 9 | 10 | import torch 11 | import torchvision 12 | from data.imagenet_lmdb import Imagenet_LMDB 13 | from data.ffhq_lmdb import FFHQ_LMDB 14 | from omegaconf import OmegaConf 15 | from utils.utils import instantiate_from_config 16 | from modules.tokenizers.tools import build_score_image 17 | from tqdm import tqdm 18 | 19 | def get_parser(**parser_kwargs): 20 | parser = argparse.ArgumentParser(**parser_kwargs) 21 | parser.add_argument("--yaml_path", type=str, 22 | default="logs/03-26T21-19-43_vqmask-50_coco_f16_1024/configs/03-26T21-19-43-project.yaml") 23 | parser.add_argument("--model_path", type=str, 24 | default="logs/03-26T21-19-43_vqmask-50_coco_f16_1024/checkpoints/last.ckpt") 25 | 26 | parser.add_argument("--batch_size", type=int, default=100) 27 | parser.add_argument("--dataset_type", type=str, default="ffhq") 28 | parser.add_argument("--codebook_size", type=int, default=1024) 29 | 30 | return parser 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = get_parser() 35 | opt, unknown = parser.parse_known_args() 36 | 37 | # init and save configs 38 | configs = OmegaConf.load(opt.yaml_path) 39 | # model 40 | model = instantiate_from_config(configs.model) 41 | state_dict = torch.load(opt.model_path)['state_dict'] 42 | model.load_state_dict(state_dict) 43 | model.eval().cuda() 44 | 45 | if opt.dataset_type == "ffhq": 46 | dset = FFHQ_LMDB(split="val", resolution=256, is_eval=True) 47 | dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, num_workers=0, shuffle=False) 48 | elif opt.dataset_type == "imagenet": 49 | dset = Imagenet_LMDB(split="val", resolution=256, is_eval=True) 50 | dloader = torch.utils.data.DataLoader(dset, batch_size=opt.batch_size, num_workers=0, shuffle=False) 51 | 52 | with torch.no_grad(): 53 | for i,data in tqdm(enumerate(dloader)): 54 | image = data["image"].float().cuda() 55 | # image = image.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 56 | 57 | quant, emb_loss, info, _, _ = model.encode(image) 58 | 59 | min_encoding_indices = info[-1].view(-1) 60 | # print(min_encoding_indices.size()) 61 | min_encoding_indices = min_encoding_indices.cpu().numpy().tolist() 62 | 63 | if i == 0: 64 | codebook_register = list(set(min_encoding_indices)) 65 | else: 66 | codebook_register = list(set(codebook_register + min_encoding_indices)) 67 | 68 | print(len(codebook_register)) 69 | print("usage: ", 1 - len(codebook_register) / opt.codebook_size) -------------------------------------------------------------------------------- /scripts/tools/thresholds/entropy_thresholds_ffhq_train_patch-16.json: -------------------------------------------------------------------------------- 1 | {"1": 2.947292990206026e-37, "2": 2.947292990206026e-37, "3": 2.947292990206026e-37, "4": 2.947292990206026e-37, "5": 2.947292990206026e-37, "6": 2.947292990206026e-37, "7": 2.947292990206026e-37, "8": 2.947292990206026e-37, "9": 2.947292990206026e-37, "10": 2.947292990206026e-37, "11": 2.947292990206026e-37, "12": 2.947292990206026e-37, "13": 2.947292990206026e-37, "14": 2.947292990206026e-37, "15": 2.947292990206026e-37, "16": 2.947292990206026e-37, "17": 2.947292990206026e-37, "18": 2.947292990206026e-37, "19": 2.947292990206026e-37, "20": 2.947292990206026e-37, "21": 2.947292990206026e-37, "22": 2.947292990206026e-37, "23": 2.947292990206026e-37, "24": 2.947292990206026e-37, "25": 4.656185173209683e-14, "26": 1.6099059585439335e-11, "27": 4.941910436428998e-09, "28": 5.954901212135155e-07, "29": 4.1580264223739505e-05, "30": 0.0015554522396996617, "31": 0.03318123519420624, "32": 0.12301039695739746, "33": 0.25411027669906616, "34": 0.39538219571113586, "35": 0.531684160232544, "36": 0.6324535608291626, "37": 0.6926922798156738, "38": 0.7654862403869629, "39": 0.8662536144256592, "40": 0.9582383632659912, "41": 1.0375746488571167, "42": 1.1040228605270386, "43": 1.1735278367996216, "44": 1.239598035812378, "45": 1.2988930940628052, "46": 1.3522948026657104, "47": 1.4027752876281738, "48": 1.4538252353668213, "49": 1.5014616250991821, "50": 1.545932412147522, "51": 1.587462306022644, "52": 1.6278789043426514, "53": 1.6679913997650146, "54": 1.7059961557388306, "55": 1.7421966791152954, "56": 1.7769132852554321, "57": 1.8112608194351196, "58": 1.844649314880371, "59": 1.8769294023513794, "60": 1.9083247184753418, "61": 1.93890380859375, "62": 1.9691890478134155, "63": 1.9986827373504639, "64": 2.0271928310394287, "65": 2.0552756786346436, "66": 2.083155632019043, "67": 2.110912322998047, "68": 2.1379952430725098, "69": 2.1645619869232178, "70": 2.1908888816833496, "71": 2.217161178588867, "72": 2.243135929107666, "73": 2.2687926292419434, "74": 2.29431414604187, "75": 2.319988965988159, "76": 2.345322608947754, "77": 2.3707399368286133, "78": 2.396108627319336, "79": 2.421663522720337, "80": 2.4472193717956543, "81": 2.4729485511779785, "82": 2.4989397525787354, "83": 2.5251026153564453, "84": 2.5515804290771484, "85": 2.578578233718872, "86": 2.605985164642334, "87": 2.634187698364258, "88": 2.6630964279174805, "89": 2.6928598880767822, "90": 2.7234365940093994, "91": 2.755277156829834, "92": 2.7889676094055176, "93": 2.8244919776916504, "94": 2.86238431930542, "95": 2.903606414794922, "96": 2.949753999710083, "97": 3.002513885498047, "98": 3.066530704498291, "99": 3.1541624069213867} -------------------------------------------------------------------------------- /scripts/tools/thresholds/entropy_thresholds_imagenet_train_patch-16.json: -------------------------------------------------------------------------------- 1 | {"1": 2.947292990206026e-37, "2": 2.947292990206026e-37, "3": 2.947292990206026e-37, "4": 2.947292990206026e-37, "5": 2.947292990206026e-37, "6": 2.947292990206026e-37, "7": 2.947292990206026e-37, "8": 2.947292990206026e-37, "9": 2.947292990206026e-37, "10": 2.947292990206026e-37, "11": 2.947292990206026e-37, "12": 2.947292990206026e-37, "13": 2.947292990206026e-37, "14": 2.947292990206026e-37, "15": 2.947292990206026e-37, "16": 2.947292990206026e-37, "17": 2.947292990206026e-37, "18": 2.947292990206026e-37, "19": 2.947292990206026e-37, "20": 3.2971225034060525e-14, "21": 2.5531342665030543e-11, "22": 1.3705702350819138e-08, "23": 2.3880370463302825e-06, "24": 0.0002516355598345399, "25": 0.010697992518544197, "26": 0.03487185761332512, "27": 0.11292985081672668, "28": 0.25502198934555054, "29": 0.3953641355037689, "30": 0.5240799784660339, "31": 0.6198657751083374, "32": 0.6835920214653015, "33": 0.7304143309593201, "34": 0.803577721118927, "35": 0.8802925944328308, "36": 0.9519385099411011, "37": 1.019404411315918, "38": 1.080482006072998, "39": 1.1392621994018555, "40": 1.1998910903930664, "41": 1.2569533586502075, "42": 1.3105838298797607, "43": 1.3611900806427002, "44": 1.4104721546173096, "45": 1.4596047401428223, "46": 1.5065158605575562, "47": 1.5511703491210938, "48": 1.594150424003601, "49": 1.6364712715148926, "50": 1.6777750253677368, "51": 1.7174620628356934, "52": 1.7558200359344482, "53": 1.7933151721954346, "54": 1.8301582336425781, "55": 1.8659451007843018, "56": 1.9007411003112793, "57": 1.93480384349823, "58": 1.968318223953247, "59": 2.001042127609253, "60": 2.0330491065979004, "61": 2.0644726753234863, "62": 2.0954644680023193, "63": 2.125847816467285, "64": 2.155766725540161, "65": 2.1852800846099854, "66": 2.2144200801849365, "67": 2.243102788925171, "68": 2.2714436054229736, "69": 2.2994747161865234, "70": 2.3272018432617188, "71": 2.354642629623413, "72": 2.381889820098877, "73": 2.4089431762695312, "74": 2.435821533203125, "75": 2.4625627994537354, "76": 2.4891631603240967, "77": 2.5156636238098145, "78": 2.5421223640441895, "79": 2.568582534790039, "80": 2.5950534343719482, "81": 2.621587038040161, "82": 2.6482317447662354, "83": 2.675014019012451, "84": 2.701997756958008, "85": 2.7292490005493164, "86": 2.7568156719207764, "87": 2.7847416400909424, "88": 2.813140869140625, "89": 2.8421614170074463, "90": 2.871901512145996, "91": 2.902510166168213, "92": 2.9341444969177246, "93": 2.9671168327331543, "94": 3.0017991065979004, "95": 3.0387187004089355, "96": 3.078676223754883, "97": 3.12296199798584, "98": 3.1741182804107666, "99": 3.238837242126465} -------------------------------------------------------------------------------- /scripts/tools/thresholds/entropy_thresholds_imagenet_val_patch-16.json: -------------------------------------------------------------------------------- 1 | {"1": 2.947292990206026e-37, "2": 2.947292990206026e-37, "3": 2.947292990206026e-37, "4": 2.947292990206026e-37, "5": 2.947292990206026e-37, "6": 2.947292990206026e-37, "7": 2.947292990206026e-37, "8": 2.947292990206026e-37, "9": 2.947292990206026e-37, "10": 2.947292990206026e-37, "11": 2.947292990206026e-37, "12": 2.947292990206026e-37, "13": 2.947292990206026e-37, "14": 2.947292990206026e-37, "15": 2.947292990206026e-37, "16": 2.947292990206026e-37, "17": 2.947292990206026e-37, "18": 2.947292990206026e-37, "19": 2.947292990206026e-37, "20": 4.216345358249406e-15, "21": 3.7225322650769055e-12, "22": 2.2681099220989154e-09, "23": 6.084861752242432e-07, "24": 7.31039690435864e-05, "25": 0.003753718687221408, "26": 0.03487185761332512, "27": 0.11335359513759613, "28": 0.26401612162590027, "29": 0.40536415576934814, "30": 0.5326341390609741, "31": 0.6258067488670349, "32": 0.6868937611579895, "33": 0.7352866530418396, "34": 0.8084094524383545, "35": 0.884431004524231, "36": 0.955210268497467, "37": 1.0215826034545898, "38": 1.0819340944290161, "39": 1.1402863264083862, "40": 1.200192928314209, "41": 1.2568758726119995, "42": 1.3101763725280762, "43": 1.3604388236999512, "44": 1.4093680381774902, "45": 1.4581224918365479, "46": 1.5047249794006348, "47": 1.5491831302642822, "48": 1.5919694900512695, "49": 1.6340652704238892, "50": 1.6750011444091797, "51": 1.714430809020996, "52": 1.7526190280914307, "53": 1.7898746728897095, "54": 1.8267024755477905, "55": 1.8624517917633057, "56": 1.8970558643341064, "57": 1.9309626817703247, "58": 1.964379072189331, "59": 1.996957778930664, "60": 2.0287485122680664, "61": 2.0600745677948, "62": 2.090888500213623, "63": 2.121279716491699, "64": 2.151048421859741, "65": 2.180299758911133, "66": 2.209294080734253, "67": 2.2378087043762207, "68": 2.265988826751709, "69": 2.2937614917755127, "70": 2.321420669555664, "71": 2.348869800567627, "72": 2.3760454654693604, "73": 2.402902603149414, "74": 2.4297287464141846, "75": 2.4564146995544434, "76": 2.482914924621582, "77": 2.50938081741333, "78": 2.5357770919799805, "79": 2.5621705055236816, "80": 2.588665008544922, "81": 2.61538028717041, "82": 2.6420035362243652, "83": 2.6689531803131104, "84": 2.6960394382476807, "85": 2.7232048511505127, "86": 2.750839948654175, "87": 2.7786669731140137, "88": 2.8069705963134766, "89": 2.835883140563965, "90": 2.8655805587768555, "91": 2.896111011505127, "92": 2.927741050720215, "93": 2.960714101791382, "94": 2.9956445693969727, "95": 3.0325424671173096, "96": 3.0726613998413086, "97": 3.1168980598449707, "98": 3.168454170227051, "99": 3.233855962753296} -------------------------------------------------------------------------------- /scripts/tools/visualize_dual_grain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torchvision 7 | from omegaconf import OmegaConf 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | from data.imagenet import ImageNetValidation 14 | from utils.utils import instantiate_from_config 15 | from modules.dynamic_modules.utils import draw_dual_grain_256res_color 16 | 17 | def get_parser(**parser_kwargs): 18 | parser = argparse.ArgumentParser(**parser_kwargs) 19 | parser.add_argument("--yaml_path", type=str, default="") 20 | parser.add_argument("--model_path", type=str, default="") 21 | parser.add_argument("--batch_size", type=int, default=4) 22 | parser.add_argument("--image_save_path", type=str, default="") 23 | 24 | return parser 25 | 26 | if __name__ == "__main__": 27 | parser = get_parser() 28 | opt, unknown = parser.parse_known_args() 29 | 30 | dataset = ImageNetValidation(config={"size" : 256, "is_eval" :True}) 31 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, num_workers=1) 32 | 33 | # init and save configs 34 | configs = OmegaConf.load(opt.yaml_path) 35 | # model 36 | model = instantiate_from_config(configs.model) 37 | state_dict = torch.load(opt.model_path)['state_dict'] 38 | model.load_state_dict(state_dict) 39 | model.eval().cuda() 40 | 41 | result_list = [] 42 | for i, data in tqdm(enumerate(dataloader)): 43 | images = data["image"].float().cuda() 44 | dec, diff, grain_indices, gate, _ = model(images) 45 | 46 | sequence_length = 1 * (grain_indices == 0) + 4 * (grain_indices == 1) 47 | sequence_length = sequence_length.sum(-1).sum(-1) 48 | 49 | result_list += sequence_length.cpu().numpy().tolist() 50 | 51 | grain_map = draw_dual_grain_256res_color(images=images.clone(), indices=grain_indices, scaler=0.7) 52 | torchvision.utils.save_image(grain_map, "{}/grain_images_{}.png".format(opt.image_save_path, i)) 53 | 54 | print("mean: ", np.mean(result_list)) 55 | print("variance: ", np.var(result_list)) 56 | print("max: ", max(result_list)) 57 | print("min: ", min(result_list)) --------------------------------------------------------------------------------