├── .idea ├── .gitignore ├── Asymmetric_VQGAN.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── assets ├── a-painting-of-a-fire.png ├── a-photograph-of-a-fire.png ├── a-shirt-with-a-fire-printed-on-it.png ├── a-shirt-with-the-inscription-'fire'.png ├── a-watercolor-painting-of-a-fire.png ├── birdhouse.png ├── fire.png ├── inpainting.png ├── modelfigure.png ├── rdm-preview.jpg ├── reconstruction1.png ├── reconstruction2.png ├── results.gif ├── the-earth-is-on-fire,-oil-on-canvas.png ├── txt2img-convsample.png └── txt2img-preview.png ├── configs ├── autoencoder │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_32x32x4_large.yaml │ ├── autoencoder_kl_32x32x4_large2.yaml │ ├── autoencoder_kl_32x32x4_large2_train.yaml │ ├── autoencoder_kl_32x32x4_large_train.yaml │ ├── autoencoder_kl_32x32x4_train.yaml │ ├── autoencoder_kl_woc_32x32x4.yaml │ ├── autoencoder_kl_woc_32x32x4_large.yaml │ ├── autoencoder_kl_woc_32x32x4_large2.yaml │ ├── eval2_gpu.yaml │ ├── random_thick_256.yaml │ ├── v1-inpainting-inference.yaml │ └── v1-t2i-inference.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml └── retrieval-augmented-diffusion │ └── 768x768.yaml ├── data ├── DejaVuSans.ttf ├── example_conditioning │ ├── superresolution │ │ └── sample_0.jpg │ └── text_conditional │ │ └── sample_0.txt ├── imagenet_clsidx_to_label.txt ├── imagenet_train_hr_indices.p ├── imagenet_val_hr_indices.p ├── index_synset.yaml └── inpainting_examples │ ├── 6458524847_2f4c361183_k.png │ ├── 6458524847_2f4c361183_k_mask.png │ ├── 8399166846_f6fb4e4b8e_k.png │ ├── 8399166846_f6fb4e4b8e_k_mask.png │ ├── alex-iby-G_Pk4D9rMLs.png │ ├── alex-iby-G_Pk4D9rMLs_mask.png │ ├── bench2.png │ ├── bench2_mask.png │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png │ ├── billow926-12-Wc-Zgx6Y.png │ ├── billow926-12-Wc-Zgx6Y_mask.png │ ├── overture-creations-5sI6fQgYIuo.png │ ├── overture-creations-5sI6fQgYIuo_mask.png │ ├── photo-1583445095369-9c651e7e5d34.png │ └── photo-1583445095369-9c651e7e5d34_mask.png ├── demo.pkl ├── environment.yaml ├── inpaint_st.py ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ ├── autoencoder_large.py │ ├── autoencoder_large2.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── my_decoder.py │ │ ├── my_decoder_large.py │ │ ├── my_decoder_large2.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── mini.py ├── models ├── ade20k │ ├── __init__.py │ ├── base.py │ ├── color150.mat │ ├── mobilenet.py │ ├── object150_info.csv │ ├── resnet.py │ ├── segm_lib │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── batchnorm.py │ │ │ │ ├── comm.py │ │ │ │ ├── replicate.py │ │ │ │ ├── tests │ │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ │ └── test_sync_batchnorm.py │ │ │ │ └── unittest.py │ │ │ └── parallel │ │ │ │ ├── __init__.py │ │ │ │ └── data_parallel.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── distributed.py │ │ │ └── sampler.py │ │ │ └── th.py │ └── utils.py ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml ├── ldm │ ├── bsr_sr │ │ └── config.yaml │ ├── celeba256 │ │ └── config.yaml │ ├── cin256 │ │ └── config.yaml │ ├── ffhq256 │ │ └── config.yaml │ ├── inpainting_big │ │ └── config.yaml │ ├── layout2img-openimages256 │ │ └── config.yaml │ ├── lsun_beds256 │ │ └── config.yaml │ ├── lsun_churches256 │ │ └── config.yaml │ ├── semantic_synthesis256 │ │ └── config.yaml │ ├── semantic_synthesis512 │ │ └── config.yaml │ └── text2img256 │ │ └── config.yaml └── lpips_models │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── notebook_helpers.py ├── outputs └── inpainting_results │ ├── 6458524847_2f4c361183_k.png │ ├── 8399166846_f6fb4e4b8e_k.png │ ├── alex-iby-G_Pk4D9rMLs.png │ ├── bench2.png │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ ├── billow926-12-Wc-Zgx6Y.png │ ├── overture-creations-5sI6fQgYIuo.png │ └── photo-1583445095369-9c651e7e5d34.png ├── requirements.txt ├── saicinpainting ├── __init__.py ├── evaluation │ ├── __init__.py │ ├── data.py │ ├── evaluator.py │ ├── losses │ │ ├── __init__.py │ │ ├── base_loss.py │ │ ├── fid │ │ │ ├── __init__.py │ │ │ ├── fid_score.py │ │ │ └── inception.py │ │ ├── lpips.py │ │ └── ssim.py │ ├── masks │ │ ├── README.md │ │ ├── __init__.py │ │ ├── countless │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── countless2d.py │ │ │ ├── countless3d.py │ │ │ ├── images │ │ │ │ ├── gcim.jpg │ │ │ │ ├── gray_segmentation.png │ │ │ │ ├── segmentation.png │ │ │ │ └── sparse.png │ │ │ ├── memprof │ │ │ │ ├── countless2d_gcim_N_1000.png │ │ │ │ ├── countless2d_quick_gcim_N_1000.png │ │ │ │ ├── countless3d.png │ │ │ │ ├── countless3d_dynamic.png │ │ │ │ ├── countless3d_dynamic_generalized.png │ │ │ │ └── countless3d_generalized.png │ │ │ ├── requirements.txt │ │ │ └── test.py │ │ └── mask.py │ ├── refinement.py │ ├── utils.py │ └── vis.py ├── training │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── aug.py │ │ ├── datasets.py │ │ └── masks.py │ ├── losses │ │ ├── __init__.py │ │ ├── adversarial.py │ │ ├── constants.py │ │ ├── distance_weighting.py │ │ ├── feature_matching.py │ │ ├── perceptual.py │ │ ├── segmentation.py │ │ └── style_loss.py │ ├── modules │ │ ├── __init__.py │ │ ├── base.py │ │ ├── depthwise_sep_conv.py │ │ ├── fake_fakes.py │ │ ├── ffc.py │ │ ├── multidilated_conv.py │ │ ├── multiscale.py │ │ ├── pix2pixhd.py │ │ ├── spatial_transform.py │ │ └── squeeze_excitation.py │ ├── trainers │ │ ├── __init__.py │ │ ├── base.py │ │ └── default.py │ └── visualizers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── colors.py │ │ ├── directory.py │ │ └── noop.py └── utils.py ├── scripts ├── download_first_stages.sh ├── download_models.sh ├── inpaint.py ├── inpaint_st.py ├── knn2img.py ├── latent_imagenet_diffusion.ipynb ├── sample_diffusion.py └── train_searcher.py ├── setup.py ├── taming └── modules │ └── autoencoder │ └── lpips │ └── vgg.pth ├── teaser.png ├── text2img_visual.png ├── txt2img.py └── visual.png /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/Asymmetric_VQGAN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/a-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-painting-of-a-fire.png -------------------------------------------------------------------------------- /assets/a-photograph-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-photograph-of-a-fire.png -------------------------------------------------------------------------------- /assets/a-shirt-with-a-fire-printed-on-it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-shirt-with-a-fire-printed-on-it.png -------------------------------------------------------------------------------- /assets/a-shirt-with-the-inscription-'fire'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-shirt-with-the-inscription-'fire'.png -------------------------------------------------------------------------------- /assets/a-watercolor-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-watercolor-painting-of-a-fire.png -------------------------------------------------------------------------------- /assets/birdhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/birdhouse.png -------------------------------------------------------------------------------- /assets/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/fire.png -------------------------------------------------------------------------------- /assets/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/inpainting.png -------------------------------------------------------------------------------- /assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/modelfigure.png -------------------------------------------------------------------------------- /assets/rdm-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/rdm-preview.jpg -------------------------------------------------------------------------------- /assets/reconstruction1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/reconstruction1.png -------------------------------------------------------------------------------- /assets/reconstruction2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/reconstruction2.png -------------------------------------------------------------------------------- /assets/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/results.gif -------------------------------------------------------------------------------- /assets/the-earth-is-on-fire,-oil-on-canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/the-earth-is-on-fire,-oil-on-canvas.png -------------------------------------------------------------------------------- /assets/txt2img-convsample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/txt2img-convsample.png -------------------------------------------------------------------------------- /assets/txt2img-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/txt2img-preview.png -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 10 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4_large.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 4 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4_large2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large2.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 2 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4_large2_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large2.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | ckpt_path: stable_vqgan.ckpt 7 | ignore_keys: ["decoder"] 8 | embed_dim: 4 9 | num_gpus: 8 10 | 11 | scheduler_config: 12 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 15 | warm_up_steps: 5000 16 | lr_start: 0.01 17 | lr_max: 1 18 | 19 | scheduler_config_d: 20 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 21 | params: 22 | verbosity_interval: 0 23 | warm_up_steps: 5000 24 | lr_start: 0.01 25 | lr_max: 1 26 | 27 | lossconfig: 28 | target: ldm.modules.losses.LPIPSWithDiscriminator 29 | params: 30 | disc_start: 50001 31 | kl_weight: 0.000001 32 | disc_weight: 0.8 33 | 34 | ddconfig: 35 | double_z: True 36 | z_channels: 4 37 | resolution: 256 38 | in_channels: 3 39 | out_ch: 3 40 | ch: 128 41 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 42 | num_res_blocks: 2 43 | attn_resolutions: [ ] 44 | dropout: 0.0 45 | 46 | data: 47 | target: main.DataModuleFromConfig 48 | params: 49 | batch_size: 2 50 | wrap: True 51 | train: 52 | target: ldm.data.imagenet.ImageNetSRTrain 53 | params: 54 | size: 256 55 | degradation: pil_nearest 56 | Cf_r: 0.5 57 | validation: 58 | target: ldm.data.imagenet.ImageNetsv 59 | params: 60 | indir: /mnt/output/myvalsam 61 | 62 | lightning: 63 | callbacks: 64 | image_logger: 65 | target: main.ImageLogger 66 | params: 67 | batch_frequency: 1000 68 | max_images: 8 69 | increase_log_steps: True 70 | 71 | trainer: 72 | benchmark: True 73 | accumulate_grad_batches: 1 74 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4_large_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | ckpt_path: stable_vqgan.ckpt 7 | ignore_keys: ["decoder"] 8 | embed_dim: 4 9 | num_gpus: 8 10 | 11 | scheduler_config: 12 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 13 | params: 14 | verbosity_interval: 0 15 | warm_up_steps: 5000 16 | lr_start: 0.01 17 | lr_max: 1 18 | 19 | scheduler_config_d: 20 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 21 | params: 22 | verbosity_interval: 0 23 | warm_up_steps: 5000 24 | lr_start: 0.01 25 | lr_max: 1 26 | 27 | lossconfig: 28 | target: ldm.modules.losses.LPIPSWithDiscriminator 29 | params: 30 | disc_start: 50001 31 | kl_weight: 0.000001 32 | disc_weight: 0.8 33 | 34 | ddconfig: 35 | double_z: True 36 | z_channels: 4 37 | resolution: 256 38 | in_channels: 3 39 | out_ch: 3 40 | ch: 128 41 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 42 | num_res_blocks: 2 43 | attn_resolutions: [ ] 44 | dropout: 0.0 45 | 46 | data: 47 | target: main.DataModuleFromConfig 48 | params: 49 | batch_size: 4 50 | wrap: True 51 | train: 52 | target: ldm.data.imagenet.ImageNetSRTrain 53 | params: 54 | size: 256 55 | degradation: pil_nearest 56 | Cf_r: 0.5 57 | validation: 58 | target: ldm.data.imagenet.ImageNetsv 59 | params: 60 | indir: /mnt/output/myvalsam 61 | 62 | lightning: 63 | callbacks: 64 | image_logger: 65 | target: main.ImageLogger 66 | params: 67 | batch_frequency: 1000 68 | max_images: 8 69 | increase_log_steps: True 70 | 71 | trainer: 72 | benchmark: True 73 | accumulate_grad_batches: 1 74 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | # monitor: val/rec_loss 6 | ckpt_path: stable_vqgan.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 10 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_woc_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL_wocondition 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 10 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_woc_32x32x4_large.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large.AutoencoderKL_wocondition 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 4 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_woc_32x32x4_large2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder_large2.AutoencoderKL_wocondition 4 | params: 5 | # monitor: val/rec_loss 6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt 7 | embed_dim: 4 8 | num_gpus: 8 9 | 10 | scheduler_config: 11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 12 | params: 13 | verbosity_interval: 0 14 | warm_up_steps: 5000 15 | lr_start: 0.01 16 | lr_max: 1 17 | 18 | scheduler_config_d: 19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D 20 | params: 21 | verbosity_interval: 0 22 | warm_up_steps: 5000 23 | lr_start: 0.01 24 | lr_max: 1 25 | 26 | lossconfig: 27 | target: ldm.modules.losses.LPIPSWithDiscriminator 28 | params: 29 | disc_start: 50001 30 | kl_weight: 0.000001 31 | disc_weight: 0.8 32 | 33 | ddconfig: 34 | double_z: True 35 | z_channels: 4 36 | resolution: 256 37 | in_channels: 3 38 | out_ch: 3 39 | ch: 128 40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 41 | num_res_blocks: 2 42 | attn_resolutions: [ ] 43 | dropout: 0.0 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 2 49 | wrap: True 50 | train: 51 | target: ldm.data.imagenet.ImageNetSRTrain 52 | params: 53 | size: 256 54 | degradation: pil_nearest 55 | Cf_r: 0.5 56 | validation: 57 | target: ldm.data.imagenet.ImageNetsv 58 | params: 59 | indir: /mnt/output/myvalsam 60 | 61 | lightning: 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1000 67 | max_images: 8 68 | increase_log_steps: True 69 | 70 | trainer: 71 | benchmark: True 72 | accumulate_grad_batches: 1 73 | -------------------------------------------------------------------------------- /configs/autoencoder/eval2_gpu.yaml: -------------------------------------------------------------------------------- 1 | evaluator_kwargs: 2 | batch_size: 8 3 | 4 | dataset_kwargs: 5 | img_suffix: .png 6 | inpainted_suffix: .png -------------------------------------------------------------------------------- /configs/autoencoder/random_thick_256.yaml: -------------------------------------------------------------------------------- 1 | generator_kind: random 2 | 3 | mask_generator_kwargs: 4 | irregular_proba: 1 5 | irregular_kwargs: 6 | min_times: 1 7 | max_times: 5 8 | max_width: 100 9 | max_angle: 4 10 | max_len: 200 11 | 12 | box_proba: 0.3 13 | box_kwargs: 14 | margin: 10 15 | bbox_min_size: 30 16 | bbox_max_size: 150 17 | max_times: 3 18 | min_times: 1 19 | 20 | segm_proba: 0 21 | squares_proba: 0 22 | 23 | variants_n: 5 24 | 25 | max_masks_per_image: 1 26 | 27 | cropping: 28 | out_min_size: 256 29 | handle_small_mode: upscale 30 | out_square_crop: True 31 | crop_min_overlap: 1 32 | 33 | max_tamper_area: 0.5 34 | -------------------------------------------------------------------------------- /configs/autoencoder/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | # conditioning_key: concat 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | finetune_keys: null 20 | 21 | scheduler_config: # 10000 warmup steps 22 | target: ldm.lr_scheduler.LambdaLinearScheduler 23 | params: 24 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 26 | f_start: [ 1.e-6 ] 27 | f_max: [ 1. ] 28 | f_min: [ 1. ] 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | image_size: 32 # unused 34 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: [ 4, 2, 1 ] 38 | num_res_blocks: 2 39 | channel_mult: [ 1, 2, 4, 4 ] 40 | num_heads: 8 41 | use_spatial_transformer: True 42 | transformer_depth: 1 43 | context_dim: 768 44 | use_checkpoint: True 45 | legacy: False 46 | 47 | first_stage_config: 48 | target: ldm.models.autoencoder.AutoencoderKL 49 | params: 50 | embed_dim: 4 51 | monitor: val/rec_loss 52 | ddconfig: 53 | double_z: true 54 | z_channels: 4 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 2 62 | - 4 63 | - 4 64 | num_res_blocks: 2 65 | attn_resolutions: [] 66 | dropout: 0.0 67 | lossconfig: 68 | target: torch.nn.Identity 69 | 70 | cond_stage_config: 71 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 72 | 73 | -------------------------------------------------------------------------------- /configs/autoencoder/v1-t2i-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /data/example_conditioning/superresolution/sample_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/example_conditioning/superresolution/sample_0.jpg -------------------------------------------------------------------------------- /data/example_conditioning/text_conditional/sample_0.txt: -------------------------------------------------------------------------------- 1 | A basket of cerries 2 | -------------------------------------------------------------------------------- /data/imagenet_train_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/imagenet_train_hr_indices.p -------------------------------------------------------------------------------- /data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /data/inpainting_examples/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /data/inpainting_examples/6458524847_2f4c361183_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/6458524847_2f4c361183_k_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bench2.png -------------------------------------------------------------------------------- /data/inpainting_examples/bench2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bench2_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png -------------------------------------------------------------------------------- /demo.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/demo.pkl -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | #name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - easydict==1.10 15 | - scikit-learn==1.2.0 16 | - opencv-python==4.1.2.30 17 | - pudb==2019.2 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - transformers==4.6.0 27 | - torchmetrics==0.6 28 | - academictorrents==2.3.3 29 | # - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | # - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | # - -e . -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /mini.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.backends.cudnn as cudnn 4 | import os 5 | import random 6 | import argparse 7 | import numpy as np 8 | import datetime 9 | import warnings 10 | 11 | warnings.filterwarnings('ignore', 12 | 'Argument interpolation should be of type InterpolationMode instead of int', 13 | UserWarning) 14 | warnings.filterwarnings('ignore', 15 | 'Leaking Caffe2 thread-pool after fork', 16 | UserWarning) 17 | 18 | 19 | def init_distributed_mode(args): 20 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 21 | args.rank = int(os.environ["RANK"]) 22 | args.world_size = int(os.environ['WORLD_SIZE']) 23 | args.gpu = int(os.environ['LOCAL_RANK']) 24 | else: 25 | print('Not using distributed mode') 26 | args.distributed = False 27 | return 28 | args.distributed = True 29 | 30 | torch.cuda.set_device(args.gpu) 31 | args.dist_backend = 'nccl' 32 | print('| distributed init (rank {}): {}'.format( 33 | args.rank, args.dist_url), flush=True) 34 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 35 | world_size=args.world_size, rank=args.rank, 36 | timeout=datetime.timedelta(seconds=3600)) 37 | torch.distributed.barrier() 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_rank(): 49 | if not is_dist_avail_and_initialized(): 50 | return 0 51 | return dist.get_rank() 52 | 53 | 54 | def is_main_process(): 55 | return get_rank() == 0 56 | 57 | 58 | def get_args_parser(): 59 | parser = argparse.ArgumentParser('training and evaluation script', add_help=False) 60 | 61 | # distributed training parameters 62 | parser.add_argument('--world_size', default=1, type=int, 63 | help='number of distributed processes') 64 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 65 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 66 | parser.add_argument('--device', default='cuda', help='device to use for training / testing') 67 | parser.add_argument('--seed', default=0, type=int) 68 | return parser 69 | 70 | 71 | def main(args): 72 | os.environ['LOCAL_RANK'] = str(args.local_rank) 73 | init_distributed_mode(args) 74 | device = torch.device(args.device) 75 | print(args) 76 | 77 | # fix the seed for reproducibility 78 | seed = args.seed + get_rank() 79 | torch.manual_seed(seed) 80 | np.random.seed(seed) 81 | random.seed(seed) 82 | cudnn.benchmark = True 83 | 84 | print(" we reach the post init ") 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser('training and evaluation script', parents=[get_args_parser()]) 89 | args = parser.parse_args() 90 | main(args) 91 | -------------------------------------------------------------------------------- /models/ade20k/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /models/ade20k/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/ade20k/color150.mat -------------------------------------------------------------------------------- /models/ade20k/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | from .utils import load_url 9 | from .segm_lib.nn import SynchronizedBatchNorm2d 10 | 11 | BatchNorm2d = SynchronizedBatchNorm2d 12 | 13 | 14 | __all__ = ['mobilenetv2'] 15 | 16 | 17 | model_urls = { 18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 19 | } 20 | 21 | 22 | def conv_bn(inp, oup, stride): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 25 | BatchNorm2d(oup), 26 | nn.ReLU6(inplace=True) 27 | ) 28 | 29 | 30 | def conv_1x1_bn(inp, oup): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 33 | BatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | hidden_dim = round(inp * expand_ratio) 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | if expand_ratio == 1: 48 | self.conv = nn.Sequential( 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | BatchNorm2d(oup), 56 | ) 57 | else: 58 | self.conv = nn.Sequential( 59 | # pw 60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 61 | BatchNorm2d(hidden_dim), 62 | nn.ReLU6(inplace=True), 63 | # dw 64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 65 | BatchNorm2d(hidden_dim), 66 | nn.ReLU6(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | BatchNorm2d(oup), 70 | ) 71 | 72 | def forward(self, x): 73 | if self.use_res_connect: 74 | return x + self.conv(x) 75 | else: 76 | return self.conv(x) 77 | 78 | 79 | class MobileNetV2(nn.Module): 80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 81 | super(MobileNetV2, self).__init__() 82 | block = InvertedResidual 83 | input_channel = 32 84 | last_channel = 1280 85 | interverted_residual_setting = [ 86 | # t, c, n, s 87 | [1, 16, 1, 1], 88 | [6, 24, 2, 2], 89 | [6, 32, 3, 2], 90 | [6, 64, 4, 2], 91 | [6, 96, 3, 1], 92 | [6, 160, 3, 2], 93 | [6, 320, 1, 1], 94 | ] 95 | 96 | # building first layer 97 | assert input_size % 32 == 0 98 | input_channel = int(input_channel * width_mult) 99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 100 | self.features = [conv_bn(3, input_channel, 2)] 101 | # building inverted residual blocks 102 | for t, c, n, s in interverted_residual_setting: 103 | output_channel = int(c * width_mult) 104 | for i in range(n): 105 | if i == 0: 106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 107 | else: 108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 109 | input_channel = output_channel 110 | # building last several layers 111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 112 | # make it nn.Sequential 113 | self.features = nn.Sequential(*self.features) 114 | 115 | # building classifier 116 | self.classifier = nn.Sequential( 117 | nn.Dropout(0.2), 118 | nn.Linear(self.last_channel, n_class), 119 | ) 120 | 121 | self._initialize_weights() 122 | 123 | def forward(self, x): 124 | x = self.features(x) 125 | x = x.mean(3).mean(2) 126 | x = self.classifier(x) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | if m.bias is not None: 135 | m.bias.data.zero_() 136 | elif isinstance(m, BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.Linear): 140 | n = m.weight.size(1) 141 | m.weight.data.normal_(0, 0.01) 142 | m.bias.data.zero_() 143 | 144 | 145 | def mobilenetv2(pretrained=False, **kwargs): 146 | """Constructs a MobileNet_V2 model. 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = MobileNetV2(n_class=1000, **kwargs) 152 | if pretrained: 153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 154 | return model -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /models/ade20k/segm_lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /models/ade20k/utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | try: 10 | from urllib import urlretrieve 11 | except ImportError: 12 | from urllib.request import urlretrieve 13 | 14 | 15 | def load_url(url, model_dir='./pretrained', map_location=None): 16 | if not os.path.exists(model_dir): 17 | os.makedirs(model_dir) 18 | filename = url.split('/')[-1] 19 | cached_file = os.path.join(model_dir, filename) 20 | if not os.path.exists(cached_file): 21 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 22 | urlretrieve(url, cached_file) 23 | return torch.load(cached_file, map_location=map_location) 24 | 25 | 26 | def color_encode(labelmap, colors, mode='RGB'): 27 | labelmap = labelmap.astype('int') 28 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 29 | dtype=np.uint8) 30 | for label in np.unique(labelmap): 31 | if label < 0: 32 | continue 33 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 34 | np.tile(colors[label], 35 | (labelmap.shape[0], labelmap.shape[1], 1)) 36 | 37 | if mode == 'BGR': 38 | return labelmap_rgb[:, :, ::-1] 39 | else: 40 | return labelmap_rgb 41 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /models/lpips_models/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/alex.pth -------------------------------------------------------------------------------- /models/lpips_models/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/squeeze.pth -------------------------------------------------------------------------------- /models/lpips_models/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/vgg.pth -------------------------------------------------------------------------------- /outputs/inpainting_results/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /outputs/inpainting_results/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /outputs/inpainting_results/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /outputs/inpainting_results/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/bench2.png -------------------------------------------------------------------------------- /outputs/inpainting_results/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /outputs/inpainting_results/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /outputs/inpainting_results/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /outputs/inpainting_results/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | easydict==1.10 3 | scikit-learn==1.2.0 4 | opencv-python==4.1.2.30 5 | pudb==2019.2 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | future>=0.17.1 9 | pytorch-lightning==1.4.2 10 | omegaconf==2.1.1 11 | test-tube>=0.7.5 12 | streamlit>=0.73.1 13 | einops==0.3.0 14 | torch-fidelity==0.3.0 15 | #transformers==4.6.0 16 | transformers==4.19.2 17 | torchmetrics==0.6 18 | protobuf==3.20.3 19 | invisible-watermark 20 | diffusers==0.12.1 21 | kornia==0.6 22 | -------------------------------------------------------------------------------- /saicinpainting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 6 | from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore 7 | 8 | 9 | def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): 10 | logging.info(f'Make evaluator {kind}') 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | metrics = {} 13 | if ssim: 14 | metrics['ssim'] = SSIMScore() 15 | if lpips: 16 | metrics['lpips'] = LPIPSScore() 17 | if fid: 18 | metrics['fid'] = FIDScore().to(device) 19 | 20 | if integral_kind is None: 21 | integral_func = None 22 | elif integral_kind == 'ssim_fid100_f1': 23 | integral_func = ssim_fid100_f1 24 | elif integral_kind == 'lpips_fid100_f1': 25 | integral_func = lpips_fid100_f1 26 | else: 27 | raise ValueError(f'Unexpected integral_kind={integral_kind}') 28 | 29 | if kind == 'default': 30 | return InpaintingEvaluatorOnline(scores=metrics, 31 | integral_func=integral_func, 32 | integral_title=integral_kind, 33 | **kwargs) 34 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/losses/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/losses/fid/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/losses/ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class SSIM(torch.nn.Module): 7 | """SSIM. Modified from: 8 | https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 9 | """ 10 | 11 | def __init__(self, window_size=11, size_average=True): 12 | super().__init__() 13 | self.window_size = window_size 14 | self.size_average = size_average 15 | self.channel = 1 16 | self.register_buffer('window', self._create_window(window_size, self.channel)) 17 | 18 | def forward(self, img1, img2): 19 | assert len(img1.shape) == 4 20 | 21 | channel = img1.size()[1] 22 | 23 | if channel == self.channel and self.window.data.type() == img1.data.type(): 24 | window = self.window 25 | else: 26 | window = self._create_window(self.window_size, channel) 27 | 28 | # window = window.to(img1.get_device()) 29 | window = window.type_as(img1) 30 | 31 | self.window = window 32 | self.channel = channel 33 | 34 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) 35 | 36 | def _gaussian(self, window_size, sigma): 37 | gauss = torch.Tensor([ 38 | np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size) 39 | ]) 40 | return gauss / gauss.sum() 41 | 42 | def _create_window(self, window_size, channel): 43 | _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) 44 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 45 | return _2D_window.expand(channel, 1, window_size, window_size).contiguous() 46 | 47 | def _ssim(self, img1, img2, window, window_size, channel, size_average=True): 48 | mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel) 49 | mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel) 50 | 51 | mu1_sq = mu1.pow(2) 52 | mu2_sq = mu2.pow(2) 53 | mu1_mu2 = mu1 * mu2 54 | 55 | sigma1_sq = F.conv2d( 56 | img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq 57 | sigma2_sq = F.conv2d( 58 | img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq 59 | sigma12 = F.conv2d( 60 | img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2 61 | 62 | C1 = 0.01 ** 2 63 | C2 = 0.03 ** 2 64 | 65 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 66 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 67 | 68 | if size_average: 69 | return ssim_map.mean() 70 | 71 | return ssim_map.mean(1).mean(1).mean(1) 72 | 73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 74 | return 75 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/README.md: -------------------------------------------------------------------------------- 1 | # Current algorithm 2 | 3 | ## Choice of mask objects 4 | 5 | For identification of the objects which are suitable for mask obtaining, panoptic segmentation model 6 | from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances 7 | belong either to "stuff" or "things" types. We consider that instances of objects should have category belong 8 | to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big 9 | area indicates either of the instance being a background or a main object which should not be removed. 10 | 11 | ## Choice of position for mask 12 | 13 | We consider that input image has size 2^n x 2^m. We downsample it using 14 | [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to 15 | 64 = 2^8 = 2^{downsample_levels}. 16 | 17 | ### Augmentation 18 | 19 | There are several parameters for augmentation: 20 | - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the 21 | image completely. 22 | - 23 | 24 | ### Shift 25 | 26 | 27 | ## Select 28 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/.gitignore: -------------------------------------------------------------------------------- 1 | results -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) 2 | 3 | Python COUNTLESS Downsampling 4 | ============================= 5 | 6 | To install: 7 | 8 | `pip install -r requirements.txt` 9 | 10 | To test: 11 | 12 | `python test.py` 13 | 14 | To benchmark countless2d: 15 | 16 | `python python/countless2d.py python/images/gray_segmentation.png` 17 | 18 | To benchmark countless3d: 19 | 20 | `python python/countless3d.py` 21 | 22 | Adjust N and the list of algorithms inside each script to modify the run parameters. 23 | 24 | 25 | Python3 is slightly faster than Python2. -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/__init__.py -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/gcim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/gcim.jpg -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/gray_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/segmentation.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/images/sparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/sparse.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png -------------------------------------------------------------------------------- /saicinpainting/evaluation/masks/countless/requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=6.2.0 2 | numpy>=1.16 3 | scipy 4 | tqdm 5 | memory_profiler 6 | six 7 | pytest -------------------------------------------------------------------------------- /saicinpainting/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import yaml 4 | from easydict import EasyDict as edict 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | def load_yaml(path): 10 | with open(path, 'r') as f: 11 | return edict(yaml.safe_load(f)) 12 | 13 | 14 | def move_to_device(obj, device): 15 | if isinstance(obj, nn.Module): 16 | return obj.to(device) 17 | if torch.is_tensor(obj): 18 | return obj.to(device) 19 | if isinstance(obj, (tuple, list)): 20 | return [move_to_device(el, device) for el in obj] 21 | if isinstance(obj, dict): 22 | return {name: move_to_device(val, device) for name, val in obj.items()} 23 | raise ValueError(f'Unexpected type {type(obj)}') 24 | 25 | 26 | class SmallMode(Enum): 27 | DROP = "drop" 28 | UPSCALE = "upscale" 29 | -------------------------------------------------------------------------------- /saicinpainting/evaluation/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io 3 | from skimage.segmentation import mark_boundaries 4 | 5 | 6 | def save_item_for_vis(item, out_file): 7 | mask = item['mask'] > 0.5 8 | if mask.ndim == 3: 9 | mask = mask[0] 10 | img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)), 11 | mask, 12 | color=(1., 0., 0.), 13 | outline_color=(1., 1., 1.), 14 | mode='thick') 15 | 16 | if 'inpainted' in item: 17 | inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)), 18 | mask, 19 | color=(1., 0., 0.), 20 | mode='outer') 21 | img = np.concatenate((img, inp_img), axis=1) 22 | 23 | img = np.clip(img * 255, 0, 255).astype('uint8') 24 | io.imsave(out_file, img) 25 | 26 | 27 | def save_mask_for_sidebyside(item, out_file): 28 | mask = item['mask']# > 0.5 29 | if mask.ndim == 3: 30 | mask = mask[0] 31 | mask = np.clip(mask * 255, 0, 255).astype('uint8') 32 | io.imsave(out_file, mask) 33 | 34 | def save_img_for_sidebyside(item, out_file): 35 | img = np.transpose(item['image'], (1, 2, 0)) 36 | img = np.clip(img * 255, 0, 255).astype('uint8') 37 | io.imsave(out_file, img) -------------------------------------------------------------------------------- /saicinpainting/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/data/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/data/aug.py: -------------------------------------------------------------------------------- 1 | from albumentations import DualIAATransform, to_tuple 2 | import imgaug.augmenters as iaa 3 | 4 | class IAAAffine2(DualIAATransform): 5 | """Place a regular grid of points on the input and randomly move the neighbourhood of these point around 6 | via affine transformations. 7 | 8 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 9 | 10 | Args: 11 | p (float): probability of applying the transform. Default: 0.5. 12 | 13 | Targets: 14 | image, mask 15 | """ 16 | 17 | def __init__( 18 | self, 19 | scale=(0.7, 1.3), 20 | translate_percent=None, 21 | translate_px=None, 22 | rotate=0.0, 23 | shear=(-0.1, 0.1), 24 | order=1, 25 | cval=0, 26 | mode="reflect", 27 | always_apply=False, 28 | p=0.5, 29 | ): 30 | super(IAAAffine2, self).__init__(always_apply, p) 31 | self.scale = dict(x=scale, y=scale) 32 | self.translate_percent = to_tuple(translate_percent, 0) 33 | self.translate_px = to_tuple(translate_px, 0) 34 | self.rotate = to_tuple(rotate) 35 | self.shear = dict(x=shear, y=shear) 36 | self.order = order 37 | self.cval = cval 38 | self.mode = mode 39 | 40 | @property 41 | def processor(self): 42 | return iaa.Affine( 43 | self.scale, 44 | self.translate_percent, 45 | self.translate_px, 46 | self.rotate, 47 | self.shear, 48 | self.order, 49 | self.cval, 50 | self.mode, 51 | ) 52 | 53 | def get_transform_init_args_names(self): 54 | return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") 55 | 56 | 57 | class IAAPerspective2(DualIAATransform): 58 | """Perform a random four point perspective transform of the input. 59 | 60 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 61 | 62 | Args: 63 | scale ((float, float): standard deviation of the normal distributions. These are used to sample 64 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). 65 | p (float): probability of applying the transform. Default: 0.5. 66 | 67 | Targets: 68 | image, mask 69 | """ 70 | 71 | def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, 72 | order=1, cval=0, mode="replicate"): 73 | super(IAAPerspective2, self).__init__(always_apply, p) 74 | self.scale = to_tuple(scale, 1.0) 75 | self.keep_size = keep_size 76 | self.cval = cval 77 | self.mode = mode 78 | 79 | @property 80 | def processor(self): 81 | return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) 82 | 83 | def get_transform_init_args_names(self): 84 | return ("scale", "keep_size") 85 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/losses/__init__.py -------------------------------------------------------------------------------- /saicinpainting/training/losses/constants.py: -------------------------------------------------------------------------------- 1 | weights = {"ade20k": 2 | [6.34517766497462, 3 | 9.328358208955224, 4 | 11.389521640091116, 5 | 16.10305958132045, 6 | 20.833333333333332, 7 | 22.22222222222222, 8 | 25.125628140703515, 9 | 43.29004329004329, 10 | 50.5050505050505, 11 | 54.6448087431694, 12 | 55.24861878453038, 13 | 60.24096385542168, 14 | 62.5, 15 | 66.2251655629139, 16 | 84.74576271186442, 17 | 90.90909090909092, 18 | 91.74311926605505, 19 | 96.15384615384616, 20 | 96.15384615384616, 21 | 97.08737864077669, 22 | 102.04081632653062, 23 | 135.13513513513513, 24 | 149.2537313432836, 25 | 153.84615384615384, 26 | 163.93442622950818, 27 | 166.66666666666666, 28 | 188.67924528301887, 29 | 192.30769230769232, 30 | 217.3913043478261, 31 | 227.27272727272725, 32 | 227.27272727272725, 33 | 227.27272727272725, 34 | 303.03030303030306, 35 | 322.5806451612903, 36 | 333.3333333333333, 37 | 370.3703703703703, 38 | 384.61538461538464, 39 | 416.6666666666667, 40 | 416.6666666666667, 41 | 434.7826086956522, 42 | 434.7826086956522, 43 | 454.5454545454545, 44 | 454.5454545454545, 45 | 500.0, 46 | 526.3157894736842, 47 | 526.3157894736842, 48 | 555.5555555555555, 49 | 555.5555555555555, 50 | 555.5555555555555, 51 | 555.5555555555555, 52 | 555.5555555555555, 53 | 555.5555555555555, 54 | 555.5555555555555, 55 | 588.2352941176471, 56 | 588.2352941176471, 57 | 588.2352941176471, 58 | 588.2352941176471, 59 | 588.2352941176471, 60 | 666.6666666666666, 61 | 666.6666666666666, 62 | 666.6666666666666, 63 | 666.6666666666666, 64 | 714.2857142857143, 65 | 714.2857142857143, 66 | 714.2857142857143, 67 | 714.2857142857143, 68 | 714.2857142857143, 69 | 769.2307692307693, 70 | 769.2307692307693, 71 | 769.2307692307693, 72 | 833.3333333333334, 73 | 833.3333333333334, 74 | 833.3333333333334, 75 | 833.3333333333334, 76 | 909.090909090909, 77 | 1000.0, 78 | 1111.111111111111, 79 | 1111.111111111111, 80 | 1111.111111111111, 81 | 1111.111111111111, 82 | 1111.111111111111, 83 | 1250.0, 84 | 1250.0, 85 | 1250.0, 86 | 1250.0, 87 | 1250.0, 88 | 1428.5714285714287, 89 | 1428.5714285714287, 90 | 1428.5714285714287, 91 | 1428.5714285714287, 92 | 1428.5714285714287, 93 | 1428.5714285714287, 94 | 1428.5714285714287, 95 | 1666.6666666666667, 96 | 1666.6666666666667, 97 | 1666.6666666666667, 98 | 1666.6666666666667, 99 | 1666.6666666666667, 100 | 1666.6666666666667, 101 | 1666.6666666666667, 102 | 1666.6666666666667, 103 | 1666.6666666666667, 104 | 1666.6666666666667, 105 | 1666.6666666666667, 106 | 2000.0, 107 | 2000.0, 108 | 2000.0, 109 | 2000.0, 110 | 2000.0, 111 | 2000.0, 112 | 2000.0, 113 | 2000.0, 114 | 2000.0, 115 | 2000.0, 116 | 2000.0, 117 | 2000.0, 118 | 2000.0, 119 | 2000.0, 120 | 2000.0, 121 | 2000.0, 122 | 2000.0, 123 | 2500.0, 124 | 2500.0, 125 | 2500.0, 126 | 2500.0, 127 | 2500.0, 128 | 2500.0, 129 | 2500.0, 130 | 2500.0, 131 | 2500.0, 132 | 2500.0, 133 | 2500.0, 134 | 2500.0, 135 | 2500.0, 136 | 3333.3333333333335, 137 | 3333.3333333333335, 138 | 3333.3333333333335, 139 | 3333.3333333333335, 140 | 3333.3333333333335, 141 | 3333.3333333333335, 142 | 3333.3333333333335, 143 | 3333.3333333333335, 144 | 3333.3333333333335, 145 | 3333.3333333333335, 146 | 3333.3333333333335, 147 | 3333.3333333333335, 148 | 3333.3333333333335, 149 | 5000.0, 150 | 5000.0, 151 | 5000.0] 152 | } -------------------------------------------------------------------------------- /saicinpainting/training/losses/feature_matching.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing): 8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none') 9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 10 | return (pixel_weights * per_pixel_l2).mean() 11 | 12 | 13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing): 14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none') 15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 16 | return (pixel_weights * per_pixel_l1).mean() 17 | 18 | 19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): 20 | if mask is None: 21 | res = torch.stack([F.mse_loss(fake_feat, target_feat) 22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean() 23 | else: 24 | res = 0 25 | norm = 0 26 | for fake_feat, target_feat in zip(fake_features, target_features): 27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) 28 | error_weights = 1 - cur_mask 29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() 30 | res = res + cur_val 31 | norm += 1 32 | res = res / norm 33 | return res 34 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | from models.ade20k import ModelBuilder 7 | from saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] 11 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] 12 | 13 | 14 | class PerceptualLoss(nn.Module): 15 | def __init__(self, normalize_inputs=True): 16 | super(PerceptualLoss, self).__init__() 17 | 18 | self.normalize_inputs = normalize_inputs 19 | self.mean_ = IMAGENET_MEAN 20 | self.std_ = IMAGENET_STD 21 | 22 | vgg = torchvision.models.vgg19(pretrained=True).features 23 | vgg_avg_pooling = [] 24 | 25 | for weights in vgg.parameters(): 26 | weights.requires_grad = False 27 | 28 | for module in vgg.modules(): 29 | if module.__class__.__name__ == 'Sequential': 30 | continue 31 | elif module.__class__.__name__ == 'MaxPool2d': 32 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) 33 | else: 34 | vgg_avg_pooling.append(module) 35 | 36 | self.vgg = nn.Sequential(*vgg_avg_pooling) 37 | 38 | def do_normalize_inputs(self, x): 39 | return (x - self.mean_.to(x.device)) / self.std_.to(x.device) 40 | 41 | def partial_losses(self, input, target, mask=None): 42 | check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') 43 | 44 | # we expect input and target to be in [0, 1] range 45 | losses = [] 46 | 47 | if self.normalize_inputs: 48 | features_input = self.do_normalize_inputs(input) 49 | features_target = self.do_normalize_inputs(target) 50 | else: 51 | features_input = input 52 | features_target = target 53 | 54 | for layer in self.vgg[:30]: 55 | 56 | features_input = layer(features_input) 57 | features_target = layer(features_target) 58 | 59 | if layer.__class__.__name__ == 'ReLU': 60 | loss = F.mse_loss(features_input, features_target, reduction='none') 61 | 62 | if mask is not None: 63 | cur_mask = F.interpolate(mask, size=features_input.shape[-2:], 64 | mode='bilinear', align_corners=False) 65 | loss = loss * (1 - cur_mask) 66 | 67 | loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) 68 | losses.append(loss) 69 | 70 | return losses 71 | 72 | def forward(self, input, target, mask=None): 73 | losses = self.partial_losses(input, target, mask=mask) 74 | return torch.stack(losses).sum(dim=0) 75 | 76 | def get_global_features(self, input): 77 | check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') 78 | 79 | if self.normalize_inputs: 80 | features_input = self.do_normalize_inputs(input) 81 | else: 82 | features_input = input 83 | 84 | features_input = self.vgg(features_input) 85 | return features_input 86 | 87 | 88 | class ResNetPL(nn.Module): 89 | def __init__(self, weight=1, 90 | weights_path=None, arch_encoder='resnet50dilated', segmentation=True): 91 | super().__init__() 92 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path, 93 | arch_encoder=arch_encoder, 94 | arch_decoder='ppm_deepsup', 95 | fc_dim=2048, 96 | segmentation=segmentation) 97 | self.impl.eval() 98 | for w in self.impl.parameters(): 99 | w.requires_grad_(False) 100 | 101 | self.weight = weight 102 | 103 | def forward(self, pred, target): 104 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) 105 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) 106 | 107 | pred_feats = self.impl(pred, return_feature_maps=True) 108 | target_feats = self.impl(target, return_feature_maps=True) 109 | 110 | result = torch.stack([F.mse_loss(cur_pred, cur_target) 111 | for cur_pred, cur_target 112 | in zip(pred_feats, target_feats)]).sum() * self.weight 113 | return result 114 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .constants import weights as constant_weights 6 | 7 | 8 | class CrossEntropy2d(nn.Module): 9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): 10 | """ 11 | weight (Tensor, optional): a manual rescaling weight given to each class. 12 | If given, has to be a Tensor of size "nclasses" 13 | """ 14 | super(CrossEntropy2d, self).__init__() 15 | self.reduction = reduction 16 | self.ignore_label = ignore_label 17 | self.weights = weights 18 | if self.weights is not None: 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device) 21 | 22 | def forward(self, predict, target): 23 | """ 24 | Args: 25 | predict:(n, c, h, w) 26 | target:(n, 1, h, w) 27 | """ 28 | target = target.long() 29 | assert not target.requires_grad 30 | assert predict.dim() == 4, "{0}".format(predict.size()) 31 | assert target.dim() == 4, "{0}".format(target.size()) 32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 33 | assert target.size(1) == 1, "{0}".format(target.size(1)) 34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) 35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) 36 | target = target.squeeze(1) 37 | n, c, h, w = predict.size() 38 | target_mask = (target >= 0) * (target != self.ignore_label) 39 | target = target[target_mask] 40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) 43 | return loss 44 | -------------------------------------------------------------------------------- /saicinpainting/training/losses/style_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class PerceptualLoss(nn.Module): 7 | r""" 8 | Perceptual loss, VGG-based 9 | https://arxiv.org/abs/1603.08155 10 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 11 | """ 12 | 13 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 14 | super(PerceptualLoss, self).__init__() 15 | self.add_module('vgg', VGG19()) 16 | self.criterion = torch.nn.L1Loss() 17 | self.weights = weights 18 | 19 | def __call__(self, x, y): 20 | # Compute features 21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 22 | 23 | content_loss = 0.0 24 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 25 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 26 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 27 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 28 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 29 | 30 | 31 | return content_loss 32 | 33 | 34 | class VGG19(torch.nn.Module): 35 | def __init__(self): 36 | super(VGG19, self).__init__() 37 | features = models.vgg19(pretrained=True).features 38 | self.relu1_1 = torch.nn.Sequential() 39 | self.relu1_2 = torch.nn.Sequential() 40 | 41 | self.relu2_1 = torch.nn.Sequential() 42 | self.relu2_2 = torch.nn.Sequential() 43 | 44 | self.relu3_1 = torch.nn.Sequential() 45 | self.relu3_2 = torch.nn.Sequential() 46 | self.relu3_3 = torch.nn.Sequential() 47 | self.relu3_4 = torch.nn.Sequential() 48 | 49 | self.relu4_1 = torch.nn.Sequential() 50 | self.relu4_2 = torch.nn.Sequential() 51 | self.relu4_3 = torch.nn.Sequential() 52 | self.relu4_4 = torch.nn.Sequential() 53 | 54 | self.relu5_1 = torch.nn.Sequential() 55 | self.relu5_2 = torch.nn.Sequential() 56 | self.relu5_3 = torch.nn.Sequential() 57 | self.relu5_4 = torch.nn.Sequential() 58 | 59 | for x in range(2): 60 | self.relu1_1.add_module(str(x), features[x]) 61 | 62 | for x in range(2, 4): 63 | self.relu1_2.add_module(str(x), features[x]) 64 | 65 | for x in range(4, 7): 66 | self.relu2_1.add_module(str(x), features[x]) 67 | 68 | for x in range(7, 9): 69 | self.relu2_2.add_module(str(x), features[x]) 70 | 71 | for x in range(9, 12): 72 | self.relu3_1.add_module(str(x), features[x]) 73 | 74 | for x in range(12, 14): 75 | self.relu3_2.add_module(str(x), features[x]) 76 | 77 | for x in range(14, 16): 78 | self.relu3_2.add_module(str(x), features[x]) 79 | 80 | for x in range(16, 18): 81 | self.relu3_4.add_module(str(x), features[x]) 82 | 83 | for x in range(18, 21): 84 | self.relu4_1.add_module(str(x), features[x]) 85 | 86 | for x in range(21, 23): 87 | self.relu4_2.add_module(str(x), features[x]) 88 | 89 | for x in range(23, 25): 90 | self.relu4_3.add_module(str(x), features[x]) 91 | 92 | for x in range(25, 27): 93 | self.relu4_4.add_module(str(x), features[x]) 94 | 95 | for x in range(27, 30): 96 | self.relu5_1.add_module(str(x), features[x]) 97 | 98 | for x in range(30, 32): 99 | self.relu5_2.add_module(str(x), features[x]) 100 | 101 | for x in range(32, 34): 102 | self.relu5_3.add_module(str(x), features[x]) 103 | 104 | for x in range(34, 36): 105 | self.relu5_4.add_module(str(x), features[x]) 106 | 107 | # don't need the gradients, just want the features 108 | for param in self.parameters(): 109 | param.requires_grad = False 110 | 111 | def forward(self, x): 112 | relu1_1 = self.relu1_1(x) 113 | relu1_2 = self.relu1_2(relu1_1) 114 | 115 | relu2_1 = self.relu2_1(relu1_2) 116 | relu2_2 = self.relu2_2(relu2_1) 117 | 118 | relu3_1 = self.relu3_1(relu2_2) 119 | relu3_2 = self.relu3_2(relu3_1) 120 | relu3_3 = self.relu3_3(relu3_2) 121 | relu3_4 = self.relu3_4(relu3_3) 122 | 123 | relu4_1 = self.relu4_1(relu3_4) 124 | relu4_2 = self.relu4_2(relu4_1) 125 | relu4_3 = self.relu4_3(relu4_2) 126 | relu4_4 = self.relu4_4(relu4_3) 127 | 128 | relu5_1 = self.relu5_1(relu4_4) 129 | relu5_2 = self.relu5_2(relu5_1) 130 | relu5_3 = self.relu5_3(relu5_2) 131 | relu5_4 = self.relu5_4(relu5_3) 132 | 133 | out = { 134 | 'relu1_1': relu1_1, 135 | 'relu1_2': relu1_2, 136 | 137 | 'relu2_1': relu2_1, 138 | 'relu2_2': relu2_2, 139 | 140 | 'relu3_1': relu3_1, 141 | 'relu3_2': relu3_2, 142 | 'relu3_3': relu3_3, 143 | 'relu3_4': relu3_4, 144 | 145 | 'relu4_1': relu4_1, 146 | 'relu4_2': relu4_2, 147 | 'relu4_3': relu4_3, 148 | 'relu4_4': relu4_4, 149 | 150 | 'relu5_1': relu5_1, 151 | 'relu5_2': relu5_2, 152 | 'relu5_3': relu5_3, 153 | 'relu5_4': relu5_4, 154 | } 155 | return out 156 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from saicinpainting.training.modules.ffc import FFCResNetGenerator 4 | from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ 5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator 6 | 7 | def make_generator(config, kind, **kwargs): 8 | logging.info(f'Make generator {kind}') 9 | 10 | if kind == 'pix2pixhd_multidilated': 11 | return MultiDilatedGlobalGenerator(**kwargs) 12 | 13 | if kind == 'pix2pixhd_global': 14 | return GlobalGenerator(**kwargs) 15 | 16 | if kind == 'ffc_resnet': 17 | return FFCResNetGenerator(**kwargs) 18 | 19 | raise ValueError(f'Unknown generator kind {kind}') 20 | 21 | 22 | def make_discriminator(kind, **kwargs): 23 | logging.info(f'Make discriminator {kind}') 24 | 25 | if kind == 'pix2pixhd_nlayer_multidilated': 26 | return MultidilatedNLayerDiscriminator(**kwargs) 27 | 28 | if kind == 'pix2pixhd_nlayer': 29 | return NLayerDiscriminator(**kwargs) 30 | 31 | raise ValueError(f'Unknown discriminator kind {kind}') 32 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, List 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 8 | from saicinpainting.training.modules.multidilated_conv import MultidilatedConv 9 | 10 | 11 | class BaseDiscriminator(nn.Module): 12 | @abc.abstractmethod 13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 14 | """ 15 | Predict scores and get intermediate activations. Useful for feature matching loss 16 | :return tuple (scores, list of intermediate activations) 17 | """ 18 | raise NotImplemented() 19 | 20 | 21 | def get_conv_block_ctor(kind='default'): 22 | if not isinstance(kind, str): 23 | return kind 24 | if kind == 'default': 25 | return nn.Conv2d 26 | if kind == 'depthwise': 27 | return DepthWiseSeperableConv 28 | if kind == 'multidilated': 29 | return MultidilatedConv 30 | raise ValueError(f'Unknown convolutional block kind {kind}') 31 | 32 | 33 | def get_norm_layer(kind='bn'): 34 | if not isinstance(kind, str): 35 | return kind 36 | if kind == 'bn': 37 | return nn.BatchNorm2d 38 | if kind == 'in': 39 | return nn.InstanceNorm2d 40 | raise ValueError(f'Unknown norm block kind {kind}') 41 | 42 | 43 | def get_activation(kind='tanh'): 44 | if kind == 'tanh': 45 | return nn.Tanh() 46 | if kind == 'sigmoid': 47 | return nn.Sigmoid() 48 | if kind is False: 49 | return nn.Identity() 50 | raise ValueError(f'Unknown activation kind {kind}') 51 | 52 | 53 | class SimpleMultiStepGenerator(nn.Module): 54 | def __init__(self, steps: List[nn.Module]): 55 | super().__init__() 56 | self.steps = nn.ModuleList(steps) 57 | 58 | def forward(self, x): 59 | cur_in = x 60 | outs = [] 61 | for step in self.steps: 62 | cur_out = step(cur_in) 63 | outs.append(cur_out) 64 | cur_in = torch.cat((cur_in, cur_out), dim=1) 65 | return torch.cat(outs[::-1], dim=1) 66 | 67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): 68 | if kind == 'convtranspose': 69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult), 70 | min(max_features, int(ngf * mult / 2)), 71 | kernel_size=3, stride=2, padding=1, output_padding=1), 72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 73 | elif kind == 'bilinear': 74 | return [nn.Upsample(scale_factor=2, mode='bilinear'), 75 | DepthWiseSeperableConv(min(max_features, ngf * mult), 76 | min(max_features, int(ngf * mult / 2)), 77 | kernel_size=3, stride=1, padding=1), 78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 79 | else: 80 | raise Exception(f"Invalid deconv kind: {kind}") -------------------------------------------------------------------------------- /saicinpainting/training/modules/depthwise_sep_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DepthWiseSeperableConv(nn.Module): 5 | def __init__(self, in_dim, out_dim, *args, **kwargs): 6 | super().__init__() 7 | if 'groups' in kwargs: 8 | # ignoring groups for Depthwise Sep Conv 9 | del kwargs['groups'] 10 | 11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) 12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) 13 | 14 | def forward(self, x): 15 | out = self.depthwise(x) 16 | out = self.pointwise(out) 17 | return out -------------------------------------------------------------------------------- /saicinpainting/training/modules/fake_fakes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import SamplePadding 3 | from kornia.augmentation import RandomAffine, CenterCrop 4 | 5 | 6 | class FakeFakesGenerator: 7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): 8 | self.grad_aug = RandomAffine(degrees=360, 9 | translate=0.2, 10 | padding_mode=SamplePadding.REFLECTION, 11 | keepdim=False, 12 | p=1) 13 | self.img_aug = RandomAffine(degrees=img_aug_degree, 14 | translate=img_aug_translate, 15 | padding_mode=SamplePadding.REFLECTION, 16 | keepdim=True, 17 | p=1) 18 | self.aug_proba = aug_proba 19 | 20 | def __call__(self, input_images, masks): 21 | blend_masks = self._fill_masks_with_gradient(masks) 22 | blend_target = self._make_blend_target(input_images) 23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks 24 | return result, blend_masks 25 | 26 | def _make_blend_target(self, input_images): 27 | batch_size = input_images.shape[0] 28 | permuted = input_images[torch.randperm(batch_size)] 29 | augmented = self.img_aug(input_images) 30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() 31 | result = augmented * is_aug + permuted * (1 - is_aug) 32 | return result 33 | 34 | def _fill_masks_with_gradient(self, masks): 35 | batch_size, _, height, width = masks.shape 36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ 37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) 38 | grad = self.grad_aug(grad) 39 | grad = CenterCrop((height, width))(grad) 40 | grad *= masks 41 | 42 | grad_for_min = grad + (1 - masks) * 10 43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] 44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 45 | grad.clamp_(min=0, max=1) 46 | 47 | return grad 48 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/multidilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 5 | 6 | class MultidilatedConv(nn.Module): 7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, 8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): 9 | super().__init__() 10 | convs = [] 11 | self.equal_dim = equal_dim 12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode 13 | if comb_mode in ('cat_out', 'cat_both'): 14 | self.cat_out = True 15 | if equal_dim: 16 | assert out_dim % dilation_num == 0 17 | out_dims = [out_dim // dilation_num] * dilation_num 18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) 19 | else: 20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 21 | out_dims.append(out_dim - sum(out_dims)) 22 | index = [] 23 | starts = [0] + out_dims[:-1] 24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] 25 | for i in range(out_dims[-1]): 26 | for j in range(dilation_num): 27 | index += list(range(starts[j], starts[j] + lengths[j])) 28 | starts[j] += lengths[j] 29 | self.index = index 30 | assert(len(index) == out_dim) 31 | self.out_dims = out_dims 32 | else: 33 | self.cat_out = False 34 | self.out_dims = [out_dim] * dilation_num 35 | 36 | if comb_mode in ('cat_in', 'cat_both'): 37 | if equal_dim: 38 | assert in_dim % dilation_num == 0 39 | in_dims = [in_dim // dilation_num] * dilation_num 40 | else: 41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 42 | in_dims.append(in_dim - sum(in_dims)) 43 | self.in_dims = in_dims 44 | self.cat_in = True 45 | else: 46 | self.cat_in = False 47 | self.in_dims = [in_dim] * dilation_num 48 | 49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d 50 | dilation = min_dilation 51 | for i in range(dilation_num): 52 | if isinstance(padding, int): 53 | cur_padding = padding * dilation 54 | else: 55 | cur_padding = padding[i] 56 | convs.append(conv_type( 57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs 58 | )) 59 | if i > 0 and shared_weights: 60 | convs[-1].weight = convs[0].weight 61 | convs[-1].bias = convs[0].bias 62 | dilation *= 2 63 | self.convs = nn.ModuleList(convs) 64 | 65 | self.shuffle_in_channels = shuffle_in_channels 66 | if self.shuffle_in_channels: 67 | # shuffle list as shuffling of tensors is nondeterministic 68 | in_channels_permute = list(range(in_dim)) 69 | random.shuffle(in_channels_permute) 70 | # save as buffer so it is saved and loaded with checkpoint 71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) 72 | 73 | def forward(self, x): 74 | if self.shuffle_in_channels: 75 | x = x[:, self.in_channels_permute] 76 | 77 | outs = [] 78 | if self.cat_in: 79 | if self.equal_dim: 80 | x = x.chunk(len(self.convs), dim=1) 81 | else: 82 | new_x = [] 83 | start = 0 84 | for dim in self.in_dims: 85 | new_x.append(x[:, start:start+dim]) 86 | start += dim 87 | x = new_x 88 | for i, conv in enumerate(self.convs): 89 | if self.cat_in: 90 | input = x[i] 91 | else: 92 | input = x 93 | outs.append(conv(input)) 94 | if self.cat_out: 95 | out = torch.cat(outs, dim=1)[:, self.index] 96 | else: 97 | out = sum(outs) 98 | return out 99 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/spatial_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from kornia.geometry.transform import rotate 5 | 6 | 7 | class LearnableSpatialTransformWrapper(nn.Module): 8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): 9 | super().__init__() 10 | self.impl = impl 11 | self.angle = torch.rand(1) * angle_init_range 12 | if train_angle: 13 | self.angle = nn.Parameter(self.angle, requires_grad=True) 14 | self.pad_coef = pad_coef 15 | 16 | def forward(self, x): 17 | if torch.is_tensor(x): 18 | return self.inverse_transform(self.impl(self.transform(x)), x) 19 | elif isinstance(x, tuple): 20 | x_trans = tuple(self.transform(elem) for elem in x) 21 | y_trans = self.impl(x_trans) 22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) 23 | else: 24 | raise ValueError(f'Unexpected input type {type(x)}') 25 | 26 | def transform(self, x): 27 | height, width = x.shape[2:] 28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') 30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) 31 | return x_padded_rotated 32 | 33 | def inverse_transform(self, y_padded_rotated, orig_x): 34 | height, width = orig_x.shape[2:] 35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 36 | 37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) 38 | y_height, y_width = y_padded.shape[2:] 39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] 40 | return y 41 | 42 | 43 | if __name__ == '__main__': 44 | layer = LearnableSpatialTransformWrapper(nn.Identity()) 45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() 46 | y = layer(x) 47 | assert x.shape == y.shape 48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) 49 | print('all ok') 50 | -------------------------------------------------------------------------------- /saicinpainting/training/modules/squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | res = x * y.expand_as(x) 20 | return res 21 | -------------------------------------------------------------------------------- /saicinpainting/training/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule 4 | 5 | 6 | def get_training_model_class(kind): 7 | if kind == 'default': 8 | return DefaultInpaintingTrainingModule 9 | 10 | raise ValueError(f'Unknown trainer module {kind}') 11 | 12 | 13 | def make_training_model(config): 14 | kind = config.training_model.kind 15 | kwargs = dict(config.training_model) 16 | kwargs.pop('kind') 17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' 18 | 19 | logging.info(f'Make training model {kind}') 20 | 21 | cls = get_training_model_class(kind) 22 | return cls(config, **kwargs) 23 | 24 | 25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True): 26 | model: torch.nn.Module = make_training_model(train_config) 27 | state = torch.load(path, map_location=map_location) 28 | model.load_state_dict(state['state_dict'], strict=strict) 29 | model.on_load_checkpoint(state) 30 | return model 31 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from saicinpainting.training.visualizers.directory import DirectoryVisualizer 4 | from saicinpainting.training.visualizers.noop import NoopVisualizer 5 | 6 | 7 | def make_visualizer(kind, **kwargs): 8 | logging.info(f'Make visualizer {kind}') 9 | 10 | if kind == 'directory': 11 | return DirectoryVisualizer(**kwargs) 12 | if kind == 'noop': 13 | return NoopVisualizer() 14 | 15 | raise ValueError(f'Unknown visualizer kind {kind}') 16 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | from skimage import color 7 | from skimage.segmentation import mark_boundaries 8 | 9 | from . import colors 10 | 11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation 12 | 13 | 14 | class BaseVisualizer: 15 | @abc.abstractmethod 16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 17 | """ 18 | Take a batch, make an image from it and visualize 19 | """ 20 | raise NotImplementedError() 21 | 22 | 23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], 24 | last_without_mask=True, rescale_keys=None, mask_only_first=None, 25 | black_mask=False) -> np.ndarray: 26 | mask = images_dict['mask'] > 0.5 27 | result = [] 28 | for i, k in enumerate(keys): 29 | img = images_dict[k] 30 | img = np.transpose(img, (1, 2, 0)) 31 | 32 | if rescale_keys is not None and k in rescale_keys: 33 | img = img - img.min() 34 | img /= img.max() + 1e-5 35 | if len(img.shape) == 2: 36 | img = np.expand_dims(img, 2) 37 | 38 | if img.shape[2] == 1: 39 | img = np.repeat(img, 3, axis=2) 40 | elif (img.shape[2] > 3): 41 | img_classes = img.argmax(2) 42 | img = color.label2rgb(img_classes, colors=COLORS) 43 | 44 | if mask_only_first: 45 | need_mark_boundaries = i == 0 46 | else: 47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask 48 | 49 | if need_mark_boundaries: 50 | if black_mask: 51 | img = img * (1 - mask[0][..., None]) 52 | img = mark_boundaries(img, 53 | mask[0], 54 | color=(1., 0., 0.), 55 | outline_color=(1., 1., 1.), 56 | mode='thick') 57 | result.append(img) 58 | return np.concatenate(result, axis=1) 59 | 60 | 61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, 62 | last_without_mask=True, rescale_keys=None) -> np.ndarray: 63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() 64 | if k in keys or k == 'mask'} 65 | 66 | batch_size = next(iter(batch.values())).shape[0] 67 | items_to_vis = min(batch_size, max_items) 68 | result = [] 69 | for i in range(items_to_vis): 70 | cur_dct = {k: tens[i] for k, tens in batch.items()} 71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, 72 | rescale_keys=rescale_keys)) 73 | return np.concatenate(result, axis=0) 74 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/colors.py: -------------------------------------------------------------------------------- 1 | import random 2 | import colorsys 3 | 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | from matplotlib.colors import LinearSegmentedColormap 9 | 10 | 11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): 12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib 13 | """ 14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks 15 | :param nlabels: Number of labels (size of colormap) 16 | :param type: 'bright' for strong colors, 'soft' for pastel colors 17 | :param first_color_black: Option to use first color as black, True or False 18 | :param last_color_black: Option to use last color as black, True or False 19 | :param verbose: Prints the number of labels and shows the colormap. True or False 20 | :return: colormap for matplotlib 21 | """ 22 | if type not in ('bright', 'soft'): 23 | print ('Please choose "bright" or "soft" for type') 24 | return 25 | 26 | if verbose: 27 | print('Number of labels: ' + str(nlabels)) 28 | 29 | # Generate color map for bright colors, based on hsv 30 | if type == 'bright': 31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1), 32 | np.random.uniform(low=0.2, high=1), 33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] 34 | 35 | # Convert HSV list to RGB 36 | randRGBcolors = [] 37 | for HSVcolor in randHSVcolors: 38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) 39 | 40 | if first_color_black: 41 | randRGBcolors[0] = [0, 0, 0] 42 | 43 | if last_color_black: 44 | randRGBcolors[-1] = [0, 0, 0] 45 | 46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 47 | 48 | # Generate soft pastel colors, by limiting the RGB spectrum 49 | if type == 'soft': 50 | low = 0.6 51 | high = 0.95 52 | randRGBcolors = [(np.random.uniform(low=low, high=high), 53 | np.random.uniform(low=low, high=high), 54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)] 55 | 56 | if first_color_black: 57 | randRGBcolors[0] = [0, 0, 0] 58 | 59 | if last_color_black: 60 | randRGBcolors[-1] = [0, 0, 0] 61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 62 | 63 | # Display colorbar 64 | if verbose: 65 | from matplotlib import colors, colorbar 66 | from matplotlib import pyplot as plt 67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) 68 | 69 | bounds = np.linspace(0, nlabels, nlabels + 1) 70 | norm = colors.BoundaryNorm(bounds, nlabels) 71 | 72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, 73 | boundaries=bounds, format='%1i', orientation=u'horizontal') 74 | 75 | return randRGBcolors, random_colormap 76 | 77 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch 7 | from saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | class DirectoryVisualizer(BaseVisualizer): 11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') 12 | 13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, 14 | last_without_mask=True, rescale_keys=None): 15 | self.outdir = outdir 16 | os.makedirs(self.outdir, exist_ok=True) 17 | self.key_order = key_order 18 | self.max_items_in_batch = max_items_in_batch 19 | self.last_without_mask = last_without_mask 20 | self.rescale_keys = rescale_keys 21 | 22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') 24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, 25 | last_without_mask=self.last_without_mask, 26 | rescale_keys=self.rescale_keys) 27 | 28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') 29 | 30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') 31 | os.makedirs(curoutdir, exist_ok=True) 32 | rank_suffix = f'_r{rank}' if rank is not None else '' 33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') 34 | 35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) 36 | cv2.imwrite(out_fname, vis_img) 37 | -------------------------------------------------------------------------------- /saicinpainting/training/visualizers/noop.py: -------------------------------------------------------------------------------- 1 | from saicinpainting.training.visualizers.base import BaseVisualizer 2 | 3 | 4 | class NoopVisualizer(BaseVisualizer): 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 9 | pass 10 | -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | parser.add_argument( 54 | "--ckpt", 55 | type=str, 56 | default="/data0/zixin/cluster_results/epoch=000001.ckpt", 57 | help="dir to save checkpoint", 58 | ) 59 | parser.add_argument( 60 | "--config", 61 | type=str, 62 | default="configs/v1_my.yaml", 63 | help="dir to obtain config", 64 | ) 65 | opt = parser.parse_args() 66 | 67 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 68 | images = [x.replace("_mask.png", ".png") for x in masks] 69 | print(f"Found {len(masks)} inputs.") 70 | 71 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 72 | model = instantiate_from_config(config.model) 73 | 74 | config1 = OmegaConf.load(opt.config) 75 | first_stage_model = load_model_from_config(config1, opt.ckpt) 76 | 77 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 78 | strict=False) 79 | 80 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 81 | model = model.to(device) 82 | first_stage_model = first_stage_model.to(device) 83 | sampler = DDIMSampler(model) 84 | 85 | os.makedirs(opt.outdir, exist_ok=True) 86 | with torch.no_grad(): 87 | with model.ema_scope(): 88 | for image, mask in tqdm(zip(images, masks)): 89 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 90 | batch = make_batch(image, mask, device=device) 91 | 92 | # encode masked image and concat downsampled mask 93 | c = model.cond_stage_model.encode(batch["masked_image"]) 94 | cc = torch.nn.functional.interpolate(batch["mask"], 95 | size=c.shape[-2:]) 96 | c = torch.cat((c, cc), dim=1) 97 | 98 | shape = (c.shape[1]-1,)+c.shape[2:] 99 | samples_ddim, _ = sampler.sample(S=opt.steps, 100 | conditioning=c, 101 | batch_size=c.shape[0], 102 | shape=shape, 103 | verbose=False) 104 | 105 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 106 | min=0.0, max=1.0) 107 | x_samples_ddim = model.decode_first_stage(samples_ddim, batch["image"], mask) 108 | image = torch.clamp((batch["image"]+1.0)/2.0, 109 | min=0.0, max=1.0) 110 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 111 | min=0.0, max=1.0) 112 | 113 | inpainted = (1-mask)*image+mask*predicted_image 114 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 115 | # im = Image.fromarray(inpainted.astype(np.uint8)) 116 | # im.show() 117 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/teaser.png -------------------------------------------------------------------------------- /text2img_visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/text2img_visual.png -------------------------------------------------------------------------------- /visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/visual.png --------------------------------------------------------------------------------