├── .idea
├── .gitignore
├── Asymmetric_VQGAN.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── assets
├── a-painting-of-a-fire.png
├── a-photograph-of-a-fire.png
├── a-shirt-with-a-fire-printed-on-it.png
├── a-shirt-with-the-inscription-'fire'.png
├── a-watercolor-painting-of-a-fire.png
├── birdhouse.png
├── fire.png
├── inpainting.png
├── modelfigure.png
├── rdm-preview.jpg
├── reconstruction1.png
├── reconstruction2.png
├── results.gif
├── the-earth-is-on-fire,-oil-on-canvas.png
├── txt2img-convsample.png
└── txt2img-preview.png
├── configs
├── autoencoder
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_32x32x4_large.yaml
│ ├── autoencoder_kl_32x32x4_large2.yaml
│ ├── autoencoder_kl_32x32x4_large2_train.yaml
│ ├── autoencoder_kl_32x32x4_large_train.yaml
│ ├── autoencoder_kl_32x32x4_train.yaml
│ ├── autoencoder_kl_woc_32x32x4.yaml
│ ├── autoencoder_kl_woc_32x32x4_large.yaml
│ ├── autoencoder_kl_woc_32x32x4_large2.yaml
│ ├── eval2_gpu.yaml
│ ├── random_thick_256.yaml
│ ├── v1-inpainting-inference.yaml
│ └── v1-t2i-inference.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
└── retrieval-augmented-diffusion
│ └── 768x768.yaml
├── data
├── DejaVuSans.ttf
├── example_conditioning
│ ├── superresolution
│ │ └── sample_0.jpg
│ └── text_conditional
│ │ └── sample_0.txt
├── imagenet_clsidx_to_label.txt
├── imagenet_train_hr_indices.p
├── imagenet_val_hr_indices.p
├── index_synset.yaml
└── inpainting_examples
│ ├── 6458524847_2f4c361183_k.png
│ ├── 6458524847_2f4c361183_k_mask.png
│ ├── 8399166846_f6fb4e4b8e_k.png
│ ├── 8399166846_f6fb4e4b8e_k_mask.png
│ ├── alex-iby-G_Pk4D9rMLs.png
│ ├── alex-iby-G_Pk4D9rMLs_mask.png
│ ├── bench2.png
│ ├── bench2_mask.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png
│ ├── billow926-12-Wc-Zgx6Y.png
│ ├── billow926-12-Wc-Zgx6Y_mask.png
│ ├── overture-creations-5sI6fQgYIuo.png
│ ├── overture-creations-5sI6fQgYIuo_mask.png
│ ├── photo-1583445095369-9c651e7e5d34.png
│ └── photo-1583445095369-9c651e7e5d34_mask.png
├── demo.pkl
├── environment.yaml
├── inpaint_st.py
├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ ├── autoencoder_large.py
│ ├── autoencoder_large2.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── my_decoder.py
│ │ ├── my_decoder_large.py
│ │ ├── my_decoder_large2.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── main.py
├── mini.py
├── models
├── ade20k
│ ├── __init__.py
│ ├── base.py
│ ├── color150.mat
│ ├── mobilenet.py
│ ├── object150_info.csv
│ ├── resnet.py
│ ├── segm_lib
│ │ ├── nn
│ │ │ ├── __init__.py
│ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── batchnorm.py
│ │ │ │ ├── comm.py
│ │ │ │ ├── replicate.py
│ │ │ │ ├── tests
│ │ │ │ │ ├── test_numeric_batchnorm.py
│ │ │ │ │ └── test_sync_batchnorm.py
│ │ │ │ └── unittest.py
│ │ │ └── parallel
│ │ │ │ ├── __init__.py
│ │ │ │ └── data_parallel.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── dataloader.py
│ │ │ ├── dataset.py
│ │ │ ├── distributed.py
│ │ │ └── sampler.py
│ │ │ └── th.py
│ └── utils.py
├── first_stage_models
│ ├── kl-f16
│ │ └── config.yaml
│ ├── kl-f32
│ │ └── config.yaml
│ ├── kl-f4
│ │ └── config.yaml
│ ├── kl-f8
│ │ └── config.yaml
│ ├── vq-f16
│ │ └── config.yaml
│ ├── vq-f4-noattn
│ │ └── config.yaml
│ ├── vq-f4
│ │ └── config.yaml
│ ├── vq-f8-n256
│ │ └── config.yaml
│ └── vq-f8
│ │ └── config.yaml
├── ldm
│ ├── bsr_sr
│ │ └── config.yaml
│ ├── celeba256
│ │ └── config.yaml
│ ├── cin256
│ │ └── config.yaml
│ ├── ffhq256
│ │ └── config.yaml
│ ├── inpainting_big
│ │ └── config.yaml
│ ├── layout2img-openimages256
│ │ └── config.yaml
│ ├── lsun_beds256
│ │ └── config.yaml
│ ├── lsun_churches256
│ │ └── config.yaml
│ ├── semantic_synthesis256
│ │ └── config.yaml
│ ├── semantic_synthesis512
│ │ └── config.yaml
│ └── text2img256
│ │ └── config.yaml
└── lpips_models
│ ├── alex.pth
│ ├── squeeze.pth
│ └── vgg.pth
├── notebook_helpers.py
├── outputs
└── inpainting_results
│ ├── 6458524847_2f4c361183_k.png
│ ├── 8399166846_f6fb4e4b8e_k.png
│ ├── alex-iby-G_Pk4D9rMLs.png
│ ├── bench2.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0.png
│ ├── billow926-12-Wc-Zgx6Y.png
│ ├── overture-creations-5sI6fQgYIuo.png
│ └── photo-1583445095369-9c651e7e5d34.png
├── requirements.txt
├── saicinpainting
├── __init__.py
├── evaluation
│ ├── __init__.py
│ ├── data.py
│ ├── evaluator.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── base_loss.py
│ │ ├── fid
│ │ │ ├── __init__.py
│ │ │ ├── fid_score.py
│ │ │ └── inception.py
│ │ ├── lpips.py
│ │ └── ssim.py
│ ├── masks
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── countless
│ │ │ ├── .gitignore
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── countless2d.py
│ │ │ ├── countless3d.py
│ │ │ ├── images
│ │ │ │ ├── gcim.jpg
│ │ │ │ ├── gray_segmentation.png
│ │ │ │ ├── segmentation.png
│ │ │ │ └── sparse.png
│ │ │ ├── memprof
│ │ │ │ ├── countless2d_gcim_N_1000.png
│ │ │ │ ├── countless2d_quick_gcim_N_1000.png
│ │ │ │ ├── countless3d.png
│ │ │ │ ├── countless3d_dynamic.png
│ │ │ │ ├── countless3d_dynamic_generalized.png
│ │ │ │ └── countless3d_generalized.png
│ │ │ ├── requirements.txt
│ │ │ └── test.py
│ │ └── mask.py
│ ├── refinement.py
│ ├── utils.py
│ └── vis.py
├── training
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── aug.py
│ │ ├── datasets.py
│ │ └── masks.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── adversarial.py
│ │ ├── constants.py
│ │ ├── distance_weighting.py
│ │ ├── feature_matching.py
│ │ ├── perceptual.py
│ │ ├── segmentation.py
│ │ └── style_loss.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── depthwise_sep_conv.py
│ │ ├── fake_fakes.py
│ │ ├── ffc.py
│ │ ├── multidilated_conv.py
│ │ ├── multiscale.py
│ │ ├── pix2pixhd.py
│ │ ├── spatial_transform.py
│ │ └── squeeze_excitation.py
│ ├── trainers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── default.py
│ └── visualizers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── colors.py
│ │ ├── directory.py
│ │ └── noop.py
└── utils.py
├── scripts
├── download_first_stages.sh
├── download_models.sh
├── inpaint.py
├── inpaint_st.py
├── knn2img.py
├── latent_imagenet_diffusion.ipynb
├── sample_diffusion.py
└── train_searcher.py
├── setup.py
├── taming
└── modules
│ └── autoencoder
│ └── lpips
│ └── vgg.pth
├── teaser.png
├── text2img_visual.png
├── txt2img.py
└── visual.png
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/Asymmetric_VQGAN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assets/a-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-painting-of-a-fire.png
--------------------------------------------------------------------------------
/assets/a-photograph-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-photograph-of-a-fire.png
--------------------------------------------------------------------------------
/assets/a-shirt-with-a-fire-printed-on-it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-shirt-with-a-fire-printed-on-it.png
--------------------------------------------------------------------------------
/assets/a-shirt-with-the-inscription-'fire'.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-shirt-with-the-inscription-'fire'.png
--------------------------------------------------------------------------------
/assets/a-watercolor-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/a-watercolor-painting-of-a-fire.png
--------------------------------------------------------------------------------
/assets/birdhouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/birdhouse.png
--------------------------------------------------------------------------------
/assets/fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/fire.png
--------------------------------------------------------------------------------
/assets/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/inpainting.png
--------------------------------------------------------------------------------
/assets/modelfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/modelfigure.png
--------------------------------------------------------------------------------
/assets/rdm-preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/rdm-preview.jpg
--------------------------------------------------------------------------------
/assets/reconstruction1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/reconstruction1.png
--------------------------------------------------------------------------------
/assets/reconstruction2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/reconstruction2.png
--------------------------------------------------------------------------------
/assets/results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/results.gif
--------------------------------------------------------------------------------
/assets/the-earth-is-on-fire,-oil-on-canvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/the-earth-is-on-fire,-oil-on-canvas.png
--------------------------------------------------------------------------------
/assets/txt2img-convsample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/txt2img-convsample.png
--------------------------------------------------------------------------------
/assets/txt2img-preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/assets/txt2img-preview.png
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 10
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4_large.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 4
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4_large2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large2.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 2
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4_large2_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large2.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | ckpt_path: stable_vqgan.ckpt
7 | ignore_keys: ["decoder"]
8 | embed_dim: 4
9 | num_gpus: 8
10 |
11 | scheduler_config:
12 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0
15 | warm_up_steps: 5000
16 | lr_start: 0.01
17 | lr_max: 1
18 |
19 | scheduler_config_d:
20 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
21 | params:
22 | verbosity_interval: 0
23 | warm_up_steps: 5000
24 | lr_start: 0.01
25 | lr_max: 1
26 |
27 | lossconfig:
28 | target: ldm.modules.losses.LPIPSWithDiscriminator
29 | params:
30 | disc_start: 50001
31 | kl_weight: 0.000001
32 | disc_weight: 0.8
33 |
34 | ddconfig:
35 | double_z: True
36 | z_channels: 4
37 | resolution: 256
38 | in_channels: 3
39 | out_ch: 3
40 | ch: 128
41 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
42 | num_res_blocks: 2
43 | attn_resolutions: [ ]
44 | dropout: 0.0
45 |
46 | data:
47 | target: main.DataModuleFromConfig
48 | params:
49 | batch_size: 2
50 | wrap: True
51 | train:
52 | target: ldm.data.imagenet.ImageNetSRTrain
53 | params:
54 | size: 256
55 | degradation: pil_nearest
56 | Cf_r: 0.5
57 | validation:
58 | target: ldm.data.imagenet.ImageNetsv
59 | params:
60 | indir: /mnt/output/myvalsam
61 |
62 | lightning:
63 | callbacks:
64 | image_logger:
65 | target: main.ImageLogger
66 | params:
67 | batch_frequency: 1000
68 | max_images: 8
69 | increase_log_steps: True
70 |
71 | trainer:
72 | benchmark: True
73 | accumulate_grad_batches: 1
74 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4_large_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | ckpt_path: stable_vqgan.ckpt
7 | ignore_keys: ["decoder"]
8 | embed_dim: 4
9 | num_gpus: 8
10 |
11 | scheduler_config:
12 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
13 | params:
14 | verbosity_interval: 0
15 | warm_up_steps: 5000
16 | lr_start: 0.01
17 | lr_max: 1
18 |
19 | scheduler_config_d:
20 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
21 | params:
22 | verbosity_interval: 0
23 | warm_up_steps: 5000
24 | lr_start: 0.01
25 | lr_max: 1
26 |
27 | lossconfig:
28 | target: ldm.modules.losses.LPIPSWithDiscriminator
29 | params:
30 | disc_start: 50001
31 | kl_weight: 0.000001
32 | disc_weight: 0.8
33 |
34 | ddconfig:
35 | double_z: True
36 | z_channels: 4
37 | resolution: 256
38 | in_channels: 3
39 | out_ch: 3
40 | ch: 128
41 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
42 | num_res_blocks: 2
43 | attn_resolutions: [ ]
44 | dropout: 0.0
45 |
46 | data:
47 | target: main.DataModuleFromConfig
48 | params:
49 | batch_size: 4
50 | wrap: True
51 | train:
52 | target: ldm.data.imagenet.ImageNetSRTrain
53 | params:
54 | size: 256
55 | degradation: pil_nearest
56 | Cf_r: 0.5
57 | validation:
58 | target: ldm.data.imagenet.ImageNetsv
59 | params:
60 | indir: /mnt/output/myvalsam
61 |
62 | lightning:
63 | callbacks:
64 | image_logger:
65 | target: main.ImageLogger
66 | params:
67 | batch_frequency: 1000
68 | max_images: 8
69 | increase_log_steps: True
70 |
71 | trainer:
72 | benchmark: True
73 | accumulate_grad_batches: 1
74 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4_train.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | # monitor: val/rec_loss
6 | ckpt_path: stable_vqgan.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 10
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_woc_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL_wocondition
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 10
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_woc_32x32x4_large.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large.AutoencoderKL_wocondition
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 4
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_woc_32x32x4_large2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder_large2.AutoencoderKL_wocondition
4 | params:
5 | # monitor: val/rec_loss
6 | # ckpt_path: /mnt/output/pre_models/stable_diff_vqauto_text.ckpt
7 | embed_dim: 4
8 | num_gpus: 8
9 |
10 | scheduler_config:
11 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
12 | params:
13 | verbosity_interval: 0
14 | warm_up_steps: 5000
15 | lr_start: 0.01
16 | lr_max: 1
17 |
18 | scheduler_config_d:
19 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler_D
20 | params:
21 | verbosity_interval: 0
22 | warm_up_steps: 5000
23 | lr_start: 0.01
24 | lr_max: 1
25 |
26 | lossconfig:
27 | target: ldm.modules.losses.LPIPSWithDiscriminator
28 | params:
29 | disc_start: 50001
30 | kl_weight: 0.000001
31 | disc_weight: 0.8
32 |
33 | ddconfig:
34 | double_z: True
35 | z_channels: 4
36 | resolution: 256
37 | in_channels: 3
38 | out_ch: 3
39 | ch: 128
40 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
41 | num_res_blocks: 2
42 | attn_resolutions: [ ]
43 | dropout: 0.0
44 |
45 | data:
46 | target: main.DataModuleFromConfig
47 | params:
48 | batch_size: 2
49 | wrap: True
50 | train:
51 | target: ldm.data.imagenet.ImageNetSRTrain
52 | params:
53 | size: 256
54 | degradation: pil_nearest
55 | Cf_r: 0.5
56 | validation:
57 | target: ldm.data.imagenet.ImageNetsv
58 | params:
59 | indir: /mnt/output/myvalsam
60 |
61 | lightning:
62 | callbacks:
63 | image_logger:
64 | target: main.ImageLogger
65 | params:
66 | batch_frequency: 1000
67 | max_images: 8
68 | increase_log_steps: True
69 |
70 | trainer:
71 | benchmark: True
72 | accumulate_grad_batches: 1
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/eval2_gpu.yaml:
--------------------------------------------------------------------------------
1 | evaluator_kwargs:
2 | batch_size: 8
3 |
4 | dataset_kwargs:
5 | img_suffix: .png
6 | inpainted_suffix: .png
--------------------------------------------------------------------------------
/configs/autoencoder/random_thick_256.yaml:
--------------------------------------------------------------------------------
1 | generator_kind: random
2 |
3 | mask_generator_kwargs:
4 | irregular_proba: 1
5 | irregular_kwargs:
6 | min_times: 1
7 | max_times: 5
8 | max_width: 100
9 | max_angle: 4
10 | max_len: 200
11 |
12 | box_proba: 0.3
13 | box_kwargs:
14 | margin: 10
15 | bbox_min_size: 30
16 | bbox_max_size: 150
17 | max_times: 3
18 | min_times: 1
19 |
20 | segm_proba: 0
21 | squares_proba: 0
22 |
23 | variants_n: 5
24 |
25 | max_masks_per_image: 1
26 |
27 | cropping:
28 | out_min_size: 256
29 | handle_small_mode: upscale
30 | out_square_crop: True
31 | crop_min_overlap: 1
32 |
33 | max_tamper_area: 0.5
34 |
--------------------------------------------------------------------------------
/configs/autoencoder/v1-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | # conditioning_key: concat
17 | monitor: val/loss_simple_ema
18 | scale_factor: 0.18215
19 | finetune_keys: null
20 |
21 | scheduler_config: # 10000 warmup steps
22 | target: ldm.lr_scheduler.LambdaLinearScheduler
23 | params:
24 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26 | f_start: [ 1.e-6 ]
27 | f_max: [ 1. ]
28 | f_min: [ 1. ]
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32 | params:
33 | image_size: 32 # unused
34 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
35 | out_channels: 4
36 | model_channels: 320
37 | attention_resolutions: [ 4, 2, 1 ]
38 | num_res_blocks: 2
39 | channel_mult: [ 1, 2, 4, 4 ]
40 | num_heads: 8
41 | use_spatial_transformer: True
42 | transformer_depth: 1
43 | context_dim: 768
44 | use_checkpoint: True
45 | legacy: False
46 |
47 | first_stage_config:
48 | target: ldm.models.autoencoder.AutoencoderKL
49 | params:
50 | embed_dim: 4
51 | monitor: val/rec_loss
52 | ddconfig:
53 | double_z: true
54 | z_channels: 4
55 | resolution: 256
56 | in_channels: 3
57 | out_ch: 3
58 | ch: 128
59 | ch_mult:
60 | - 1
61 | - 2
62 | - 4
63 | - 4
64 | num_res_blocks: 2
65 | attn_resolutions: []
66 | dropout: 0.0
67 | lossconfig:
68 | target: torch.nn.Identity
69 |
70 | cond_stage_config:
71 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72 |
73 |
--------------------------------------------------------------------------------
/configs/autoencoder/v1-t2i-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/data/example_conditioning/superresolution/sample_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/example_conditioning/superresolution/sample_0.jpg
--------------------------------------------------------------------------------
/data/example_conditioning/text_conditional/sample_0.txt:
--------------------------------------------------------------------------------
1 | A basket of cerries
2 |
--------------------------------------------------------------------------------
/data/imagenet_train_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/imagenet_train_hr_indices.p
--------------------------------------------------------------------------------
/data/imagenet_val_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/imagenet_val_hr_indices.p
--------------------------------------------------------------------------------
/data/inpainting_examples/6458524847_2f4c361183_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/6458524847_2f4c361183_k.png
--------------------------------------------------------------------------------
/data/inpainting_examples/6458524847_2f4c361183_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/6458524847_2f4c361183_k_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png
--------------------------------------------------------------------------------
/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png
--------------------------------------------------------------------------------
/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bench2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bench2.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bench2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bench2_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png
--------------------------------------------------------------------------------
/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png
--------------------------------------------------------------------------------
/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
--------------------------------------------------------------------------------
/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png
--------------------------------------------------------------------------------
/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png
--------------------------------------------------------------------------------
/demo.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/demo.pkl
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | #name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - easydict==1.10
15 | - scikit-learn==1.2.0
16 | - opencv-python==4.1.2.30
17 | - pudb==2019.2
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - transformers==4.6.0
27 | - torchmetrics==0.6
28 | - academictorrents==2.3.3
29 | # - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | # - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31 | # - -e .
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/mini.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.backends.cudnn as cudnn
4 | import os
5 | import random
6 | import argparse
7 | import numpy as np
8 | import datetime
9 | import warnings
10 |
11 | warnings.filterwarnings('ignore',
12 | 'Argument interpolation should be of type InterpolationMode instead of int',
13 | UserWarning)
14 | warnings.filterwarnings('ignore',
15 | 'Leaking Caffe2 thread-pool after fork',
16 | UserWarning)
17 |
18 |
19 | def init_distributed_mode(args):
20 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
21 | args.rank = int(os.environ["RANK"])
22 | args.world_size = int(os.environ['WORLD_SIZE'])
23 | args.gpu = int(os.environ['LOCAL_RANK'])
24 | else:
25 | print('Not using distributed mode')
26 | args.distributed = False
27 | return
28 | args.distributed = True
29 |
30 | torch.cuda.set_device(args.gpu)
31 | args.dist_backend = 'nccl'
32 | print('| distributed init (rank {}): {}'.format(
33 | args.rank, args.dist_url), flush=True)
34 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
35 | world_size=args.world_size, rank=args.rank,
36 | timeout=datetime.timedelta(seconds=3600))
37 | torch.distributed.barrier()
38 |
39 |
40 | def is_dist_avail_and_initialized():
41 | if not dist.is_available():
42 | return False
43 | if not dist.is_initialized():
44 | return False
45 | return True
46 |
47 |
48 | def get_rank():
49 | if not is_dist_avail_and_initialized():
50 | return 0
51 | return dist.get_rank()
52 |
53 |
54 | def is_main_process():
55 | return get_rank() == 0
56 |
57 |
58 | def get_args_parser():
59 | parser = argparse.ArgumentParser('training and evaluation script', add_help=False)
60 |
61 | # distributed training parameters
62 | parser.add_argument('--world_size', default=1, type=int,
63 | help='number of distributed processes')
64 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
65 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
66 | parser.add_argument('--device', default='cuda', help='device to use for training / testing')
67 | parser.add_argument('--seed', default=0, type=int)
68 | return parser
69 |
70 |
71 | def main(args):
72 | os.environ['LOCAL_RANK'] = str(args.local_rank)
73 | init_distributed_mode(args)
74 | device = torch.device(args.device)
75 | print(args)
76 |
77 | # fix the seed for reproducibility
78 | seed = args.seed + get_rank()
79 | torch.manual_seed(seed)
80 | np.random.seed(seed)
81 | random.seed(seed)
82 | cudnn.benchmark = True
83 |
84 | print(" we reach the post init ")
85 |
86 |
87 | if __name__ == '__main__':
88 | parser = argparse.ArgumentParser('training and evaluation script', parents=[get_args_parser()])
89 | args = parser.parse_args()
90 | main(args)
91 |
--------------------------------------------------------------------------------
/models/ade20k/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
--------------------------------------------------------------------------------
/models/ade20k/color150.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/ade20k/color150.mat
--------------------------------------------------------------------------------
/models/ade20k/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This MobileNetV2 implementation is modified from the following repository:
3 | https://github.com/tonylins/pytorch-mobilenet-v2
4 | """
5 |
6 | import torch.nn as nn
7 | import math
8 | from .utils import load_url
9 | from .segm_lib.nn import SynchronizedBatchNorm2d
10 |
11 | BatchNorm2d = SynchronizedBatchNorm2d
12 |
13 |
14 | __all__ = ['mobilenetv2']
15 |
16 |
17 | model_urls = {
18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
19 | }
20 |
21 |
22 | def conv_bn(inp, oup, stride):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25 | BatchNorm2d(oup),
26 | nn.ReLU6(inplace=True)
27 | )
28 |
29 |
30 | def conv_1x1_bn(inp, oup):
31 | return nn.Sequential(
32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33 | BatchNorm2d(oup),
34 | nn.ReLU6(inplace=True)
35 | )
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride, expand_ratio):
40 | super(InvertedResidual, self).__init__()
41 | self.stride = stride
42 | assert stride in [1, 2]
43 |
44 | hidden_dim = round(inp * expand_ratio)
45 | self.use_res_connect = self.stride == 1 and inp == oup
46 |
47 | if expand_ratio == 1:
48 | self.conv = nn.Sequential(
49 | # dw
50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
51 | BatchNorm2d(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # pw-linear
54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
55 | BatchNorm2d(oup),
56 | )
57 | else:
58 | self.conv = nn.Sequential(
59 | # pw
60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
61 | BatchNorm2d(hidden_dim),
62 | nn.ReLU6(inplace=True),
63 | # dw
64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
65 | BatchNorm2d(hidden_dim),
66 | nn.ReLU6(inplace=True),
67 | # pw-linear
68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69 | BatchNorm2d(oup),
70 | )
71 |
72 | def forward(self, x):
73 | if self.use_res_connect:
74 | return x + self.conv(x)
75 | else:
76 | return self.conv(x)
77 |
78 |
79 | class MobileNetV2(nn.Module):
80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
81 | super(MobileNetV2, self).__init__()
82 | block = InvertedResidual
83 | input_channel = 32
84 | last_channel = 1280
85 | interverted_residual_setting = [
86 | # t, c, n, s
87 | [1, 16, 1, 1],
88 | [6, 24, 2, 2],
89 | [6, 32, 3, 2],
90 | [6, 64, 4, 2],
91 | [6, 96, 3, 1],
92 | [6, 160, 3, 2],
93 | [6, 320, 1, 1],
94 | ]
95 |
96 | # building first layer
97 | assert input_size % 32 == 0
98 | input_channel = int(input_channel * width_mult)
99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
100 | self.features = [conv_bn(3, input_channel, 2)]
101 | # building inverted residual blocks
102 | for t, c, n, s in interverted_residual_setting:
103 | output_channel = int(c * width_mult)
104 | for i in range(n):
105 | if i == 0:
106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
107 | else:
108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
109 | input_channel = output_channel
110 | # building last several layers
111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
112 | # make it nn.Sequential
113 | self.features = nn.Sequential(*self.features)
114 |
115 | # building classifier
116 | self.classifier = nn.Sequential(
117 | nn.Dropout(0.2),
118 | nn.Linear(self.last_channel, n_class),
119 | )
120 |
121 | self._initialize_weights()
122 |
123 | def forward(self, x):
124 | x = self.features(x)
125 | x = x.mean(3).mean(2)
126 | x = self.classifier(x)
127 | return x
128 |
129 | def _initialize_weights(self):
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133 | m.weight.data.normal_(0, math.sqrt(2. / n))
134 | if m.bias is not None:
135 | m.bias.data.zero_()
136 | elif isinstance(m, BatchNorm2d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 | elif isinstance(m, nn.Linear):
140 | n = m.weight.size(1)
141 | m.weight.data.normal_(0, 0.01)
142 | m.bias.data.zero_()
143 |
144 |
145 | def mobilenetv2(pretrained=False, **kwargs):
146 | """Constructs a MobileNet_V2 model.
147 |
148 | Args:
149 | pretrained (bool): If True, returns a model pre-trained on ImageNet
150 | """
151 | model = MobileNetV2(n_class=1000, **kwargs)
152 | if pretrained:
153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
154 | return model
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | import collections
7 | from torch.nn.parallel._functions import Gather
8 |
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 |
13 | def async_copy_to(obj, dev, main_stream=None):
14 | if torch.is_tensor(obj):
15 | v = obj.cuda(dev, non_blocking=True)
16 | if main_stream is not None:
17 | v.data.record_stream(main_stream)
18 | return v
19 | elif isinstance(obj, collections.Mapping):
20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21 | elif isinstance(obj, collections.Sequence):
22 | return [async_copy_to(o, dev, main_stream) for o in obj]
23 | else:
24 | return obj
25 |
26 |
27 | def dict_gather(outputs, target_device, dim=0):
28 | """
29 | Gathers variables from different GPUs on a specified device
30 | (-1 means the CPU), with dictionary support.
31 | """
32 | def gather_map(outputs):
33 | out = outputs[0]
34 | if torch.is_tensor(out):
35 | # MJY(20180330) HACK:: force nr_dims > 0
36 | if out.dim() == 0:
37 | outputs = [o.unsqueeze(0) for o in outputs]
38 | return Gather.apply(target_device, dim, *outputs)
39 | elif out is None:
40 | return None
41 | elif isinstance(out, collections.Mapping):
42 | return {k: gather_map([o[k] for o in outputs]) for k in out}
43 | elif isinstance(out, collections.Sequence):
44 | return type(out)(map(gather_map, zip(*outputs)))
45 | return gather_map(outputs)
46 |
47 |
48 | class DictGatherDataParallel(nn.DataParallel):
49 | def gather(self, outputs, output_device):
50 | return dict_gather(outputs, output_device, dim=self.dim)
51 |
52 |
53 | class UserScatteredDataParallel(DictGatherDataParallel):
54 | def scatter(self, inputs, kwargs, device_ids):
55 | assert len(inputs) == 1
56 | inputs = inputs[0]
57 | inputs = _async_copy_stream(inputs, device_ids)
58 | inputs = [[i] for i in inputs]
59 | assert len(kwargs) == 0
60 | kwargs = [{} for _ in range(len(inputs))]
61 |
62 | return inputs, kwargs
63 |
64 |
65 | def user_scattered_collate(batch):
66 | return batch
67 |
68 |
69 | def _async_copy(inputs, device_ids):
70 | nr_devs = len(device_ids)
71 | assert type(inputs) in (tuple, list)
72 | assert len(inputs) == nr_devs
73 |
74 | outputs = []
75 | for i, dev in zip(inputs, device_ids):
76 | with cuda.device(dev):
77 | outputs.append(async_copy_to(i, dev))
78 |
79 | return tuple(outputs)
80 |
81 |
82 | def _async_copy_stream(inputs, device_ids):
83 | nr_devs = len(device_ids)
84 | assert type(inputs) in (tuple, list)
85 | assert len(inputs) == nr_devs
86 |
87 | outputs = []
88 | streams = [_get_stream(d) for d in device_ids]
89 | for i, dev, stream in zip(inputs, device_ids, streams):
90 | with cuda.device(dev):
91 | main_stream = cuda.current_stream()
92 | with cuda.stream(stream):
93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94 | main_stream.wait_stream(stream)
95 |
96 | return outputs
97 |
98 |
99 | """Adapted from: torch/nn/parallel/_functions.py"""
100 | # background streams used for copying
101 | _streams = None
102 |
103 |
104 | def _get_stream(device):
105 | """Gets a background stream for copying between CPU and GPU"""
106 | global _streams
107 | if device == -1:
108 | return None
109 | if _streams is None:
110 | _streams = [None] * cuda.device_count()
111 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
112 | return _streams[device]
113 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/models/ade20k/segm_lib/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/models/ade20k/utils.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2 |
3 | import os
4 | import sys
5 |
6 | import numpy as np
7 | import torch
8 |
9 | try:
10 | from urllib import urlretrieve
11 | except ImportError:
12 | from urllib.request import urlretrieve
13 |
14 |
15 | def load_url(url, model_dir='./pretrained', map_location=None):
16 | if not os.path.exists(model_dir):
17 | os.makedirs(model_dir)
18 | filename = url.split('/')[-1]
19 | cached_file = os.path.join(model_dir, filename)
20 | if not os.path.exists(cached_file):
21 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
22 | urlretrieve(url, cached_file)
23 | return torch.load(cached_file, map_location=map_location)
24 |
25 |
26 | def color_encode(labelmap, colors, mode='RGB'):
27 | labelmap = labelmap.astype('int')
28 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
29 | dtype=np.uint8)
30 | for label in np.unique(labelmap):
31 | if label < 0:
32 | continue
33 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
34 | np.tile(colors[label],
35 | (labelmap.shape[0], labelmap.shape[1], 1))
36 |
37 | if mode == 'BGR':
38 | return labelmap_rgb[:, :, ::-1]
39 | else:
40 | return labelmap_rgb
41 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/models/lpips_models/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/alex.pth
--------------------------------------------------------------------------------
/models/lpips_models/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/squeeze.pth
--------------------------------------------------------------------------------
/models/lpips_models/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/models/lpips_models/vgg.pth
--------------------------------------------------------------------------------
/outputs/inpainting_results/6458524847_2f4c361183_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/6458524847_2f4c361183_k.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/8399166846_f6fb4e4b8e_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/8399166846_f6fb4e4b8e_k.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/alex-iby-G_Pk4D9rMLs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/alex-iby-G_Pk4D9rMLs.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/bench2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/bench2.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/bertrand-gabioud-CpuFzIsHYJ0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/bertrand-gabioud-CpuFzIsHYJ0.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/billow926-12-Wc-Zgx6Y.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/billow926-12-Wc-Zgx6Y.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/outputs/inpainting_results/photo-1583445095369-9c651e7e5d34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/outputs/inpainting_results/photo-1583445095369-9c651e7e5d34.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.4.3
2 | easydict==1.10
3 | scikit-learn==1.2.0
4 | opencv-python==4.1.2.30
5 | pudb==2019.2
6 | imageio==2.9.0
7 | imageio-ffmpeg==0.4.2
8 | future>=0.17.1
9 | pytorch-lightning==1.4.2
10 | omegaconf==2.1.1
11 | test-tube>=0.7.5
12 | streamlit>=0.73.1
13 | einops==0.3.0
14 | torch-fidelity==0.3.0
15 | #transformers==4.6.0
16 | transformers==4.19.2
17 | torchmetrics==0.6
18 | protobuf==3.20.3
19 | invisible-watermark
20 | diffusers==0.12.1
21 | kornia==0.6
22 |
--------------------------------------------------------------------------------
/saicinpainting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 | from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
6 | from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
7 |
8 |
9 | def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
10 | logging.info(f'Make evaluator {kind}')
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 | metrics = {}
13 | if ssim:
14 | metrics['ssim'] = SSIMScore()
15 | if lpips:
16 | metrics['lpips'] = LPIPSScore()
17 | if fid:
18 | metrics['fid'] = FIDScore().to(device)
19 |
20 | if integral_kind is None:
21 | integral_func = None
22 | elif integral_kind == 'ssim_fid100_f1':
23 | integral_func = ssim_fid100_f1
24 | elif integral_kind == 'lpips_fid100_f1':
25 | integral_func = lpips_fid100_f1
26 | else:
27 | raise ValueError(f'Unexpected integral_kind={integral_kind}')
28 |
29 | if kind == 'default':
30 | return InpaintingEvaluatorOnline(scores=metrics,
31 | integral_func=integral_func,
32 | integral_title=integral_kind,
33 | **kwargs)
34 |
--------------------------------------------------------------------------------
/saicinpainting/evaluation/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/losses/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/evaluation/losses/fid/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/losses/fid/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/evaluation/losses/ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class SSIM(torch.nn.Module):
7 | """SSIM. Modified from:
8 | https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
9 | """
10 |
11 | def __init__(self, window_size=11, size_average=True):
12 | super().__init__()
13 | self.window_size = window_size
14 | self.size_average = size_average
15 | self.channel = 1
16 | self.register_buffer('window', self._create_window(window_size, self.channel))
17 |
18 | def forward(self, img1, img2):
19 | assert len(img1.shape) == 4
20 |
21 | channel = img1.size()[1]
22 |
23 | if channel == self.channel and self.window.data.type() == img1.data.type():
24 | window = self.window
25 | else:
26 | window = self._create_window(self.window_size, channel)
27 |
28 | # window = window.to(img1.get_device())
29 | window = window.type_as(img1)
30 |
31 | self.window = window
32 | self.channel = channel
33 |
34 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
35 |
36 | def _gaussian(self, window_size, sigma):
37 | gauss = torch.Tensor([
38 | np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)
39 | ])
40 | return gauss / gauss.sum()
41 |
42 | def _create_window(self, window_size, channel):
43 | _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
44 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
45 | return _2D_window.expand(channel, 1, window_size, window_size).contiguous()
46 |
47 | def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
48 | mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel)
49 | mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel)
50 |
51 | mu1_sq = mu1.pow(2)
52 | mu2_sq = mu2.pow(2)
53 | mu1_mu2 = mu1 * mu2
54 |
55 | sigma1_sq = F.conv2d(
56 | img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq
57 | sigma2_sq = F.conv2d(
58 | img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq
59 | sigma12 = F.conv2d(
60 | img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2
61 |
62 | C1 = 0.01 ** 2
63 | C2 = 0.03 ** 2
64 |
65 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
66 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
67 |
68 | if size_average:
69 | return ssim_map.mean()
70 |
71 | return ssim_map.mean(1).mean(1).mean(1)
72 |
73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
74 | return
75 |
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/README.md:
--------------------------------------------------------------------------------
1 | # Current algorithm
2 |
3 | ## Choice of mask objects
4 |
5 | For identification of the objects which are suitable for mask obtaining, panoptic segmentation model
6 | from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances
7 | belong either to "stuff" or "things" types. We consider that instances of objects should have category belong
8 | to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big
9 | area indicates either of the instance being a background or a main object which should not be removed.
10 |
11 | ## Choice of position for mask
12 |
13 | We consider that input image has size 2^n x 2^m. We downsample it using
14 | [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to
15 | 64 = 2^8 = 2^{downsample_levels}.
16 |
17 | ### Augmentation
18 |
19 | There are several parameters for augmentation:
20 | - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the
21 | image completely.
22 | -
23 |
24 | ### Shift
25 |
26 |
27 | ## Select
28 |
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/.gitignore:
--------------------------------------------------------------------------------
1 | results
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/william-silversmith/countless)
2 |
3 | Python COUNTLESS Downsampling
4 | =============================
5 |
6 | To install:
7 |
8 | `pip install -r requirements.txt`
9 |
10 | To test:
11 |
12 | `python test.py`
13 |
14 | To benchmark countless2d:
15 |
16 | `python python/countless2d.py python/images/gray_segmentation.png`
17 |
18 | To benchmark countless3d:
19 |
20 | `python python/countless3d.py`
21 |
22 | Adjust N and the list of algorithms inside each script to modify the run parameters.
23 |
24 |
25 | Python3 is slightly faster than Python2.
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/images/gcim.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/gcim.jpg
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/images/segmentation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/segmentation.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/images/sparse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/images/sparse.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png
--------------------------------------------------------------------------------
/saicinpainting/evaluation/masks/countless/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow>=6.2.0
2 | numpy>=1.16
3 | scipy
4 | tqdm
5 | memory_profiler
6 | six
7 | pytest
--------------------------------------------------------------------------------
/saicinpainting/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import yaml
4 | from easydict import EasyDict as edict
5 | import torch.nn as nn
6 | import torch
7 |
8 |
9 | def load_yaml(path):
10 | with open(path, 'r') as f:
11 | return edict(yaml.safe_load(f))
12 |
13 |
14 | def move_to_device(obj, device):
15 | if isinstance(obj, nn.Module):
16 | return obj.to(device)
17 | if torch.is_tensor(obj):
18 | return obj.to(device)
19 | if isinstance(obj, (tuple, list)):
20 | return [move_to_device(el, device) for el in obj]
21 | if isinstance(obj, dict):
22 | return {name: move_to_device(val, device) for name, val in obj.items()}
23 | raise ValueError(f'Unexpected type {type(obj)}')
24 |
25 |
26 | class SmallMode(Enum):
27 | DROP = "drop"
28 | UPSCALE = "upscale"
29 |
--------------------------------------------------------------------------------
/saicinpainting/evaluation/vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage import io
3 | from skimage.segmentation import mark_boundaries
4 |
5 |
6 | def save_item_for_vis(item, out_file):
7 | mask = item['mask'] > 0.5
8 | if mask.ndim == 3:
9 | mask = mask[0]
10 | img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)),
11 | mask,
12 | color=(1., 0., 0.),
13 | outline_color=(1., 1., 1.),
14 | mode='thick')
15 |
16 | if 'inpainted' in item:
17 | inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)),
18 | mask,
19 | color=(1., 0., 0.),
20 | mode='outer')
21 | img = np.concatenate((img, inp_img), axis=1)
22 |
23 | img = np.clip(img * 255, 0, 255).astype('uint8')
24 | io.imsave(out_file, img)
25 |
26 |
27 | def save_mask_for_sidebyside(item, out_file):
28 | mask = item['mask']# > 0.5
29 | if mask.ndim == 3:
30 | mask = mask[0]
31 | mask = np.clip(mask * 255, 0, 255).astype('uint8')
32 | io.imsave(out_file, mask)
33 |
34 | def save_img_for_sidebyside(item, out_file):
35 | img = np.transpose(item['image'], (1, 2, 0))
36 | img = np.clip(img * 255, 0, 255).astype('uint8')
37 | io.imsave(out_file, img)
--------------------------------------------------------------------------------
/saicinpainting/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/training/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/data/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/training/data/aug.py:
--------------------------------------------------------------------------------
1 | from albumentations import DualIAATransform, to_tuple
2 | import imgaug.augmenters as iaa
3 |
4 | class IAAAffine2(DualIAATransform):
5 | """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
6 | via affine transformations.
7 |
8 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
9 |
10 | Args:
11 | p (float): probability of applying the transform. Default: 0.5.
12 |
13 | Targets:
14 | image, mask
15 | """
16 |
17 | def __init__(
18 | self,
19 | scale=(0.7, 1.3),
20 | translate_percent=None,
21 | translate_px=None,
22 | rotate=0.0,
23 | shear=(-0.1, 0.1),
24 | order=1,
25 | cval=0,
26 | mode="reflect",
27 | always_apply=False,
28 | p=0.5,
29 | ):
30 | super(IAAAffine2, self).__init__(always_apply, p)
31 | self.scale = dict(x=scale, y=scale)
32 | self.translate_percent = to_tuple(translate_percent, 0)
33 | self.translate_px = to_tuple(translate_px, 0)
34 | self.rotate = to_tuple(rotate)
35 | self.shear = dict(x=shear, y=shear)
36 | self.order = order
37 | self.cval = cval
38 | self.mode = mode
39 |
40 | @property
41 | def processor(self):
42 | return iaa.Affine(
43 | self.scale,
44 | self.translate_percent,
45 | self.translate_px,
46 | self.rotate,
47 | self.shear,
48 | self.order,
49 | self.cval,
50 | self.mode,
51 | )
52 |
53 | def get_transform_init_args_names(self):
54 | return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
55 |
56 |
57 | class IAAPerspective2(DualIAATransform):
58 | """Perform a random four point perspective transform of the input.
59 |
60 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
61 |
62 | Args:
63 | scale ((float, float): standard deviation of the normal distributions. These are used to sample
64 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
65 | p (float): probability of applying the transform. Default: 0.5.
66 |
67 | Targets:
68 | image, mask
69 | """
70 |
71 | def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5,
72 | order=1, cval=0, mode="replicate"):
73 | super(IAAPerspective2, self).__init__(always_apply, p)
74 | self.scale = to_tuple(scale, 1.0)
75 | self.keep_size = keep_size
76 | self.cval = cval
77 | self.mode = mode
78 |
79 | @property
80 | def processor(self):
81 | return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)
82 |
83 | def get_transform_init_args_names(self):
84 | return ("scale", "keep_size")
85 |
--------------------------------------------------------------------------------
/saicinpainting/training/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/saicinpainting/training/losses/__init__.py
--------------------------------------------------------------------------------
/saicinpainting/training/losses/constants.py:
--------------------------------------------------------------------------------
1 | weights = {"ade20k":
2 | [6.34517766497462,
3 | 9.328358208955224,
4 | 11.389521640091116,
5 | 16.10305958132045,
6 | 20.833333333333332,
7 | 22.22222222222222,
8 | 25.125628140703515,
9 | 43.29004329004329,
10 | 50.5050505050505,
11 | 54.6448087431694,
12 | 55.24861878453038,
13 | 60.24096385542168,
14 | 62.5,
15 | 66.2251655629139,
16 | 84.74576271186442,
17 | 90.90909090909092,
18 | 91.74311926605505,
19 | 96.15384615384616,
20 | 96.15384615384616,
21 | 97.08737864077669,
22 | 102.04081632653062,
23 | 135.13513513513513,
24 | 149.2537313432836,
25 | 153.84615384615384,
26 | 163.93442622950818,
27 | 166.66666666666666,
28 | 188.67924528301887,
29 | 192.30769230769232,
30 | 217.3913043478261,
31 | 227.27272727272725,
32 | 227.27272727272725,
33 | 227.27272727272725,
34 | 303.03030303030306,
35 | 322.5806451612903,
36 | 333.3333333333333,
37 | 370.3703703703703,
38 | 384.61538461538464,
39 | 416.6666666666667,
40 | 416.6666666666667,
41 | 434.7826086956522,
42 | 434.7826086956522,
43 | 454.5454545454545,
44 | 454.5454545454545,
45 | 500.0,
46 | 526.3157894736842,
47 | 526.3157894736842,
48 | 555.5555555555555,
49 | 555.5555555555555,
50 | 555.5555555555555,
51 | 555.5555555555555,
52 | 555.5555555555555,
53 | 555.5555555555555,
54 | 555.5555555555555,
55 | 588.2352941176471,
56 | 588.2352941176471,
57 | 588.2352941176471,
58 | 588.2352941176471,
59 | 588.2352941176471,
60 | 666.6666666666666,
61 | 666.6666666666666,
62 | 666.6666666666666,
63 | 666.6666666666666,
64 | 714.2857142857143,
65 | 714.2857142857143,
66 | 714.2857142857143,
67 | 714.2857142857143,
68 | 714.2857142857143,
69 | 769.2307692307693,
70 | 769.2307692307693,
71 | 769.2307692307693,
72 | 833.3333333333334,
73 | 833.3333333333334,
74 | 833.3333333333334,
75 | 833.3333333333334,
76 | 909.090909090909,
77 | 1000.0,
78 | 1111.111111111111,
79 | 1111.111111111111,
80 | 1111.111111111111,
81 | 1111.111111111111,
82 | 1111.111111111111,
83 | 1250.0,
84 | 1250.0,
85 | 1250.0,
86 | 1250.0,
87 | 1250.0,
88 | 1428.5714285714287,
89 | 1428.5714285714287,
90 | 1428.5714285714287,
91 | 1428.5714285714287,
92 | 1428.5714285714287,
93 | 1428.5714285714287,
94 | 1428.5714285714287,
95 | 1666.6666666666667,
96 | 1666.6666666666667,
97 | 1666.6666666666667,
98 | 1666.6666666666667,
99 | 1666.6666666666667,
100 | 1666.6666666666667,
101 | 1666.6666666666667,
102 | 1666.6666666666667,
103 | 1666.6666666666667,
104 | 1666.6666666666667,
105 | 1666.6666666666667,
106 | 2000.0,
107 | 2000.0,
108 | 2000.0,
109 | 2000.0,
110 | 2000.0,
111 | 2000.0,
112 | 2000.0,
113 | 2000.0,
114 | 2000.0,
115 | 2000.0,
116 | 2000.0,
117 | 2000.0,
118 | 2000.0,
119 | 2000.0,
120 | 2000.0,
121 | 2000.0,
122 | 2000.0,
123 | 2500.0,
124 | 2500.0,
125 | 2500.0,
126 | 2500.0,
127 | 2500.0,
128 | 2500.0,
129 | 2500.0,
130 | 2500.0,
131 | 2500.0,
132 | 2500.0,
133 | 2500.0,
134 | 2500.0,
135 | 2500.0,
136 | 3333.3333333333335,
137 | 3333.3333333333335,
138 | 3333.3333333333335,
139 | 3333.3333333333335,
140 | 3333.3333333333335,
141 | 3333.3333333333335,
142 | 3333.3333333333335,
143 | 3333.3333333333335,
144 | 3333.3333333333335,
145 | 3333.3333333333335,
146 | 3333.3333333333335,
147 | 3333.3333333333335,
148 | 3333.3333333333335,
149 | 5000.0,
150 | 5000.0,
151 | 5000.0]
152 | }
--------------------------------------------------------------------------------
/saicinpainting/training/losses/feature_matching.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known
10 | return (pixel_weights * per_pixel_l2).mean()
11 |
12 |
13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known
16 | return (pixel_weights * per_pixel_l1).mean()
17 |
18 |
19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
20 | if mask is None:
21 | res = torch.stack([F.mse_loss(fake_feat, target_feat)
22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
23 | else:
24 | res = 0
25 | norm = 0
26 | for fake_feat, target_feat in zip(fake_features, target_features):
27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
28 | error_weights = 1 - cur_mask
29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
30 | res = res + cur_val
31 | norm += 1
32 | res = res / norm
33 | return res
34 |
--------------------------------------------------------------------------------
/saicinpainting/training/losses/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | from models.ade20k import ModelBuilder
7 | from saicinpainting.utils import check_and_warn_input_range
8 |
9 |
10 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
11 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
12 |
13 |
14 | class PerceptualLoss(nn.Module):
15 | def __init__(self, normalize_inputs=True):
16 | super(PerceptualLoss, self).__init__()
17 |
18 | self.normalize_inputs = normalize_inputs
19 | self.mean_ = IMAGENET_MEAN
20 | self.std_ = IMAGENET_STD
21 |
22 | vgg = torchvision.models.vgg19(pretrained=True).features
23 | vgg_avg_pooling = []
24 |
25 | for weights in vgg.parameters():
26 | weights.requires_grad = False
27 |
28 | for module in vgg.modules():
29 | if module.__class__.__name__ == 'Sequential':
30 | continue
31 | elif module.__class__.__name__ == 'MaxPool2d':
32 | vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
33 | else:
34 | vgg_avg_pooling.append(module)
35 |
36 | self.vgg = nn.Sequential(*vgg_avg_pooling)
37 |
38 | def do_normalize_inputs(self, x):
39 | return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
40 |
41 | def partial_losses(self, input, target, mask=None):
42 | check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
43 |
44 | # we expect input and target to be in [0, 1] range
45 | losses = []
46 |
47 | if self.normalize_inputs:
48 | features_input = self.do_normalize_inputs(input)
49 | features_target = self.do_normalize_inputs(target)
50 | else:
51 | features_input = input
52 | features_target = target
53 |
54 | for layer in self.vgg[:30]:
55 |
56 | features_input = layer(features_input)
57 | features_target = layer(features_target)
58 |
59 | if layer.__class__.__name__ == 'ReLU':
60 | loss = F.mse_loss(features_input, features_target, reduction='none')
61 |
62 | if mask is not None:
63 | cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
64 | mode='bilinear', align_corners=False)
65 | loss = loss * (1 - cur_mask)
66 |
67 | loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
68 | losses.append(loss)
69 |
70 | return losses
71 |
72 | def forward(self, input, target, mask=None):
73 | losses = self.partial_losses(input, target, mask=mask)
74 | return torch.stack(losses).sum(dim=0)
75 |
76 | def get_global_features(self, input):
77 | check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
78 |
79 | if self.normalize_inputs:
80 | features_input = self.do_normalize_inputs(input)
81 | else:
82 | features_input = input
83 |
84 | features_input = self.vgg(features_input)
85 | return features_input
86 |
87 |
88 | class ResNetPL(nn.Module):
89 | def __init__(self, weight=1,
90 | weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
91 | super().__init__()
92 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
93 | arch_encoder=arch_encoder,
94 | arch_decoder='ppm_deepsup',
95 | fc_dim=2048,
96 | segmentation=segmentation)
97 | self.impl.eval()
98 | for w in self.impl.parameters():
99 | w.requires_grad_(False)
100 |
101 | self.weight = weight
102 |
103 | def forward(self, pred, target):
104 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
105 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
106 |
107 | pred_feats = self.impl(pred, return_feature_maps=True)
108 | target_feats = self.impl(target, return_feature_maps=True)
109 |
110 | result = torch.stack([F.mse_loss(cur_pred, cur_target)
111 | for cur_pred, cur_target
112 | in zip(pred_feats, target_feats)]).sum() * self.weight
113 | return result
114 |
--------------------------------------------------------------------------------
/saicinpainting/training/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .constants import weights as constant_weights
6 |
7 |
8 | class CrossEntropy2d(nn.Module):
9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
10 | """
11 | weight (Tensor, optional): a manual rescaling weight given to each class.
12 | If given, has to be a Tensor of size "nclasses"
13 | """
14 | super(CrossEntropy2d, self).__init__()
15 | self.reduction = reduction
16 | self.ignore_label = ignore_label
17 | self.weights = weights
18 | if self.weights is not None:
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
21 |
22 | def forward(self, predict, target):
23 | """
24 | Args:
25 | predict:(n, c, h, w)
26 | target:(n, 1, h, w)
27 | """
28 | target = target.long()
29 | assert not target.requires_grad
30 | assert predict.dim() == 4, "{0}".format(predict.size())
31 | assert target.dim() == 4, "{0}".format(target.size())
32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
33 | assert target.size(1) == 1, "{0}".format(target.size(1))
34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
36 | target = target.squeeze(1)
37 | n, c, h, w = predict.size()
38 | target_mask = (target >= 0) * (target != self.ignore_label)
39 | target = target[target_mask]
40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
43 | return loss
44 |
--------------------------------------------------------------------------------
/saicinpainting/training/losses/style_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 |
5 |
6 | class PerceptualLoss(nn.Module):
7 | r"""
8 | Perceptual loss, VGG-based
9 | https://arxiv.org/abs/1603.08155
10 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
11 | """
12 |
13 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
14 | super(PerceptualLoss, self).__init__()
15 | self.add_module('vgg', VGG19())
16 | self.criterion = torch.nn.L1Loss()
17 | self.weights = weights
18 |
19 | def __call__(self, x, y):
20 | # Compute features
21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
22 |
23 | content_loss = 0.0
24 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
25 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
26 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
27 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
28 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
29 |
30 |
31 | return content_loss
32 |
33 |
34 | class VGG19(torch.nn.Module):
35 | def __init__(self):
36 | super(VGG19, self).__init__()
37 | features = models.vgg19(pretrained=True).features
38 | self.relu1_1 = torch.nn.Sequential()
39 | self.relu1_2 = torch.nn.Sequential()
40 |
41 | self.relu2_1 = torch.nn.Sequential()
42 | self.relu2_2 = torch.nn.Sequential()
43 |
44 | self.relu3_1 = torch.nn.Sequential()
45 | self.relu3_2 = torch.nn.Sequential()
46 | self.relu3_3 = torch.nn.Sequential()
47 | self.relu3_4 = torch.nn.Sequential()
48 |
49 | self.relu4_1 = torch.nn.Sequential()
50 | self.relu4_2 = torch.nn.Sequential()
51 | self.relu4_3 = torch.nn.Sequential()
52 | self.relu4_4 = torch.nn.Sequential()
53 |
54 | self.relu5_1 = torch.nn.Sequential()
55 | self.relu5_2 = torch.nn.Sequential()
56 | self.relu5_3 = torch.nn.Sequential()
57 | self.relu5_4 = torch.nn.Sequential()
58 |
59 | for x in range(2):
60 | self.relu1_1.add_module(str(x), features[x])
61 |
62 | for x in range(2, 4):
63 | self.relu1_2.add_module(str(x), features[x])
64 |
65 | for x in range(4, 7):
66 | self.relu2_1.add_module(str(x), features[x])
67 |
68 | for x in range(7, 9):
69 | self.relu2_2.add_module(str(x), features[x])
70 |
71 | for x in range(9, 12):
72 | self.relu3_1.add_module(str(x), features[x])
73 |
74 | for x in range(12, 14):
75 | self.relu3_2.add_module(str(x), features[x])
76 |
77 | for x in range(14, 16):
78 | self.relu3_2.add_module(str(x), features[x])
79 |
80 | for x in range(16, 18):
81 | self.relu3_4.add_module(str(x), features[x])
82 |
83 | for x in range(18, 21):
84 | self.relu4_1.add_module(str(x), features[x])
85 |
86 | for x in range(21, 23):
87 | self.relu4_2.add_module(str(x), features[x])
88 |
89 | for x in range(23, 25):
90 | self.relu4_3.add_module(str(x), features[x])
91 |
92 | for x in range(25, 27):
93 | self.relu4_4.add_module(str(x), features[x])
94 |
95 | for x in range(27, 30):
96 | self.relu5_1.add_module(str(x), features[x])
97 |
98 | for x in range(30, 32):
99 | self.relu5_2.add_module(str(x), features[x])
100 |
101 | for x in range(32, 34):
102 | self.relu5_3.add_module(str(x), features[x])
103 |
104 | for x in range(34, 36):
105 | self.relu5_4.add_module(str(x), features[x])
106 |
107 | # don't need the gradients, just want the features
108 | for param in self.parameters():
109 | param.requires_grad = False
110 |
111 | def forward(self, x):
112 | relu1_1 = self.relu1_1(x)
113 | relu1_2 = self.relu1_2(relu1_1)
114 |
115 | relu2_1 = self.relu2_1(relu1_2)
116 | relu2_2 = self.relu2_2(relu2_1)
117 |
118 | relu3_1 = self.relu3_1(relu2_2)
119 | relu3_2 = self.relu3_2(relu3_1)
120 | relu3_3 = self.relu3_3(relu3_2)
121 | relu3_4 = self.relu3_4(relu3_3)
122 |
123 | relu4_1 = self.relu4_1(relu3_4)
124 | relu4_2 = self.relu4_2(relu4_1)
125 | relu4_3 = self.relu4_3(relu4_2)
126 | relu4_4 = self.relu4_4(relu4_3)
127 |
128 | relu5_1 = self.relu5_1(relu4_4)
129 | relu5_2 = self.relu5_2(relu5_1)
130 | relu5_3 = self.relu5_3(relu5_2)
131 | relu5_4 = self.relu5_4(relu5_3)
132 |
133 | out = {
134 | 'relu1_1': relu1_1,
135 | 'relu1_2': relu1_2,
136 |
137 | 'relu2_1': relu2_1,
138 | 'relu2_2': relu2_2,
139 |
140 | 'relu3_1': relu3_1,
141 | 'relu3_2': relu3_2,
142 | 'relu3_3': relu3_3,
143 | 'relu3_4': relu3_4,
144 |
145 | 'relu4_1': relu4_1,
146 | 'relu4_2': relu4_2,
147 | 'relu4_3': relu4_3,
148 | 'relu4_4': relu4_4,
149 |
150 | 'relu5_1': relu5_1,
151 | 'relu5_2': relu5_2,
152 | 'relu5_3': relu5_3,
153 | 'relu5_4': relu5_4,
154 | }
155 | return out
156 |
--------------------------------------------------------------------------------
/saicinpainting/training/modules/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from saicinpainting.training.modules.ffc import FFCResNetGenerator
4 | from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator
6 |
7 | def make_generator(config, kind, **kwargs):
8 | logging.info(f'Make generator {kind}')
9 |
10 | if kind == 'pix2pixhd_multidilated':
11 | return MultiDilatedGlobalGenerator(**kwargs)
12 |
13 | if kind == 'pix2pixhd_global':
14 | return GlobalGenerator(**kwargs)
15 |
16 | if kind == 'ffc_resnet':
17 | return FFCResNetGenerator(**kwargs)
18 |
19 | raise ValueError(f'Unknown generator kind {kind}')
20 |
21 |
22 | def make_discriminator(kind, **kwargs):
23 | logging.info(f'Make discriminator {kind}')
24 |
25 | if kind == 'pix2pixhd_nlayer_multidilated':
26 | return MultidilatedNLayerDiscriminator(**kwargs)
27 |
28 | if kind == 'pix2pixhd_nlayer':
29 | return NLayerDiscriminator(**kwargs)
30 |
31 | raise ValueError(f'Unknown discriminator kind {kind}')
32 |
--------------------------------------------------------------------------------
/saicinpainting/training/modules/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Tuple, List
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
8 | from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
9 |
10 |
11 | class BaseDiscriminator(nn.Module):
12 | @abc.abstractmethod
13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
14 | """
15 | Predict scores and get intermediate activations. Useful for feature matching loss
16 | :return tuple (scores, list of intermediate activations)
17 | """
18 | raise NotImplemented()
19 |
20 |
21 | def get_conv_block_ctor(kind='default'):
22 | if not isinstance(kind, str):
23 | return kind
24 | if kind == 'default':
25 | return nn.Conv2d
26 | if kind == 'depthwise':
27 | return DepthWiseSeperableConv
28 | if kind == 'multidilated':
29 | return MultidilatedConv
30 | raise ValueError(f'Unknown convolutional block kind {kind}')
31 |
32 |
33 | def get_norm_layer(kind='bn'):
34 | if not isinstance(kind, str):
35 | return kind
36 | if kind == 'bn':
37 | return nn.BatchNorm2d
38 | if kind == 'in':
39 | return nn.InstanceNorm2d
40 | raise ValueError(f'Unknown norm block kind {kind}')
41 |
42 |
43 | def get_activation(kind='tanh'):
44 | if kind == 'tanh':
45 | return nn.Tanh()
46 | if kind == 'sigmoid':
47 | return nn.Sigmoid()
48 | if kind is False:
49 | return nn.Identity()
50 | raise ValueError(f'Unknown activation kind {kind}')
51 |
52 |
53 | class SimpleMultiStepGenerator(nn.Module):
54 | def __init__(self, steps: List[nn.Module]):
55 | super().__init__()
56 | self.steps = nn.ModuleList(steps)
57 |
58 | def forward(self, x):
59 | cur_in = x
60 | outs = []
61 | for step in self.steps:
62 | cur_out = step(cur_in)
63 | outs.append(cur_out)
64 | cur_in = torch.cat((cur_in, cur_out), dim=1)
65 | return torch.cat(outs[::-1], dim=1)
66 |
67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
68 | if kind == 'convtranspose':
69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult),
70 | min(max_features, int(ngf * mult / 2)),
71 | kernel_size=3, stride=2, padding=1, output_padding=1),
72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation]
73 | elif kind == 'bilinear':
74 | return [nn.Upsample(scale_factor=2, mode='bilinear'),
75 | DepthWiseSeperableConv(min(max_features, ngf * mult),
76 | min(max_features, int(ngf * mult / 2)),
77 | kernel_size=3, stride=1, padding=1),
78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation]
79 | else:
80 | raise Exception(f"Invalid deconv kind: {kind}")
--------------------------------------------------------------------------------
/saicinpainting/training/modules/depthwise_sep_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DepthWiseSeperableConv(nn.Module):
5 | def __init__(self, in_dim, out_dim, *args, **kwargs):
6 | super().__init__()
7 | if 'groups' in kwargs:
8 | # ignoring groups for Depthwise Sep Conv
9 | del kwargs['groups']
10 |
11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
13 |
14 | def forward(self, x):
15 | out = self.depthwise(x)
16 | out = self.pointwise(out)
17 | return out
--------------------------------------------------------------------------------
/saicinpainting/training/modules/fake_fakes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia import SamplePadding
3 | from kornia.augmentation import RandomAffine, CenterCrop
4 |
5 |
6 | class FakeFakesGenerator:
7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8 | self.grad_aug = RandomAffine(degrees=360,
9 | translate=0.2,
10 | padding_mode=SamplePadding.REFLECTION,
11 | keepdim=False,
12 | p=1)
13 | self.img_aug = RandomAffine(degrees=img_aug_degree,
14 | translate=img_aug_translate,
15 | padding_mode=SamplePadding.REFLECTION,
16 | keepdim=True,
17 | p=1)
18 | self.aug_proba = aug_proba
19 |
20 | def __call__(self, input_images, masks):
21 | blend_masks = self._fill_masks_with_gradient(masks)
22 | blend_target = self._make_blend_target(input_images)
23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks
24 | return result, blend_masks
25 |
26 | def _make_blend_target(self, input_images):
27 | batch_size = input_images.shape[0]
28 | permuted = input_images[torch.randperm(batch_size)]
29 | augmented = self.img_aug(input_images)
30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31 | result = augmented * is_aug + permuted * (1 - is_aug)
32 | return result
33 |
34 | def _fill_masks_with_gradient(self, masks):
35 | batch_size, _, height, width = masks.shape
36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38 | grad = self.grad_aug(grad)
39 | grad = CenterCrop((height, width))(grad)
40 | grad *= masks
41 |
42 | grad_for_min = grad + (1 - masks) * 10
43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45 | grad.clamp_(min=0, max=1)
46 |
47 | return grad
48 |
--------------------------------------------------------------------------------
/saicinpainting/training/modules/multidilated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
5 |
6 | class MultidilatedConv(nn.Module):
7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
9 | super().__init__()
10 | convs = []
11 | self.equal_dim = equal_dim
12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
13 | if comb_mode in ('cat_out', 'cat_both'):
14 | self.cat_out = True
15 | if equal_dim:
16 | assert out_dim % dilation_num == 0
17 | out_dims = [out_dim // dilation_num] * dilation_num
18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
19 | else:
20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
21 | out_dims.append(out_dim - sum(out_dims))
22 | index = []
23 | starts = [0] + out_dims[:-1]
24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
25 | for i in range(out_dims[-1]):
26 | for j in range(dilation_num):
27 | index += list(range(starts[j], starts[j] + lengths[j]))
28 | starts[j] += lengths[j]
29 | self.index = index
30 | assert(len(index) == out_dim)
31 | self.out_dims = out_dims
32 | else:
33 | self.cat_out = False
34 | self.out_dims = [out_dim] * dilation_num
35 |
36 | if comb_mode in ('cat_in', 'cat_both'):
37 | if equal_dim:
38 | assert in_dim % dilation_num == 0
39 | in_dims = [in_dim // dilation_num] * dilation_num
40 | else:
41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
42 | in_dims.append(in_dim - sum(in_dims))
43 | self.in_dims = in_dims
44 | self.cat_in = True
45 | else:
46 | self.cat_in = False
47 | self.in_dims = [in_dim] * dilation_num
48 |
49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
50 | dilation = min_dilation
51 | for i in range(dilation_num):
52 | if isinstance(padding, int):
53 | cur_padding = padding * dilation
54 | else:
55 | cur_padding = padding[i]
56 | convs.append(conv_type(
57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
58 | ))
59 | if i > 0 and shared_weights:
60 | convs[-1].weight = convs[0].weight
61 | convs[-1].bias = convs[0].bias
62 | dilation *= 2
63 | self.convs = nn.ModuleList(convs)
64 |
65 | self.shuffle_in_channels = shuffle_in_channels
66 | if self.shuffle_in_channels:
67 | # shuffle list as shuffling of tensors is nondeterministic
68 | in_channels_permute = list(range(in_dim))
69 | random.shuffle(in_channels_permute)
70 | # save as buffer so it is saved and loaded with checkpoint
71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
72 |
73 | def forward(self, x):
74 | if self.shuffle_in_channels:
75 | x = x[:, self.in_channels_permute]
76 |
77 | outs = []
78 | if self.cat_in:
79 | if self.equal_dim:
80 | x = x.chunk(len(self.convs), dim=1)
81 | else:
82 | new_x = []
83 | start = 0
84 | for dim in self.in_dims:
85 | new_x.append(x[:, start:start+dim])
86 | start += dim
87 | x = new_x
88 | for i, conv in enumerate(self.convs):
89 | if self.cat_in:
90 | input = x[i]
91 | else:
92 | input = x
93 | outs.append(conv(input))
94 | if self.cat_out:
95 | out = torch.cat(outs, dim=1)[:, self.index]
96 | else:
97 | out = sum(outs)
98 | return out
99 |
--------------------------------------------------------------------------------
/saicinpainting/training/modules/spatial_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from kornia.geometry.transform import rotate
5 |
6 |
7 | class LearnableSpatialTransformWrapper(nn.Module):
8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
9 | super().__init__()
10 | self.impl = impl
11 | self.angle = torch.rand(1) * angle_init_range
12 | if train_angle:
13 | self.angle = nn.Parameter(self.angle, requires_grad=True)
14 | self.pad_coef = pad_coef
15 |
16 | def forward(self, x):
17 | if torch.is_tensor(x):
18 | return self.inverse_transform(self.impl(self.transform(x)), x)
19 | elif isinstance(x, tuple):
20 | x_trans = tuple(self.transform(elem) for elem in x)
21 | y_trans = self.impl(x_trans)
22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
23 | else:
24 | raise ValueError(f'Unexpected input type {type(x)}')
25 |
26 | def transform(self, x):
27 | height, width = x.shape[2:]
28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
31 | return x_padded_rotated
32 |
33 | def inverse_transform(self, y_padded_rotated, orig_x):
34 | height, width = orig_x.shape[2:]
35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
36 |
37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
38 | y_height, y_width = y_padded.shape[2:]
39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
40 | return y
41 |
42 |
43 | if __name__ == '__main__':
44 | layer = LearnableSpatialTransformWrapper(nn.Identity())
45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
46 | y = layer(x)
47 | assert x.shape == y.shape
48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
49 | print('all ok')
50 |
--------------------------------------------------------------------------------
/saicinpainting/training/modules/squeeze_excitation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class SELayer(nn.Module):
5 | def __init__(self, channel, reduction=16):
6 | super(SELayer, self).__init__()
7 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
8 | self.fc = nn.Sequential(
9 | nn.Linear(channel, channel // reduction, bias=False),
10 | nn.ReLU(inplace=True),
11 | nn.Linear(channel // reduction, channel, bias=False),
12 | nn.Sigmoid()
13 | )
14 |
15 | def forward(self, x):
16 | b, c, _, _ = x.size()
17 | y = self.avg_pool(x).view(b, c)
18 | y = self.fc(y).view(b, c, 1, 1)
19 | res = x * y.expand_as(x)
20 | return res
21 |
--------------------------------------------------------------------------------
/saicinpainting/training/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
4 |
5 |
6 | def get_training_model_class(kind):
7 | if kind == 'default':
8 | return DefaultInpaintingTrainingModule
9 |
10 | raise ValueError(f'Unknown trainer module {kind}')
11 |
12 |
13 | def make_training_model(config):
14 | kind = config.training_model.kind
15 | kwargs = dict(config.training_model)
16 | kwargs.pop('kind')
17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
18 |
19 | logging.info(f'Make training model {kind}')
20 |
21 | cls = get_training_model_class(kind)
22 | return cls(config, **kwargs)
23 |
24 |
25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True):
26 | model: torch.nn.Module = make_training_model(train_config)
27 | state = torch.load(path, map_location=map_location)
28 | model.load_state_dict(state['state_dict'], strict=strict)
29 | model.on_load_checkpoint(state)
30 | return model
31 |
--------------------------------------------------------------------------------
/saicinpainting/training/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from saicinpainting.training.visualizers.directory import DirectoryVisualizer
4 | from saicinpainting.training.visualizers.noop import NoopVisualizer
5 |
6 |
7 | def make_visualizer(kind, **kwargs):
8 | logging.info(f'Make visualizer {kind}')
9 |
10 | if kind == 'directory':
11 | return DirectoryVisualizer(**kwargs)
12 | if kind == 'noop':
13 | return NoopVisualizer()
14 |
15 | raise ValueError(f'Unknown visualizer kind {kind}')
16 |
--------------------------------------------------------------------------------
/saicinpainting/training/visualizers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Dict, List
3 |
4 | import numpy as np
5 | import torch
6 | from skimage import color
7 | from skimage.segmentation import mark_boundaries
8 |
9 | from . import colors
10 |
11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
12 |
13 |
14 | class BaseVisualizer:
15 | @abc.abstractmethod
16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
17 | """
18 | Take a batch, make an image from it and visualize
19 | """
20 | raise NotImplementedError()
21 |
22 |
23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
24 | last_without_mask=True, rescale_keys=None, mask_only_first=None,
25 | black_mask=False) -> np.ndarray:
26 | mask = images_dict['mask'] > 0.5
27 | result = []
28 | for i, k in enumerate(keys):
29 | img = images_dict[k]
30 | img = np.transpose(img, (1, 2, 0))
31 |
32 | if rescale_keys is not None and k in rescale_keys:
33 | img = img - img.min()
34 | img /= img.max() + 1e-5
35 | if len(img.shape) == 2:
36 | img = np.expand_dims(img, 2)
37 |
38 | if img.shape[2] == 1:
39 | img = np.repeat(img, 3, axis=2)
40 | elif (img.shape[2] > 3):
41 | img_classes = img.argmax(2)
42 | img = color.label2rgb(img_classes, colors=COLORS)
43 |
44 | if mask_only_first:
45 | need_mark_boundaries = i == 0
46 | else:
47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
48 |
49 | if need_mark_boundaries:
50 | if black_mask:
51 | img = img * (1 - mask[0][..., None])
52 | img = mark_boundaries(img,
53 | mask[0],
54 | color=(1., 0., 0.),
55 | outline_color=(1., 1., 1.),
56 | mode='thick')
57 | result.append(img)
58 | return np.concatenate(result, axis=1)
59 |
60 |
61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
62 | last_without_mask=True, rescale_keys=None) -> np.ndarray:
63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
64 | if k in keys or k == 'mask'}
65 |
66 | batch_size = next(iter(batch.values())).shape[0]
67 | items_to_vis = min(batch_size, max_items)
68 | result = []
69 | for i in range(items_to_vis):
70 | cur_dct = {k: tens[i] for k, tens in batch.items()}
71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
72 | rescale_keys=rescale_keys))
73 | return np.concatenate(result, axis=0)
74 |
--------------------------------------------------------------------------------
/saicinpainting/training/visualizers/colors.py:
--------------------------------------------------------------------------------
1 | import random
2 | import colorsys
3 |
4 | import numpy as np
5 | import matplotlib
6 | matplotlib.use('agg')
7 | import matplotlib.pyplot as plt
8 | from matplotlib.colors import LinearSegmentedColormap
9 |
10 |
11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
13 | """
14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
15 | :param nlabels: Number of labels (size of colormap)
16 | :param type: 'bright' for strong colors, 'soft' for pastel colors
17 | :param first_color_black: Option to use first color as black, True or False
18 | :param last_color_black: Option to use last color as black, True or False
19 | :param verbose: Prints the number of labels and shows the colormap. True or False
20 | :return: colormap for matplotlib
21 | """
22 | if type not in ('bright', 'soft'):
23 | print ('Please choose "bright" or "soft" for type')
24 | return
25 |
26 | if verbose:
27 | print('Number of labels: ' + str(nlabels))
28 |
29 | # Generate color map for bright colors, based on hsv
30 | if type == 'bright':
31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1),
32 | np.random.uniform(low=0.2, high=1),
33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
34 |
35 | # Convert HSV list to RGB
36 | randRGBcolors = []
37 | for HSVcolor in randHSVcolors:
38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
39 |
40 | if first_color_black:
41 | randRGBcolors[0] = [0, 0, 0]
42 |
43 | if last_color_black:
44 | randRGBcolors[-1] = [0, 0, 0]
45 |
46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
47 |
48 | # Generate soft pastel colors, by limiting the RGB spectrum
49 | if type == 'soft':
50 | low = 0.6
51 | high = 0.95
52 | randRGBcolors = [(np.random.uniform(low=low, high=high),
53 | np.random.uniform(low=low, high=high),
54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)]
55 |
56 | if first_color_black:
57 | randRGBcolors[0] = [0, 0, 0]
58 |
59 | if last_color_black:
60 | randRGBcolors[-1] = [0, 0, 0]
61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
62 |
63 | # Display colorbar
64 | if verbose:
65 | from matplotlib import colors, colorbar
66 | from matplotlib import pyplot as plt
67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
68 |
69 | bounds = np.linspace(0, nlabels, nlabels + 1)
70 | norm = colors.BoundaryNorm(bounds, nlabels)
71 |
72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
73 | boundaries=bounds, format='%1i', orientation=u'horizontal')
74 |
75 | return randRGBcolors, random_colormap
76 |
77 |
--------------------------------------------------------------------------------
/saicinpainting/training/visualizers/directory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
7 | from saicinpainting.utils import check_and_warn_input_range
8 |
9 |
10 | class DirectoryVisualizer(BaseVisualizer):
11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
12 |
13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
14 | last_without_mask=True, rescale_keys=None):
15 | self.outdir = outdir
16 | os.makedirs(self.outdir, exist_ok=True)
17 | self.key_order = key_order
18 | self.max_items_in_batch = max_items_in_batch
19 | self.last_without_mask = last_without_mask
20 | self.rescale_keys = rescale_keys
21 |
22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
25 | last_without_mask=self.last_without_mask,
26 | rescale_keys=self.rescale_keys)
27 |
28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
29 |
30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
31 | os.makedirs(curoutdir, exist_ok=True)
32 | rank_suffix = f'_r{rank}' if rank is not None else ''
33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
34 |
35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
36 | cv2.imwrite(out_fname, vis_img)
37 |
--------------------------------------------------------------------------------
/saicinpainting/training/visualizers/noop.py:
--------------------------------------------------------------------------------
1 | from saicinpainting.training.visualizers.base import BaseVisualizer
2 |
3 |
4 | class NoopVisualizer(BaseVisualizer):
5 | def __init__(self, *args, **kwargs):
6 | pass
7 |
8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
9 | pass
10 |
--------------------------------------------------------------------------------
/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | parser.add_argument(
54 | "--ckpt",
55 | type=str,
56 | default="/data0/zixin/cluster_results/epoch=000001.ckpt",
57 | help="dir to save checkpoint",
58 | )
59 | parser.add_argument(
60 | "--config",
61 | type=str,
62 | default="configs/v1_my.yaml",
63 | help="dir to obtain config",
64 | )
65 | opt = parser.parse_args()
66 |
67 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
68 | images = [x.replace("_mask.png", ".png") for x in masks]
69 | print(f"Found {len(masks)} inputs.")
70 |
71 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
72 | model = instantiate_from_config(config.model)
73 |
74 | config1 = OmegaConf.load(opt.config)
75 | first_stage_model = load_model_from_config(config1, opt.ckpt)
76 |
77 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
78 | strict=False)
79 |
80 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
81 | model = model.to(device)
82 | first_stage_model = first_stage_model.to(device)
83 | sampler = DDIMSampler(model)
84 |
85 | os.makedirs(opt.outdir, exist_ok=True)
86 | with torch.no_grad():
87 | with model.ema_scope():
88 | for image, mask in tqdm(zip(images, masks)):
89 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
90 | batch = make_batch(image, mask, device=device)
91 |
92 | # encode masked image and concat downsampled mask
93 | c = model.cond_stage_model.encode(batch["masked_image"])
94 | cc = torch.nn.functional.interpolate(batch["mask"],
95 | size=c.shape[-2:])
96 | c = torch.cat((c, cc), dim=1)
97 |
98 | shape = (c.shape[1]-1,)+c.shape[2:]
99 | samples_ddim, _ = sampler.sample(S=opt.steps,
100 | conditioning=c,
101 | batch_size=c.shape[0],
102 | shape=shape,
103 | verbose=False)
104 |
105 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
106 | min=0.0, max=1.0)
107 | x_samples_ddim = model.decode_first_stage(samples_ddim, batch["image"], mask)
108 | image = torch.clamp((batch["image"]+1.0)/2.0,
109 | min=0.0, max=1.0)
110 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
111 | min=0.0, max=1.0)
112 |
113 | inpainted = (1-mask)*image+mask*predicted_image
114 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
115 | # im = Image.fromarray(inpainted.astype(np.uint8))
116 | # im.show()
117 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
118 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/taming/modules/autoencoder/lpips/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/taming/modules/autoencoder/lpips/vgg.pth
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/teaser.png
--------------------------------------------------------------------------------
/text2img_visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/text2img_visual.png
--------------------------------------------------------------------------------
/visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/buxiangzhiren/Asymmetric_VQGAN/406c5cd2e86b3e9565795cc5cd1e3aab304b1a44/visual.png
--------------------------------------------------------------------------------