├── .gitignore
├── LICENSE
├── README.md
├── multi-scale-blended-diffusion
├── InteractiveEditing.ipynb
├── OrignalLatentDiffusionLICENSE
├── configs
│ ├── autoencoder
│ │ ├── autoencoder_kl_16x16x16.yaml
│ │ ├── autoencoder_kl_32x32x4.yaml
│ │ ├── autoencoder_kl_64x64x3.yaml
│ │ └── autoencoder_kl_8x8x64.yaml
│ ├── latent-diffusion
│ │ ├── celebahq-ldm-vq-4.yaml
│ │ ├── cin-ldm-vq-f8.yaml
│ │ ├── cin256-v2.yaml
│ │ ├── ffhq-ldm-vq-4.yaml
│ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ └── txt2img-1p4B-eval.yaml
│ ├── retrieval-augmented-diffusion
│ │ └── 768x768.yaml
│ └── stable-diffusion
│ │ └── v1-inference.yaml
├── environment.yaml
├── inputs
│ ├── bedroom_painting.jpg
│ ├── clarissa_strozzi.jpg
│ ├── inputs.txt
│ ├── lofi_revoy.jpg
│ ├── prof_pic.jpg
│ └── readme.md
├── ldm
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── imagenet.py
│ │ └── lsun.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── autoencoder.py
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── classifier.py
│ │ │ ├── ddim.py
│ │ │ ├── ddpm.py
│ │ │ └── plms.py
│ ├── modules
│ │ ├── attention.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── image_degradation
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── msbd
│ ├── BLDSampler.py
│ ├── MSBDGenerator.py
│ ├── __init__.py
│ └── msbd_utils.py
└── multi_scale_blended_diffusion.py
└── overview.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | build/
11 | develop-eggs/
12 | dist/
13 | downloads/
14 | eggs/
15 | .eggs/
16 | lib/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | share/python-wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | *.py,cover
49 | .hypothesis/
50 | .pytest_cache/
51 | cover/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | .pybuilder/
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | latent-diffusion-main/models
85 | src
86 | latent-diffusion-main/src
87 | latent-diffusion-main/outputs
88 | latent-diffusion-main/output
89 | latent-diffusion-main/data
90 | *.pt
91 | *.png
92 | RealESRGAN_x4plus.pth
93 | *.sh
94 |
95 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Preferred Networks, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## High-Resolution Image Editing via Multi-Stage Blended Diffusion
5 |
6 | 
7 |
8 | This repository includes our implementation of Multi-Stage Blended Diffusion, as described in our paper [High-Resolution Image Editing via Multi-Stage Blended Diffusion](https://arxiv.org/abs/2210.12965)
9 |
10 | Our implementation builds on the original implementation of Latent Diffusion, available at (https://github.com/CompVis/latent-diffusion), which is licensed under the MIT license.
11 | Specifically `multi-scale-blended-diffusion/ldm`, `multi-scale-blended-diffusion/configs`, and `multi-scale-blended-diffusion/models` are entirely from latent diffusion.
12 |
13 | ### Setup:
14 |
15 | * Install the environment specified in `multi-scale-blended-diffusion/environment.yaml`:
16 | ```
17 | conda env create -f environment.yaml
18 | ```
19 | * Download the stable diffusion v1.4 from the [huggingface space](https://huggingface.co/spaces/stabilityai/stable-diffusion) and copy it to `multi-scale-blended-diffusion/models/ldm/stable-diffusion-v1/model.ckpt`.
20 | This requires login and has to be done manually.
21 | * Also download RealESRGANx4plus from [here](https://github.com/xinntao/Real-ESRGAN#inference-general-images) and place it in `multi-scale-blended-diffusion/RealESRGAN_x4plus.pth`:
22 | ```
23 | wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P multi-scale-blended-diffusion/
24 | ```
25 | ### Usage
26 | To try our approach for interactive editing, use the [`multi-scale-blended-diffusion/InteractiveEditing.ipynb`](multi-scale-blended-diffusion/InteractiveEditing.ipynb) notebook locally, or use our [colab demo](https://colab.research.google.com/gist/JohannesAck/2c4561a8a4d1522f752b1a86f3e24c12/multiscaleblendeddiffusioncolab.ipynb).
27 |
28 | To run, validate our approach on the examples used in our paper, use [`multi-scale-blended-diffusion/multi_scale_blended_diffusion.py`.](multi-scale-blended-diffusion/multi_scale_blended_diffusion.py)
29 |
30 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/OrignalLatentDiffusionLICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/environment.yaml:
--------------------------------------------------------------------------------
1 | name: msbd
2 | channels:
3 | - pytorch
4 | - defaults
5 | - conda-forge
6 | dependencies:
7 | - python=3.8.5
8 | - pip
9 | - cudatoolkit=11.3
10 | - pytorch::pytorch
11 | - pytorch::torchvision
12 | - pytorch::torchaudio
13 | - numpy
14 | - jupyterlab
15 | - ipycanvas
16 | - pip:
17 | - albumentations
18 | - pudb==2019.2
19 | - imageio
20 | - imageio-ffmpeg
21 | - pytorch-lightning
22 | - kornia
23 | - omegaconf==2.1.1
24 | - test-tube>=0.7.5
25 | - einops==0.3.0
26 | - transformers
27 | - basicsr
28 | - realesrgan
29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/bedroom_painting.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/inputs/bedroom_painting.jpg
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/clarissa_strozzi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/inputs/clarissa_strozzi.jpg
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/inputs.txt:
--------------------------------------------------------------------------------
1 | # prompt; filename (assuming mask follows FOO_mask.png patter); margin_multiplier
2 | An Oil painting of a girl huggin a corgi on a pedestal; clarissa_strozzi.jpg; 1.4
3 | Statue of Roman Emperor, Canon 5D Mark 3, 35mm, flickr; marunouchi.png; 1.2
4 | Oil painting of Mt. Fuji, by Paul Sandby; river_severn.png; 1.3
5 | red hair; prof_pic.jpg; 1.1
6 | cyberpunk neon cityscape, digital painting, trending on artstation, David Revoy; lofi_revoy.jpg; 1.3
7 | a painting is hanging on the wall; bedroom_painting.jpg; 1.4
8 | fishing boat, lofi, dreamy, moody, very colorful, anime inspiration, ghibli vibe; stable_anime.png; 1.4
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/lofi_revoy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/inputs/lofi_revoy.jpg
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/prof_pic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/inputs/prof_pic.jpg
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/inputs/readme.md:
--------------------------------------------------------------------------------
1 | All images used in our paper are used in accordance with their licenses and attributed below, in the order of Figure \ref{fig:variations}:
2 | 1. `bedroom_painting.jpg` : "white wooden dresser with mirror photo, by Minh Pham, via https://unsplash.com/photos/7pCFUybP\_P8 (3902x5853)
3 | 2. `clarissa_strozzi.jpg` : "Portrait of Clarissa Strozzi", by Titian Vecelli, (1803x2117)
4 | 3. `marunouchi.png` : "people walking on sidewalk near high rise buildings during daytime", by Nat Weearwong, via https://unsplash.com/photos/0cZgvYHirBg (4896x3264)
5 | 4. `river_severn.png` : "The River Severn at Shrewsbury, Shropshire", by Paul Sandby, via https://https://unsplash.com/photos/HEEvYhNzpEo (3999x3041)
6 | 5. `prof_pic.jpg` : Selfie by author (3456x4608)
7 | 6. `lofi_revoy.jpg` : "Lofi Cyberpunk" by David Revoy https://www.davidrevoy.com/article867/lofi-cyberpunk (2431x1930)
8 | 7. `stable_anime.png` : Anime-style image of river generated with stable diffusion by the authors (2048x2048)
9 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/data/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os, yaml, pickle, shutil, tarfile, glob
2 | import cv2
3 | import albumentations
4 | import PIL
5 | import numpy as np
6 | import torchvision.transforms.functional as TF
7 | from omegaconf import OmegaConf
8 | from functools import partial
9 | from PIL import Image
10 | from tqdm import tqdm
11 | from torch.utils.data import Dataset, Subset
12 |
13 | import taming.data.utils as tdu
14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15 | from taming.data.imagenet import ImagePaths
16 |
17 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18 |
19 |
20 | def synset2idx(path_to_yaml="data/index_synset.yaml"):
21 | with open(path_to_yaml) as f:
22 | di2s = yaml.load(f)
23 | return dict((v,k) for k,v in di2s.items())
24 |
25 |
26 | class ImageNetBase(Dataset):
27 | def __init__(self, config=None):
28 | self.config = config or OmegaConf.create()
29 | if not type(self.config)==dict:
30 | self.config = OmegaConf.to_container(self.config)
31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33 | self._prepare()
34 | self._prepare_synset_to_human()
35 | self._prepare_idx_to_synset()
36 | self._prepare_human_to_integer_label()
37 | self._load()
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, i):
43 | return self.data[i]
44 |
45 | def _prepare(self):
46 | raise NotImplementedError()
47 |
48 | def _filter_relpaths(self, relpaths):
49 | ignore = set([
50 | "n06596364_9591.JPEG",
51 | ])
52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53 | if "sub_indices" in self.config:
54 | indices = str_to_indices(self.config["sub_indices"])
55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57 | files = []
58 | for rpath in relpaths:
59 | syn = rpath.split("/")[0]
60 | if syn in synsets:
61 | files.append(rpath)
62 | return files
63 | else:
64 | return relpaths
65 |
66 | def _prepare_synset_to_human(self):
67 | SIZE = 2655750
68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69 | self.human_dict = os.path.join(self.root, "synset_human.txt")
70 | if (not os.path.exists(self.human_dict) or
71 | not os.path.getsize(self.human_dict)==SIZE):
72 | download(URL, self.human_dict)
73 |
74 | def _prepare_idx_to_synset(self):
75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77 | if (not os.path.exists(self.idx2syn)):
78 | download(URL, self.idx2syn)
79 |
80 | def _prepare_human_to_integer_label(self):
81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83 | if (not os.path.exists(self.human2integer)):
84 | download(URL, self.human2integer)
85 | with open(self.human2integer, "r") as f:
86 | lines = f.read().splitlines()
87 | assert len(lines) == 1000
88 | self.human2integer_dict = dict()
89 | for line in lines:
90 | value, key = line.split(":")
91 | self.human2integer_dict[key] = int(value)
92 |
93 | def _load(self):
94 | with open(self.txt_filelist, "r") as f:
95 | self.relpaths = f.read().splitlines()
96 | l1 = len(self.relpaths)
97 | self.relpaths = self._filter_relpaths(self.relpaths)
98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99 |
100 | self.synsets = [p.split("/")[0] for p in self.relpaths]
101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102 |
103 | unique_synsets = np.unique(self.synsets)
104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105 | if not self.keep_orig_class_label:
106 | self.class_labels = [class_dict[s] for s in self.synsets]
107 | else:
108 | self.class_labels = [self.synset2idx[s] for s in self.synsets]
109 |
110 | with open(self.human_dict, "r") as f:
111 | human_dict = f.read().splitlines()
112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113 |
114 | self.human_labels = [human_dict[s] for s in self.synsets]
115 |
116 | labels = {
117 | "relpath": np.array(self.relpaths),
118 | "synsets": np.array(self.synsets),
119 | "class_label": np.array(self.class_labels),
120 | "human_label": np.array(self.human_labels),
121 | }
122 |
123 | if self.process_images:
124 | self.size = retrieve(self.config, "size", default=256)
125 | self.data = ImagePaths(self.abspaths,
126 | labels=labels,
127 | size=self.size,
128 | random_crop=self.random_crop,
129 | )
130 | else:
131 | self.data = self.abspaths
132 |
133 |
134 | class ImageNetTrain(ImageNetBase):
135 | NAME = "ILSVRC2012_train"
136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138 | FILES = [
139 | "ILSVRC2012_img_train.tar",
140 | ]
141 | SIZES = [
142 | 147897477120,
143 | ]
144 |
145 | def __init__(self, process_images=True, data_root=None, **kwargs):
146 | self.process_images = process_images
147 | self.data_root = data_root
148 | super().__init__(**kwargs)
149 |
150 | def _prepare(self):
151 | if self.data_root:
152 | self.root = os.path.join(self.data_root, self.NAME)
153 | else:
154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156 |
157 | self.datadir = os.path.join(self.root, "data")
158 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
159 | self.expected_length = 1281167
160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161 | default=True)
162 | if not tdu.is_prepared(self.root):
163 | # prep
164 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
165 |
166 | datadir = self.datadir
167 | if not os.path.exists(datadir):
168 | path = os.path.join(self.root, self.FILES[0])
169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170 | import academictorrents as at
171 | atpath = at.get(self.AT_HASH, datastore=self.root)
172 | assert atpath == path
173 |
174 | print("Extracting {} to {}".format(path, datadir))
175 | os.makedirs(datadir, exist_ok=True)
176 | with tarfile.open(path, "r:") as tar:
177 | tar.extractall(path=datadir)
178 |
179 | print("Extracting sub-tars.")
180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181 | for subpath in tqdm(subpaths):
182 | subdir = subpath[:-len(".tar")]
183 | os.makedirs(subdir, exist_ok=True)
184 | with tarfile.open(subpath, "r:") as tar:
185 | tar.extractall(path=subdir)
186 |
187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189 | filelist = sorted(filelist)
190 | filelist = "\n".join(filelist)+"\n"
191 | with open(self.txt_filelist, "w") as f:
192 | f.write(filelist)
193 |
194 | tdu.mark_prepared(self.root)
195 |
196 |
197 | class ImageNetValidation(ImageNetBase):
198 | NAME = "ILSVRC2012_validation"
199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202 | FILES = [
203 | "ILSVRC2012_img_val.tar",
204 | "validation_synset.txt",
205 | ]
206 | SIZES = [
207 | 6744924160,
208 | 1950000,
209 | ]
210 |
211 | def __init__(self, process_images=True, data_root=None, **kwargs):
212 | self.data_root = data_root
213 | self.process_images = process_images
214 | super().__init__(**kwargs)
215 |
216 | def _prepare(self):
217 | if self.data_root:
218 | self.root = os.path.join(self.data_root, self.NAME)
219 | else:
220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222 | self.datadir = os.path.join(self.root, "data")
223 | self.txt_filelist = os.path.join(self.root, "filelist.txt")
224 | self.expected_length = 50000
225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226 | default=False)
227 | if not tdu.is_prepared(self.root):
228 | # prep
229 | print("Preparing dataset {} in {}".format(self.NAME, self.root))
230 |
231 | datadir = self.datadir
232 | if not os.path.exists(datadir):
233 | path = os.path.join(self.root, self.FILES[0])
234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235 | import academictorrents as at
236 | atpath = at.get(self.AT_HASH, datastore=self.root)
237 | assert atpath == path
238 |
239 | print("Extracting {} to {}".format(path, datadir))
240 | os.makedirs(datadir, exist_ok=True)
241 | with tarfile.open(path, "r:") as tar:
242 | tar.extractall(path=datadir)
243 |
244 | vspath = os.path.join(self.root, self.FILES[1])
245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246 | download(self.VS_URL, vspath)
247 |
248 | with open(vspath, "r") as f:
249 | synset_dict = f.read().splitlines()
250 | synset_dict = dict(line.split() for line in synset_dict)
251 |
252 | print("Reorganizing into synset folders")
253 | synsets = np.unique(list(synset_dict.values()))
254 | for s in synsets:
255 | os.makedirs(os.path.join(datadir, s), exist_ok=True)
256 | for k, v in synset_dict.items():
257 | src = os.path.join(datadir, k)
258 | dst = os.path.join(datadir, v)
259 | shutil.move(src, dst)
260 |
261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263 | filelist = sorted(filelist)
264 | filelist = "\n".join(filelist)+"\n"
265 | with open(self.txt_filelist, "w") as f:
266 | f.write(filelist)
267 |
268 | tdu.mark_prepared(self.root)
269 |
270 |
271 |
272 | class ImageNetSR(Dataset):
273 | def __init__(self, size=None,
274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275 | random_crop=True):
276 | """
277 | Imagenet Superresolution Dataloader
278 | Performs following ops in order:
279 | 1. crops a crop of size s from image either as random or center crop
280 | 2. resizes crop to size with cv2.area_interpolation
281 | 3. degrades resized crop with degradation_fn
282 |
283 | :param size: resizing to size after cropping
284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285 | :param downscale_f: Low Resolution Downsample factor
286 | :param min_crop_f: determines crop size s,
287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288 | :param max_crop_f: ""
289 | :param data_root:
290 | :param random_crop:
291 | """
292 | self.base = self.get_base()
293 | assert size
294 | assert (size / downscale_f).is_integer()
295 | self.size = size
296 | self.LR_size = int(size / downscale_f)
297 | self.min_crop_f = min_crop_f
298 | self.max_crop_f = max_crop_f
299 | assert(max_crop_f <= 1.)
300 | self.center_crop = not random_crop
301 |
302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303 |
304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305 |
306 | if degradation == "bsrgan":
307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308 |
309 | elif degradation == "bsrgan_light":
310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311 |
312 | else:
313 | interpolation_fn = {
314 | "cv_nearest": cv2.INTER_NEAREST,
315 | "cv_bilinear": cv2.INTER_LINEAR,
316 | "cv_bicubic": cv2.INTER_CUBIC,
317 | "cv_area": cv2.INTER_AREA,
318 | "cv_lanczos": cv2.INTER_LANCZOS4,
319 | "pil_nearest": PIL.Image.NEAREST,
320 | "pil_bilinear": PIL.Image.BILINEAR,
321 | "pil_bicubic": PIL.Image.BICUBIC,
322 | "pil_box": PIL.Image.BOX,
323 | "pil_hamming": PIL.Image.HAMMING,
324 | "pil_lanczos": PIL.Image.LANCZOS,
325 | }[degradation]
326 |
327 | self.pil_interpolation = degradation.startswith("pil_")
328 |
329 | if self.pil_interpolation:
330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331 |
332 | else:
333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334 | interpolation=interpolation_fn)
335 |
336 | def __len__(self):
337 | return len(self.base)
338 |
339 | def __getitem__(self, i):
340 | example = self.base[i]
341 | image = Image.open(example["file_path_"])
342 |
343 | if not image.mode == "RGB":
344 | image = image.convert("RGB")
345 |
346 | image = np.array(image).astype(np.uint8)
347 |
348 | min_side_len = min(image.shape[:2])
349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350 | crop_side_len = int(crop_side_len)
351 |
352 | if self.center_crop:
353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354 |
355 | else:
356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357 |
358 | image = self.cropper(image=image)["image"]
359 | image = self.image_rescaler(image=image)["image"]
360 |
361 | if self.pil_interpolation:
362 | image_pil = PIL.Image.fromarray(image)
363 | LR_image = self.degradation_process(image_pil)
364 | LR_image = np.array(LR_image).astype(np.uint8)
365 |
366 | else:
367 | LR_image = self.degradation_process(image=image)["image"]
368 |
369 | example["image"] = (image/127.5 - 1.0).astype(np.float32)
370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371 |
372 | return example
373 |
374 |
375 | class ImageNetSRTrain(ImageNetSR):
376 | def __init__(self, **kwargs):
377 | super().__init__(**kwargs)
378 |
379 | def get_base(self):
380 | with open("data/imagenet_train_hr_indices.p", "rb") as f:
381 | indices = pickle.load(f)
382 | dset = ImageNetTrain(process_images=False,)
383 | return Subset(dset, indices)
384 |
385 |
386 | class ImageNetSRValidation(ImageNetSR):
387 | def __init__(self, **kwargs):
388 | super().__init__(**kwargs)
389 |
390 | def get_base(self):
391 | with open("data/imagenet_val_hr_indices.p", "rb") as f:
392 | indices = pickle.load(f)
393 | dset = ImageNetValidation(process_images=False,)
394 | return Subset(dset, indices)
395 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
9 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10 |
11 | from ldm.util import instantiate_from_config
12 |
13 |
14 | class VQModel(pl.LightningModule):
15 | def __init__(self,
16 | ddconfig,
17 | lossconfig,
18 | n_embed,
19 | embed_dim,
20 | ckpt_path=None,
21 | ignore_keys=[],
22 | image_key="image",
23 | colorize_nlabels=None,
24 | monitor=None,
25 | batch_resize_range=None,
26 | scheduler_config=None,
27 | lr_g_factor=1.0,
28 | remap=None,
29 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
30 | use_ema=False
31 | ):
32 | super().__init__()
33 | self.embed_dim = embed_dim
34 | self.n_embed = n_embed
35 | self.image_key = image_key
36 | self.encoder = Encoder(**ddconfig)
37 | self.decoder = Decoder(**ddconfig)
38 | self.loss = instantiate_from_config(lossconfig)
39 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40 | remap=remap,
41 | sane_index_shape=sane_index_shape)
42 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44 | if colorize_nlabels is not None:
45 | assert type(colorize_nlabels)==int
46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47 | if monitor is not None:
48 | self.monitor = monitor
49 | self.batch_resize_range = batch_resize_range
50 | if self.batch_resize_range is not None:
51 | print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52 |
53 | self.use_ema = use_ema
54 | if self.use_ema:
55 | self.model_ema = LitEma(self)
56 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57 |
58 | if ckpt_path is not None:
59 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60 | self.scheduler_config = scheduler_config
61 | self.lr_g_factor = lr_g_factor
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def init_from_ckpt(self, path, ignore_keys=list()):
79 | sd = torch.load(path, map_location="cpu")["state_dict"]
80 | keys = list(sd.keys())
81 | for k in keys:
82 | for ik in ignore_keys:
83 | if k.startswith(ik):
84 | print("Deleting key {} from state_dict.".format(k))
85 | del sd[k]
86 | missing, unexpected = self.load_state_dict(sd, strict=False)
87 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88 | if len(missing) > 0:
89 | print(f"Missing Keys: {missing}")
90 | print(f"Unexpected Keys: {unexpected}")
91 |
92 | def on_train_batch_end(self, *args, **kwargs):
93 | if self.use_ema:
94 | self.model_ema(self)
95 |
96 | def encode(self, x):
97 | h = self.encoder(x)
98 | h = self.quant_conv(h)
99 | quant, emb_loss, info = self.quantize(h)
100 | return quant, emb_loss, info
101 |
102 | def encode_to_prequant(self, x):
103 | h = self.encoder(x)
104 | h = self.quant_conv(h)
105 | return h
106 |
107 | def decode(self, quant):
108 | quant = self.post_quant_conv(quant)
109 | dec = self.decoder(quant)
110 | return dec
111 |
112 | def decode_code(self, code_b):
113 | quant_b = self.quantize.embed_code(code_b)
114 | dec = self.decode(quant_b)
115 | return dec
116 |
117 | def forward(self, input, return_pred_indices=False):
118 | quant, diff, (_,_,ind) = self.encode(input)
119 | dec = self.decode(quant)
120 | if return_pred_indices:
121 | return dec, diff, ind
122 | return dec, diff
123 |
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129 | if self.batch_resize_range is not None:
130 | lower_size = self.batch_resize_range[0]
131 | upper_size = self.batch_resize_range[1]
132 | if self.global_step <= 4:
133 | # do the first few batches with max size to avoid later oom
134 | new_resize = upper_size
135 | else:
136 | new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137 | if new_resize != x.shape[2]:
138 | x = F.interpolate(x, size=new_resize, mode="bicubic")
139 | x = x.detach()
140 | return x
141 |
142 | def training_step(self, batch, batch_idx, optimizer_idx):
143 | # https://github.com/pytorch/pytorch/issues/37142
144 | # try not to fool the heuristics
145 | x = self.get_input(batch, self.image_key)
146 | xrec, qloss, ind = self(x, return_pred_indices=True)
147 |
148 | if optimizer_idx == 0:
149 | # autoencode
150 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151 | last_layer=self.get_last_layer(), split="train",
152 | predicted_indices=ind)
153 |
154 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155 | return aeloss
156 |
157 | if optimizer_idx == 1:
158 | # discriminator
159 | discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160 | last_layer=self.get_last_layer(), split="train")
161 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162 | return discloss
163 |
164 | def validation_step(self, batch, batch_idx):
165 | log_dict = self._validation_step(batch, batch_idx)
166 | with self.ema_scope():
167 | log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168 | return log_dict
169 |
170 | def _validation_step(self, batch, batch_idx, suffix=""):
171 | x = self.get_input(batch, self.image_key)
172 | xrec, qloss, ind = self(x, return_pred_indices=True)
173 | aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174 | self.global_step,
175 | last_layer=self.get_last_layer(),
176 | split="val"+suffix,
177 | predicted_indices=ind
178 | )
179 |
180 | discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181 | self.global_step,
182 | last_layer=self.get_last_layer(),
183 | split="val"+suffix,
184 | predicted_indices=ind
185 | )
186 | rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187 | self.log(f"val{suffix}/rec_loss", rec_loss,
188 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189 | self.log(f"val{suffix}/aeloss", aeloss,
190 | prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191 | if version.parse(pl.__version__) >= version.parse('1.4.0'):
192 | del log_dict_ae[f"val{suffix}/rec_loss"]
193 | self.log_dict(log_dict_ae)
194 | self.log_dict(log_dict_disc)
195 | return self.log_dict
196 |
197 | def configure_optimizers(self):
198 | lr_d = self.learning_rate
199 | lr_g = self.lr_g_factor*self.learning_rate
200 | print("lr_d", lr_d)
201 | print("lr_g", lr_g)
202 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203 | list(self.decoder.parameters())+
204 | list(self.quantize.parameters())+
205 | list(self.quant_conv.parameters())+
206 | list(self.post_quant_conv.parameters()),
207 | lr=lr_g, betas=(0.5, 0.9))
208 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209 | lr=lr_d, betas=(0.5, 0.9))
210 |
211 | if self.scheduler_config is not None:
212 | scheduler = instantiate_from_config(self.scheduler_config)
213 |
214 | print("Setting up LambdaLR scheduler...")
215 | scheduler = [
216 | {
217 | 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218 | 'interval': 'step',
219 | 'frequency': 1
220 | },
221 | {
222 | 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223 | 'interval': 'step',
224 | 'frequency': 1
225 | },
226 | ]
227 | return [opt_ae, opt_disc], scheduler
228 | return [opt_ae, opt_disc], []
229 |
230 | def get_last_layer(self):
231 | return self.decoder.conv_out.weight
232 |
233 | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234 | log = dict()
235 | x = self.get_input(batch, self.image_key)
236 | x = x.to(self.device)
237 | if only_inputs:
238 | log["inputs"] = x
239 | return log
240 | xrec, _ = self(x)
241 | if x.shape[1] > 3:
242 | # colorize with random projection
243 | assert xrec.shape[1] > 3
244 | x = self.to_rgb(x)
245 | xrec = self.to_rgb(xrec)
246 | log["inputs"] = x
247 | log["reconstructions"] = xrec
248 | if plot_ema:
249 | with self.ema_scope():
250 | xrec_ema, _ = self(x)
251 | if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252 | log["reconstructions_ema"] = xrec_ema
253 | return log
254 |
255 | def to_rgb(self, x):
256 | assert self.image_key == "segmentation"
257 | if not hasattr(self, "colorize"):
258 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259 | x = F.conv2d(x, weight=self.colorize)
260 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261 | return x
262 |
263 |
264 | class VQModelInterface(VQModel):
265 | def __init__(self, embed_dim, *args, **kwargs):
266 | super().__init__(embed_dim=embed_dim, *args, **kwargs)
267 | self.embed_dim = embed_dim
268 |
269 | def encode(self, x):
270 | h = self.encoder(x)
271 | h = self.quant_conv(h)
272 | return h
273 |
274 | def decode(self, h, force_not_quantize=False):
275 | # also go through quantization layer
276 | if not force_not_quantize:
277 | quant, emb_loss, info = self.quantize(h)
278 | else:
279 | quant = h
280 | quant = self.post_quant_conv(quant)
281 | dec = self.decoder(quant)
282 | return dec
283 |
284 |
285 | class AutoencoderKL(pl.LightningModule):
286 | def __init__(self,
287 | ddconfig,
288 | lossconfig,
289 | embed_dim,
290 | ckpt_path=None,
291 | ignore_keys=[],
292 | image_key="image",
293 | colorize_nlabels=None,
294 | monitor=None,
295 | ):
296 | super().__init__()
297 | self.image_key = image_key
298 | self.encoder = Encoder(**ddconfig)
299 | self.decoder = Decoder(**ddconfig)
300 | self.loss = instantiate_from_config(lossconfig)
301 | assert ddconfig["double_z"]
302 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304 | self.embed_dim = embed_dim
305 | if colorize_nlabels is not None:
306 | assert type(colorize_nlabels)==int
307 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308 | if monitor is not None:
309 | self.monitor = monitor
310 | if ckpt_path is not None:
311 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312 |
313 | def init_from_ckpt(self, path, ignore_keys=list()):
314 | sd = torch.load(path, map_location="cpu")["state_dict"]
315 | keys = list(sd.keys())
316 | for k in keys:
317 | for ik in ignore_keys:
318 | if k.startswith(ik):
319 | print("Deleting key {} from state_dict.".format(k))
320 | del sd[k]
321 | self.load_state_dict(sd, strict=False)
322 | print(f"Restored from {path}")
323 |
324 | def encode(self, x):
325 | h = self.encoder(x)
326 | moments = self.quant_conv(h)
327 | posterior = DiagonalGaussianDistribution(moments)
328 | return posterior
329 |
330 | def decode(self, z):
331 | z = self.post_quant_conv(z)
332 | dec = self.decoder(z)
333 | return dec
334 |
335 | def forward(self, input, sample_posterior=True):
336 | posterior = self.encode(input)
337 | if sample_posterior:
338 | z = posterior.sample()
339 | else:
340 | z = posterior.mode()
341 | dec = self.decode(z)
342 | return dec, posterior
343 |
344 | def get_input(self, batch, k):
345 | x = batch[k]
346 | if len(x.shape) == 3:
347 | x = x[..., None]
348 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349 | return x
350 |
351 | def training_step(self, batch, batch_idx, optimizer_idx):
352 | inputs = self.get_input(batch, self.image_key)
353 | reconstructions, posterior = self(inputs)
354 |
355 | if optimizer_idx == 0:
356 | # train encoder+decoder+logvar
357 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358 | last_layer=self.get_last_layer(), split="train")
359 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361 | return aeloss
362 |
363 | if optimizer_idx == 1:
364 | # train the discriminator
365 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366 | last_layer=self.get_last_layer(), split="train")
367 |
368 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370 | return discloss
371 |
372 | def validation_step(self, batch, batch_idx):
373 | inputs = self.get_input(batch, self.image_key)
374 | reconstructions, posterior = self(inputs)
375 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376 | last_layer=self.get_last_layer(), split="val")
377 |
378 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379 | last_layer=self.get_last_layer(), split="val")
380 |
381 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382 | self.log_dict(log_dict_ae)
383 | self.log_dict(log_dict_disc)
384 | return self.log_dict
385 |
386 | def configure_optimizers(self):
387 | lr = self.learning_rate
388 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389 | list(self.decoder.parameters())+
390 | list(self.quant_conv.parameters())+
391 | list(self.post_quant_conv.parameters()),
392 | lr=lr, betas=(0.5, 0.9))
393 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394 | lr=lr, betas=(0.5, 0.9))
395 | return [opt_ae, opt_disc], []
396 |
397 | def get_last_layer(self):
398 | return self.decoder.conv_out.weight
399 |
400 | @torch.no_grad()
401 | def log_images(self, batch, only_inputs=False, **kwargs):
402 | log = dict()
403 | x = self.get_input(batch, self.image_key)
404 | x = x.to(self.device)
405 | if not only_inputs:
406 | xrec, posterior = self(x)
407 | if x.shape[1] > 3:
408 | # colorize with random projection
409 | assert xrec.shape[1] > 3
410 | x = self.to_rgb(x)
411 | xrec = self.to_rgb(xrec)
412 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413 | log["reconstructions"] = xrec
414 | log["inputs"] = x
415 | return log
416 |
417 | def to_rgb(self, x):
418 | assert self.image_key == "segmentation"
419 | if not hasattr(self, "colorize"):
420 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421 | x = F.conv2d(x, weight=self.colorize)
422 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423 | return x
424 |
425 |
426 | class IdentityFirstStage(torch.nn.Module):
427 | def __init__(self, *args, vq_interface=False, **kwargs):
428 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429 | super().__init__()
430 |
431 | def encode(self, x, *args, **kwargs):
432 | return x
433 |
434 | def decode(self, x, *args, **kwargs):
435 | return x
436 |
437 | def quantize(self, x, *args, **kwargs):
438 | if self.vq_interface:
439 | return x, None, [None, None, None]
440 | return x
441 |
442 | def forward(self, x, *args, **kwargs):
443 | return x
444 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of diffusion model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/models/diffusion/ddim.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
8 |
9 |
10 | class DDIMSampler(object):
11 | def __init__(self, model, schedule="linear", **kwargs):
12 | super().__init__()
13 | self.model = model
14 | self.ddpm_num_timesteps = model.num_timesteps
15 | self.schedule = schedule
16 |
17 | def register_buffer(self, name, attr):
18 | if type(attr) == torch.Tensor:
19 | if attr.device != torch.device("cuda"):
20 | attr = attr.to(torch.device("cuda"))
21 | setattr(self, name, attr)
22 |
23 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25 | num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
26 | alphas_cumprod = self.model.alphas_cumprod
27 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28 | def to_torch(x): return x.clone().detach().to(torch.float32).to(self.model.device)
29 |
30 | self.register_buffer('betas', to_torch(self.model.betas))
31 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33 |
34 | # calculations for diffusion q(x_t | x_{t-1}) and others
35 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36 | self.register_buffer('sqrt_one_minus_alphas_cumprod',
37 | to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38 | self.register_buffer('log_one_minus_alphas_cumprod',
39 | to_torch(np.log(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recip_alphas_cumprod',
41 | to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(
43 | np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta, verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps',
57 | sigmas_for_original_sampling_steps)
58 |
59 | @torch.no_grad()
60 | def sample(self,
61 | S,
62 | batch_size,
63 | shape,
64 | conditioning=None,
65 | callback=None,
66 | normals_sequence=None,
67 | img_callback=None,
68 | quantize_x0=False,
69 | eta=0.,
70 | mask=None,
71 | x0=None,
72 | temperature=1.,
73 | noise_dropout=0.,
74 | score_corrector=None,
75 | corrector_kwargs=None,
76 | verbose=True,
77 | x_T=None,
78 | log_every_t=100,
79 | unconditional_guidance_scale=1.,
80 | unconditional_conditioning=None,
81 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
82 | **kwargs
83 | ):
84 | if conditioning is not None:
85 | if isinstance(conditioning, dict):
86 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87 | if cbs != batch_size:
88 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89 | else:
90 | if conditioning.shape[0] != batch_size:
91 | print(
92 | f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
93 |
94 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
95 | # sampling
96 | C, H, W = shape
97 | size = (batch_size, C, H, W)
98 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
99 |
100 | samples, intermediates = self.ddim_sampling(conditioning, size,
101 | callback=callback,
102 | img_callback=img_callback,
103 | quantize_denoised=quantize_x0,
104 | mask=mask, x0=x0,
105 | ddim_use_original_steps=False,
106 | noise_dropout=noise_dropout,
107 | temperature=temperature,
108 | score_corrector=score_corrector,
109 | corrector_kwargs=corrector_kwargs,
110 | x_T=x_T,
111 | log_every_t=log_every_t,
112 | unconditional_guidance_scale=unconditional_guidance_scale,
113 | unconditional_conditioning=unconditional_conditioning,
114 | )
115 | return samples, intermediates
116 |
117 |
118 | @torch.no_grad()
119 | def ddim_sampling(self, cond, shape,
120 | x_T=None, ddim_use_original_steps=False,
121 | callback=None, timesteps=None, quantize_denoised=False,
122 | mask=None, x0=None, img_callback=None, log_every_t=100,
123 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
124 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
125 | device = self.model.betas.device
126 | b = shape[0]
127 | if x_T is None:
128 | img = torch.randn(shape, device=device)
129 | else:
130 | img = x_T
131 |
132 | if timesteps is None:
133 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134 | elif timesteps is not None and not ddim_use_original_steps:
135 | subset_end = int(
136 | min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
137 | timesteps = self.ddim_timesteps[:subset_end]
138 |
139 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
140 | time_range = reversed(
141 | range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
142 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
143 | print(f"Running DDIM Sampling with {total_steps} timesteps")
144 |
145 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
146 |
147 | for i, step in enumerate(iterator):
148 | index = total_steps - i - 1
149 | ts = torch.full((b,), step, device=device, dtype=torch.long)
150 |
151 | if mask is not None:
152 | assert x0 is not None
153 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154 | img = img_orig * mask + (1. - mask) * img
155 |
156 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157 | quantize_denoised=quantize_denoised, temperature=temperature,
158 | noise_dropout=noise_dropout, score_corrector=score_corrector,
159 | corrector_kwargs=corrector_kwargs,
160 | unconditional_guidance_scale=unconditional_guidance_scale,
161 | unconditional_conditioning=unconditional_conditioning)
162 | img, pred_x0 = outs
163 | if callback:
164 | callback(i)
165 | if img_callback:
166 | img_callback(pred_x0, i)
167 |
168 | if index % log_every_t == 0 or index == total_steps - 1:
169 | intermediates['x_inter'].append(img)
170 | intermediates['pred_x0'].append(pred_x0)
171 |
172 | return img, intermediates
173 |
174 | @torch.no_grad()
175 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
176 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
177 | unconditional_guidance_scale=1., unconditional_conditioning=None):
178 | """
179 | Performs a single denoising step on x_t given conditioning c. Plus many options which seem to be unused in the code?
180 | Can also accept batches I think.
181 | """
182 | b, *_, device = *x.shape, x.device
183 |
184 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
185 | e_t = self.model.apply_model(x, t, c)
186 | else:
187 | x_in = torch.cat([x] * 2)
188 | t_in = torch.cat([t] * 2)
189 | c_in = torch.cat([unconditional_conditioning, c])
190 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
191 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
192 |
193 | if score_corrector is not None:
194 | assert self.model.parameterization == "eps"
195 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
196 |
197 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
198 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
199 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
200 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
201 | # select parameters corresponding to the currently considered timestep
202 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
203 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
204 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
205 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
206 |
207 | # current prediction for x_0
208 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
209 | if quantize_denoised:
210 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
211 | # direction pointing to x_t
212 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
213 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
214 | if noise_dropout > 0.: # wtf is noise dropout????
215 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
216 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
217 | return x_prev, pred_x0
218 |
219 |
220 |
221 | @torch.no_grad()
222 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
223 | # fast, but does not allow for exact reconstruction
224 | # t serves as an index to gather the correct alphas
225 | if use_original_steps:
226 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
227 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
228 | else:
229 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
230 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
231 |
232 | if noise is None:
233 | noise = torch.randn_like(x0)
234 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
235 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
236 |
237 | @torch.no_grad()
238 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
239 | use_original_steps=False):
240 |
241 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
242 | timesteps = timesteps[:t_start]
243 |
244 | time_range = np.flip(timesteps)
245 | total_steps = timesteps.shape[0]
246 | print(f"Running DDIM Sampling with {total_steps} timesteps")
247 |
248 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
249 | x_dec = x_latent
250 | for i, step in enumerate(iterator):
251 | index = total_steps - i - 1
252 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
253 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
254 | unconditional_guidance_scale=unconditional_guidance_scale,
255 | unconditional_conditioning=unconditional_conditioning)
256 | return x_dec
257 |
258 | def extract_into_tensor(a, t, x_shape):
259 | b, *_ = t.shape
260 | out = a.gather(-1, t)
261 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
262 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 | class PLMSSampler(object):
12 | def __init__(self, model, schedule="linear", **kwargs):
13 | super().__init__()
14 | self.model = model
15 | self.ddpm_num_timesteps = model.num_timesteps
16 | self.schedule = schedule
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != torch.device("cuda"):
21 | attr = attr.to(torch.device("cuda"))
22 | setattr(self, name, attr)
23 |
24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25 | if ddim_eta != 0:
26 | raise ValueError('ddim_eta must be 0 for PLMS')
27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29 | alphas_cumprod = self.model.alphas_cumprod
30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32 |
33 | self.register_buffer('betas', to_torch(self.model.betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43 |
44 | # ddim sampling parameters
45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46 | ddim_timesteps=self.ddim_timesteps,
47 | eta=ddim_eta,verbose=verbose)
48 | self.register_buffer('ddim_sigmas', ddim_sigmas)
49 | self.register_buffer('ddim_alphas', ddim_alphas)
50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56 |
57 | @torch.no_grad()
58 | def sample(self,
59 | S,
60 | batch_size,
61 | shape,
62 | conditioning=None,
63 | callback=None,
64 | normals_sequence=None,
65 | img_callback=None,
66 | quantize_x0=False,
67 | eta=0.,
68 | mask=None,
69 | x0=None,
70 | temperature=1.,
71 | noise_dropout=0.,
72 | score_corrector=None,
73 | corrector_kwargs=None,
74 | verbose=True,
75 | x_T=None,
76 | log_every_t=100,
77 | unconditional_guidance_scale=1.,
78 | unconditional_conditioning=None,
79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80 | **kwargs
81 | ):
82 | if conditioning is not None:
83 | if isinstance(conditioning, dict):
84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85 | if cbs != batch_size:
86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87 | else:
88 | if conditioning.shape[0] != batch_size:
89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90 |
91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92 | # sampling
93 | C, H, W = shape
94 | size = (batch_size, C, H, W)
95 | print(f'Data shape for PLMS sampling is {size}')
96 |
97 | samples, intermediates = self.plms_sampling(conditioning, size,
98 | callback=callback,
99 | img_callback=img_callback,
100 | quantize_denoised=quantize_x0,
101 | mask=mask, x0=x0,
102 | ddim_use_original_steps=False,
103 | noise_dropout=noise_dropout,
104 | temperature=temperature,
105 | score_corrector=score_corrector,
106 | corrector_kwargs=corrector_kwargs,
107 | x_T=x_T,
108 | log_every_t=log_every_t,
109 | unconditional_guidance_scale=unconditional_guidance_scale,
110 | unconditional_conditioning=unconditional_conditioning,
111 | )
112 | return samples, intermediates
113 |
114 | @torch.no_grad()
115 | def plms_sampling(self, cond, shape,
116 | x_T=None, ddim_use_original_steps=False,
117 | callback=None, timesteps=None, quantize_denoised=False,
118 | mask=None, x0=None, img_callback=None, log_every_t=100,
119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120 | unconditional_guidance_scale=1., unconditional_conditioning=None,):
121 | device = self.model.betas.device
122 | b = shape[0]
123 | if x_T is None:
124 | img = torch.randn(shape, device=device)
125 | else:
126 | img = x_T
127 |
128 | if timesteps is None:
129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130 | elif timesteps is not None and not ddim_use_original_steps:
131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132 | timesteps = self.ddim_timesteps[:subset_end]
133 |
134 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137 | print(f"Running PLMS Sampling with {total_steps} timesteps")
138 |
139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140 | old_eps = []
141 |
142 | for i, step in enumerate(iterator):
143 | index = total_steps - i - 1
144 | ts = torch.full((b,), step, device=device, dtype=torch.long)
145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146 |
147 | if mask is not None:
148 | assert x0 is not None
149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150 | img = img_orig * mask + (1. - mask) * img
151 |
152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153 | quantize_denoised=quantize_denoised, temperature=temperature,
154 | noise_dropout=noise_dropout, score_corrector=score_corrector,
155 | corrector_kwargs=corrector_kwargs,
156 | unconditional_guidance_scale=unconditional_guidance_scale,
157 | unconditional_conditioning=unconditional_conditioning,
158 | old_eps=old_eps, t_next=ts_next)
159 | img, pred_x0, e_t = outs
160 | old_eps.append(e_t)
161 | if len(old_eps) >= 4:
162 | old_eps.pop(0)
163 | if callback: callback(i)
164 | if img_callback: img_callback(pred_x0, i)
165 |
166 | if index % log_every_t == 0 or index == total_steps - 1:
167 | intermediates['x_inter'].append(img)
168 | intermediates['pred_x0'].append(pred_x0)
169 |
170 | return img, intermediates
171 |
172 | @torch.no_grad()
173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176 | b, *_, device = *x.shape, x.device
177 |
178 | def get_model_output(x, t):
179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180 | e_t = self.model.apply_model(x, t, c)
181 | else:
182 | x_in = torch.cat([x] * 2)
183 | t_in = torch.cat([t] * 2)
184 | c_in = torch.cat([unconditional_conditioning, c])
185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187 |
188 | if score_corrector is not None:
189 | assert self.model.parameterization == "eps"
190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191 |
192 | return e_t
193 |
194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198 |
199 | def get_x_prev_and_pred_x0(e_t, index):
200 | # select parameters corresponding to the currently considered timestep
201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205 |
206 | # current prediction for x_0
207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208 | if quantize_denoised:
209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210 | # direction pointing to x_t
211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213 | if noise_dropout > 0.:
214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216 | return x_prev, pred_x0
217 |
218 | e_t = get_model_output(x, t)
219 | if len(old_eps) == 0:
220 | # Pseudo Improved Euler (2nd order)
221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222 | e_t_next = get_model_output(x_prev, t_next)
223 | e_t_prime = (e_t + e_t_next) / 2
224 | elif len(old_eps) == 1:
225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
227 | elif len(old_eps) == 2:
228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230 | elif len(old_eps) >= 3:
231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233 |
234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235 |
236 | return x_prev, pred_x0, e_t
237 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 | from transformers import CLIPTokenizer, CLIPTextModel
8 |
9 |
10 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
11 |
12 |
13 | class AbstractEncoder(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def encode(self, *args, **kwargs):
18 | raise NotImplementedError
19 |
20 |
21 |
22 | class ClassEmbedder(nn.Module):
23 | def __init__(self, embed_dim, n_classes=1000, key='class'):
24 | super().__init__()
25 | self.key = key
26 | self.embedding = nn.Embedding(n_classes, embed_dim)
27 |
28 | def forward(self, batch, key=None):
29 | if key is None:
30 | key = self.key
31 | # this is for use in crossattn
32 | c = batch[key][:, None]
33 | c = self.embedding(c)
34 | return c
35 |
36 |
37 | class TransformerEmbedder(AbstractEncoder):
38 | """Some transformer encoder layers"""
39 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
40 | super().__init__()
41 | self.device = device
42 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
43 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
44 |
45 | def forward(self, tokens):
46 | tokens = tokens.to(self.device) # meh
47 | z = self.transformer(tokens, return_embeddings=True)
48 | return z
49 |
50 | def encode(self, x):
51 | return self(x)
52 |
53 |
54 | class BERTTokenizer(AbstractEncoder):
55 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
56 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
57 | super().__init__()
58 | from transformers import BertTokenizerFast # TODO: add to reuquirements
59 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
60 | self.device = device
61 | self.vq_interface = vq_interface
62 | self.max_length = max_length
63 |
64 | def forward(self, text):
65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67 | tokens = batch_encoding["input_ids"].to(self.device)
68 | return tokens
69 |
70 | @torch.no_grad()
71 | def encode(self, text):
72 | tokens = self(text)
73 | if not self.vq_interface:
74 | return tokens
75 | return None, None, [None, None, tokens]
76 |
77 | def decode(self, text):
78 | return text
79 |
80 |
81 | class BERTEmbedder(AbstractEncoder):
82 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85 | super().__init__()
86 | self.use_tknz_fn = use_tokenizer
87 | if self.use_tknz_fn:
88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89 | self.device = device
90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
92 | emb_dropout=embedding_dropout)
93 |
94 | def forward(self, text):
95 | if self.use_tknz_fn:
96 | tokens = self.tknz_fn(text)#.to(self.device)
97 | else:
98 | tokens = text
99 | z = self.transformer(tokens, return_embeddings=True)
100 | return z
101 |
102 | def encode(self, text):
103 | # output of length 77
104 | return self(text)
105 |
106 |
107 | class SpatialRescaler(nn.Module):
108 | def __init__(self,
109 | n_stages=1,
110 | method='bilinear',
111 | multiplier=0.5,
112 | in_channels=3,
113 | out_channels=None,
114 | bias=False):
115 | super().__init__()
116 | self.n_stages = n_stages
117 | assert self.n_stages >= 0
118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119 | self.multiplier = multiplier
120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121 | self.remap_output = out_channels is not None
122 | if self.remap_output:
123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125 |
126 | def forward(self,x):
127 | for stage in range(self.n_stages):
128 | x = self.interpolator(x, scale_factor=self.multiplier)
129 |
130 |
131 | if self.remap_output:
132 | x = self.channel_mapper(x)
133 | return x
134 |
135 | def encode(self, x):
136 | return self(x)
137 |
138 |
139 | class FrozenCLIPEmbedder(AbstractEncoder):
140 | """Uses the CLIP transformer encoder for text (from Hugging Face)"""
141 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
142 | super().__init__()
143 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
144 | self.transformer = CLIPTextModel.from_pretrained(version)
145 | self.device = device
146 | self.max_length = max_length
147 | self.freeze()
148 |
149 | def freeze(self):
150 | self.transformer = self.transformer.eval()
151 | for param in self.parameters():
152 | param.requires_grad = False
153 |
154 | def forward(self, text):
155 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
156 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
157 | tokens = batch_encoding["input_ids"].to(self.device)
158 | outputs = self.transformer(input_ids=tokens)
159 |
160 | z = outputs.last_hidden_state
161 | return z
162 |
163 | def encode(self, text):
164 | return self(text)
165 |
166 |
167 | class FrozenCLIPTextEmbedder(nn.Module):
168 | """
169 | Uses the CLIP transformer encoder for text.
170 | """
171 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
172 | super().__init__()
173 | self.model, _ = clip.load(version, jit=False, device="cpu")
174 | self.device = device
175 | self.max_length = max_length
176 | self.n_repeat = n_repeat
177 | self.normalize = normalize
178 |
179 | def freeze(self):
180 | self.model = self.model.eval()
181 | for param in self.parameters():
182 | param.requires_grad = False
183 |
184 | def forward(self, text):
185 | tokens = clip.tokenize(text).to(self.device)
186 | z = self.model.encode_text(tokens)
187 | if self.normalize:
188 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
189 | return z
190 |
191 | def encode(self, text):
192 | z = self(text)
193 | if z.ndim==2:
194 | z = z[:, None, :]
195 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
196 | return z
197 |
198 |
199 | class FrozenClipImageEmbedder(nn.Module):
200 | """
201 | Uses the CLIP image encoder.
202 | """
203 | def __init__(
204 | self,
205 | model,
206 | jit=False,
207 | device='cuda' if torch.cuda.is_available() else 'cpu',
208 | antialias=False,
209 | ):
210 | super().__init__()
211 | self.model, _ = clip.load(name=model, device=device, jit=jit)
212 |
213 | self.antialias = antialias
214 |
215 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
216 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
217 |
218 | def preprocess(self, x):
219 | # normalize to [0,1]
220 | x = kornia.geometry.resize(x, (224, 224),
221 | interpolation='bicubic',align_corners=True,
222 | antialias=self.antialias)
223 | x = (x + 1.) / 2.
224 | # renormalize according to clip
225 | x = kornia.enhance.normalize(x, self.mean, self.std)
226 | return x
227 |
228 | def forward(self, x):
229 | # x is assumed to be in range [-1,1]
230 | return self.model.encode_image(self.preprocess(x))
231 |
232 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/msbd/BLDSampler.py:
--------------------------------------------------------------------------------
1 | """
2 | This module is based on the original latent diffusion models code base, specifically on `ldm/models/diffusion/ddim.py`.
3 | It mainly just implements the sampler for blended latent diffusion by changing the sampling in `blended_ddim_sampling` to incorporate the mask and
4 | replace the unmasked area by a noisy version of the original image after each timestep.
5 | It also implements SDEdit and repaint by starting at an intermediate timestep and repeating diffusion steps multiple times.
6 | """
7 | import logging
8 |
9 | import numpy as np
10 | import torch
11 | from einops import rearrange
12 | from ldm.models.diffusion.ddim import DDIMSampler
13 | from ldm.models.diffusion.ddpm import LatentDiffusion
14 | from PIL import Image
15 | from tqdm.auto import tqdm
16 |
17 | from msbd.msbd_utils import get_dilated_mask, get_repaint_schedule
18 |
19 | logger = logging.getLogger()
20 |
21 |
22 | class BlendedDiffusionSampler(DDIMSampler):
23 | def __init__(self, model: LatentDiffusion, schedule="linear"): # this just adds the typehint for easier coding
24 | super().__init__(model, schedule)
25 |
26 | @torch.no_grad()
27 | def blended_diffusion_sampling(
28 | self,
29 | source_img, # LATENT, not pixels
30 | mask,
31 | num_ddim_steps,
32 | batch_size,
33 | shape,
34 | conditioning=None,
35 | callback=None,
36 | img_callback=None,
37 | quantize_x0=False,
38 | eta=0.,
39 | x0=None,
40 | temperature=1.,
41 | noise_dropout=0.,
42 | score_corrector=None,
43 | corrector_kwargs=None,
44 | verbose=True,
45 | x_T=None,
46 | log_every_t=100,
47 | unconditional_guidance_scale=1.,
48 | unconditional_conditioning=None,
49 | dilate_mask=False,
50 | repaint_steps=4,
51 | repaint_jump=1,
52 | start_timestep=1.0,
53 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
54 | ):
55 | if conditioning is not None:
56 | if isinstance(conditioning, dict):
57 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
58 | if cbs != batch_size:
59 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
60 | else:
61 | if conditioning.shape[0] != batch_size:
62 | print(
63 | f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
64 |
65 | self.make_schedule(ddim_num_steps=num_ddim_steps, ddim_eta=eta, verbose=verbose)
66 | # sampling
67 | C, H, W = shape
68 | size = (batch_size, C, H, W)
69 | print(f'Data shape for DDIM sampling is {size}, eta {eta}')
70 |
71 | samples, intermediates = self.blended_ddim_sampling(
72 | conditioning, size,
73 | source_img=source_img,
74 | callback=callback,
75 | img_callback=img_callback,
76 | quantize_denoised=quantize_x0,
77 | mask=mask, x0=x0,
78 | ddim_use_original_steps=False,
79 | noise_dropout=noise_dropout,
80 | temperature=temperature,
81 | score_corrector=score_corrector,
82 | corrector_kwargs=corrector_kwargs,
83 | x_T=x_T,
84 | log_every_t=log_every_t,
85 | unconditional_guidance_scale=unconditional_guidance_scale,
86 | unconditional_conditioning=unconditional_conditioning,
87 | dilate_mask=dilate_mask,
88 | repaint_steps=repaint_steps,
89 | repaint_jump=repaint_jump,
90 | start_timestep=start_timestep
91 | )
92 | return samples, intermediates
93 |
94 | @torch.no_grad()
95 | def decode_and_save_latent(self, latent: torch.Tensor, fp: str):
96 | """
97 | Function for debugging to decode and save any input image latent tensor.
98 | """
99 | if len(latent.shape) == 3:
100 | latent = latent[None]
101 | if latent.device != self.model.device:
102 | latent = latent.to(self.model.device)
103 | decoded = self.model.decode_first_stage(latent)[0]
104 |
105 | decoded = torch.clamp((decoded+1.0)/2.0, min=0.0, max=1.0).cpu().numpy()
106 | decoded = 255. * rearrange(decoded, 'c h w -> h w c')
107 | Image.fromarray(decoded.astype(np.uint8)).save(fp)
108 | logging.info(f'Saved sample to {fp}')
109 |
110 | def q_sample_start_end(self, x_t_start: torch.Tensor, t_end: int, t_start: int = 0, noise: torch.Tensor = None) -> torch.Tensor:
111 | """
112 | Samples the forward diffusion process from any starting point in the diffusion process, i.e. x_t_end ~ q(x_t_end|x_t_start).
113 |
114 | Samples from the forward process as defined in DDPM, but should be fine to work with DDIM sampling.
115 | """
116 | if noise is None:
117 | noise = torch.randn_like(x_t_start)
118 |
119 | alpha_cumprod_start_end = self.model.alphas_cumprod[t_end] / \
120 | self.model.alphas_cumprod[t_start]
121 | return torch.sqrt(alpha_cumprod_start_end) * x_t_start + torch.sqrt(1 - alpha_cumprod_start_end) * noise
122 |
123 | @torch.no_grad()
124 | def blended_ddim_sampling(
125 | self, cond, shape, source_img,
126 | x_T=None, ddim_use_original_steps=False,
127 | callback=None, timesteps=None, quantize_denoised=False,
128 | mask=None, x0=None, img_callback=None, log_every_t=100,
129 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
130 | unconditional_guidance_scale=1., unconditional_conditioning=None,
131 | dilate_mask=False, repaint_steps=0, repaint_jump=1, start_timestep=1.0):
132 | """
133 | `source_img` is the image that we wil edit, such that the content of `mask` matches `cond`.
134 | This is done by starting at gaussian noise and then doing the following iteratively:
135 |
136 | 1. do one denoising step on img_t
137 | 2. edit_img = denoise(img_t, cond, t)
138 | 3. noised_source_img = forward_process(source_img, t)
139 | 4. img_{t-1} = mask * edit_img + (1 - mask) * noised_source_img
140 | 5. goto 2 till t=0
141 |
142 | `start_timestep` is the ratio of where to start in the diffusion process as in SDEdit.
143 | 1.0 means do the full diffusion from gaussian noise
144 | """
145 | device = self.model.betas.device
146 | b = shape[0]
147 |
148 | mask = mask.to(device)
149 |
150 | if timesteps is None:
151 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
152 | elif timesteps is not None and not ddim_use_original_steps:
153 | subset_end = int(
154 | min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
155 | timesteps = self.ddim_timesteps[:subset_end]
156 |
157 | if x_T is not None:
158 | edit_img = x_T
159 | logger.warn('starting diffusion from preset x_T, ignoring sdedit')
160 | else:
161 | if 0.0 < start_timestep < 1.0 and int(len(self.ddim_timesteps) * start_timestep): # second condition ensures at least one timestep
162 | #SDEdit
163 | if source_img.shape[0] == 1:
164 | # if a single image is given this function will make b different samples
165 | source_img = source_img.repeat(b, 1, 1, 1)
166 |
167 | timesteps = timesteps[:int(len(timesteps) * start_timestep)]
168 | logger.info(f'Using SDEdit ratio {start_timestep}, starting at {timesteps[-1]}/1000')
169 | edit_img = self.q_sample_start_end(source_img, t_end=timesteps[-1], t_start=0)
170 | elif start_timestep == 1.0:
171 | edit_img = torch.randn(shape, device=device)
172 | elif start_timestep == 0.0 or int(len(self.ddim_timesteps) * start_timestep):
173 | logger.warn('Start timestep is 0.0, or rounded down to zero, returning original image')
174 | return source_img, [source_img]
175 |
176 |
177 | intermediates = {'x_inter': [edit_img], 'pred_x0': [edit_img]}
178 | time_range = reversed(
179 | range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
180 |
181 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
182 | print(
183 | f"Running blended DDIM Sampling with {total_steps} timesteps with {repaint_steps} repaint steps and {repaint_jump} repaint jumps")
184 |
185 | if repaint_steps:
186 | # overwrites timerange with a zigzag line
187 | index_schedule = get_repaint_schedule(
188 | np.arange(len(time_range) - 1, -1, -1), repaint_steps, repaint_jump)
189 | time_range = get_repaint_schedule(time_range, repaint_steps, repaint_jump)
190 | logger.info(f'repaint schedule: {time_range}', )
191 | else:
192 | index_schedule = np.arange(len(time_range) - 1, -1, -1)
193 |
194 | for i, step in enumerate(tqdm(time_range, desc='DDIM Sampler', total=len(time_range))):
195 | index = index_schedule[i]
196 | ts = torch.full((b,), step, device=device, dtype=torch.long)
197 |
198 | # ts is the timestep in the original 1000 step DDPM model
199 | # index is the timestep for the DDIM sampler using only as many timesteps as specified
200 | edit_img, pred_x0 = self.p_sample_ddim(edit_img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
201 | quantize_denoised=quantize_denoised, temperature=temperature,
202 | noise_dropout=noise_dropout, score_corrector=score_corrector,
203 | corrector_kwargs=corrector_kwargs,
204 | unconditional_guidance_scale=unconditional_guidance_scale,
205 | unconditional_conditioning=unconditional_conditioning)
206 |
207 | source_img_noised = self.model.q_sample(source_img, ts)
208 | # maybe this should not be ts, but the next step?
209 |
210 | mask_expanded = self.get_mask(mask, dilate_mask, shape, total_steps, i)
211 | edit_img = (1 - mask_expanded) * source_img_noised + mask_expanded * edit_img
212 |
213 | if repaint_steps and i > 0 and i < len(time_range) - 1:
214 | # if doing repaint, jump back if the next element in the schedule is noisier than the current one
215 | # this is the noise level/timestep AFTER the most recent denoising step
216 | t_current = timesteps[index_schedule[i] - 1]
217 | # this is the noise level/timestep we need for the next denoising step
218 | t_next_step = timesteps[index_schedule[i + 1]]
219 | if t_next_step > t_current:
220 | edit_img = self.q_sample_start_end(
221 | edit_img, t_end=t_next_step, t_start=t_current)
222 |
223 | if callback:
224 | callback(i)
225 | if img_callback:
226 | img_callback(pred_x0, i)
227 |
228 | # logger.warn('doing a lot of logging of intermediates!')
229 | if index % log_every_t == 0 or index == total_steps - 1:
230 | intermediates['x_inter'].append(edit_img)
231 | intermediates['pred_x0'].append(pred_x0)
232 |
233 | return edit_img, intermediates
234 |
235 | def get_mask(self, mask, dilate_mask, shape, total_steps, current_step):
236 | """Implements mask dilation, extending the mask at the start of the diffusion process but slowly shrinking it to the original size.
237 | Not really used for our multi-stage blended diffusion.
238 | """
239 | if dilate_mask:
240 | sampling_progress_ratio = current_step / total_steps
241 | if sampling_progress_ratio < 0.25:
242 | kernel_size = 7
243 | elif sampling_progress_ratio < 0.5:
244 | kernel_size = 5
245 | elif sampling_progress_ratio < 0.75:
246 | kernel_size = 3
247 | else:
248 | kernel_size = 1
249 |
250 | mask_dilated = get_dilated_mask(mask, kernel_size)
251 | mask_expanded = mask_dilated.expand(*shape)
252 | else:
253 | mask_expanded = mask.expand(*shape)
254 | return mask_expanded
255 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/msbd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/multi-scale-blended-diffusion/msbd/__init__.py
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/msbd/msbd_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some utils for blended latent diffusion.
3 |
4 | I just wanted to keep them separate from the utils.py in ldm and such.
5 |
6 | @author Johannes Ackermann
7 | """
8 |
9 | import logging
10 | import os
11 | from typing import Tuple
12 |
13 | import numpy as np
14 | import torch
15 | from einops import rearrange
16 | from PIL import Image
17 | from torchvision.utils import make_grid
18 | import kornia
19 |
20 | logger = logging.getLogger()
21 |
22 |
23 | def get_dilated_mask(
24 | mask: torch.Tensor,
25 | dilation_kernel_size: int,
26 | ) -> torch.Tensor:
27 | """
28 | Dilates (extends) the mask by convultion with a kernel of size `dilation_kernel_size`x`dilation_kernel_size`.
29 | If applied to the downsampled mask (as is done in the paper), this should be done with sizes, 7->5->3->1 for
30 | equal parts of the diffusion process. (of course, dilation with size 1 leaves the mask unchanged)
31 | """
32 | kernel = torch.ones([dilation_kernel_size, dilation_kernel_size],
33 | dtype=mask.dtype).to(mask.device)
34 | if len(mask.shape) == 2:
35 | mask = mask[None][None]
36 | dilated_mask = torch.nn.functional.conv2d(mask, kernel[None][None], padding='same')
37 | dilated_mask[dilated_mask >= 1.0] = 1.0
38 | return dilated_mask
39 |
40 |
41 | def get_repaint_schedule(original_schedule: int, repaint_steps: int, repaint_jump: int):
42 | """
43 | Generates a schedule as in repaint, that takes a given denoising schedule and repeats
44 | it in such a way that each set of `repaint_jump` denoising steps is repeated `repaint_steps` times.
45 | example if we had 10 steps in total, repaint_steps = 1 and repaint_jump = 2:
46 | 10-9-8-10-9-8-7-6-5-7-6-5-4-3-5-4-3-5-4-3...
47 |
48 | This code became pretty ugly, but essentially it's just zigzagging through `original_schedule`.
49 | """
50 | # i thought this could be done with repeat(reshape()), but not quite :'(
51 |
52 | if repaint_jump == 0:
53 | schedule = np.repeat(original_schedule, repaint_steps + 1)
54 | else:
55 | n_step_orig = len(original_schedule)
56 | schedule = [original_schedule[:repaint_jump - 1]]
57 | for idx_jump_level in range(1, n_step_orig // repaint_jump):
58 | for idx_rep in range(repaint_steps):
59 | if idx_rep == repaint_steps - 1:
60 | schedule.append(
61 | original_schedule[idx_jump_level * (repaint_jump) - 1:(idx_jump_level + 1) * (repaint_jump) - 1])
62 | else:
63 | schedule.append(
64 | original_schedule[idx_jump_level * (repaint_jump) - 1:(idx_jump_level + 1) * (repaint_jump)])
65 | schedule.append(original_schedule[repaint_jump * (n_step_orig // repaint_jump) - 1:])
66 | schedule = np.concatenate(schedule)
67 | return schedule
68 |
69 |
70 | def tensor_to_pil(tensor: torch.Tensor):
71 | # if not tensor.min() < 0.0:
72 | # logger.warn('Image should be scaled to [-1.0, 1.0]')
73 | tensor = tensor.cpu().numpy()[0]
74 | tensor = (tensor + 1.0) / 2.0
75 | tensor = 255. * rearrange(tensor, 'c h w -> h w c')
76 | return Image.fromarray(tensor.astype(np.uint8))
77 |
78 |
79 | def get_alpha_masks(crops_x: Tuple[int], crops_y: Tuple[int], target_imagesize: Tuple[int], overlap: int):
80 | """
81 | Generates the alpha masks later used for blending.
82 | """
83 | alpha_mask_full = np.zeros([len(crops_x)] + list(target_imagesize))
84 | for idx, (crop_x, crop_y) in enumerate(zip(crops_x, crops_y)):
85 | alpha_mask_full[idx, crop_x[0]:crop_x[1], crop_y[0]:crop_y[1]] = 1.0
86 | alpha_func = 1.0 / (1 + np.exp(-np.linspace(-5.0, 5.0, overlap // 2))) # sigmoid blending
87 | alpha_mask_x = np.tile(alpha_func, [crop_y[1] - crop_y[0], 1]).T
88 | alpha_mask_y = np.tile(alpha_func, [crop_x[1] - crop_x[0], 1])
89 | if not crop_x[0] == 0:
90 | alpha_mask_full[idx, crop_x[0]:crop_x[0] + overlap // 2, crop_y[0]:crop_y[1]
91 | ] = alpha_mask_full[idx, crop_x[0]:crop_x[0] + overlap // 2, crop_y[0]:crop_y[1]] * alpha_mask_x
92 | if not crop_y[0] == 0:
93 | alpha_mask_full[idx, crop_x[0]:crop_x[1], crop_y[0]:crop_y[0] + overlap //
94 | 2] = alpha_mask_full[idx, crop_x[0]:crop_x[1], crop_y[0]:crop_y[0] + overlap // 2] * alpha_mask_y
95 | if not crop_x[1] == target_imagesize[0]:
96 | alpha_mask_full[idx, crop_x[1] - overlap // 2:crop_x[1], crop_y[0]:crop_y[1]] = alpha_mask_full[idx,
97 | crop_x[1] - overlap // 2:crop_x[1], crop_y[0]:crop_y[1]] * alpha_mask_x[::-1, :]
98 | if not crop_y[1] == target_imagesize[1]:
99 | alpha_mask_full[idx, crop_x[0]:crop_x[1], crop_y[1] - overlap // 2:crop_y[1]] = alpha_mask_full[idx,
100 | crop_x[0]:crop_x[1], crop_y[1]-overlap // 2:crop_y[1]] * alpha_mask_y[:, ::-1]
101 |
102 | alpha_mask_seg = []
103 | for idx, (crop_x, crop_y) in enumerate(zip(crops_x, crops_y)):
104 | alpha_mask_seg.append(alpha_mask_full[idx, crop_x[0]:crop_x[1], crop_y[0]:crop_y[1]])
105 |
106 | return alpha_mask_seg
107 |
108 |
109 | def get_result_grid(source_img, all_samples, mask) -> Image:
110 | """
111 | Saves a grid of images visualizing the original image, the mask and multiple samples
112 | """
113 | mask = mask.cpu()
114 | h = all_samples[0].shape[2]
115 | w = all_samples[0].shape[3]
116 | source_img_mask_vis = (0.8 * source_img.clone().cpu() + 1.0) / 2.0
117 | source_img_mask_vis[0, 1] += mask[0, 0] * 0.2
118 | source_img_mask_vis = source_img_mask_vis.clamp(-1.0, 1.0)
119 | source_img = torch.nn.functional.interpolate(source_img, size=[h, w])
120 | source_img_mask_vis = torch.nn.functional.interpolate(source_img_mask_vis, size=[h, w])
121 | # additionally, save as grid
122 | vis_rows = []
123 | for sample_row in all_samples:
124 | vis_rows.append(torch.cat([(source_img.cpu() + 1.0) / 2.0,
125 | source_img_mask_vis, all_samples[0].cpu()], 0))
126 |
127 | grid = torch.stack(vis_rows, 0)
128 |
129 | grid = rearrange(grid, 'n b c h w -> (n b) c h w')
130 | grid = make_grid(grid, nrow=(len(all_samples[0]) + 2))
131 |
132 | # to image
133 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
134 |
135 | pil_result = Image.fromarray(grid.astype(np.uint8))
136 | return pil_result
137 |
138 |
139 | def debug_save_img(image: torch.Tensor, name='debug'):
140 | assert len(image.shape) == 4
141 | # if not image.min() < 0.0:
142 | # logger.warn('Image should be scaled to [-1.0, 1.0]')
143 | image = torch.clamp((image[0] + 1.0) / 2.0, min=0.0, max=1.0)
144 | image = 255.0 * rearrange(image.cpu().numpy(), 'c h w -> h w c')
145 | Image.fromarray(image.astype(np.uint8)).save(f'{name}.png') # debugging
146 | logger.info(f'saved debug output to {name}.png')
147 |
148 |
149 | def is_notebook() -> bool:
150 | """
151 | Based on
152 | https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
153 | """
154 | try:
155 | from IPython import get_ipython
156 | shell = get_ipython().__class__.__name__
157 | if shell == 'ZMQInteractiveShell':
158 | return True # Jupyter notebook or qtconsole
159 | elif shell == 'TerminalInteractiveShell':
160 | return False # Terminal running IPython
161 | else:
162 | return False # Other type (?)
163 | except NameError:
164 | return False # Probably standard Python interpreter
165 |
166 |
167 | def display_or_save(pil_img: Image, folder: str, name: str) -> None:
168 | """
169 | If running in a jupyter notebook uses `display` to show the image, if not running in a jupyter notebook
170 | saves the image to the given folder with the given name.
171 | """
172 | if is_notebook():
173 | from IPython.display import display
174 | print(name)
175 | display(pil_img)
176 | else:
177 | if folder is not None:
178 | fp = os.path.join(folder, name + '.png')
179 | pil_img.save(fp)
180 | logger.info(f'saved image to {fp}')
181 |
182 | def sharpen(image: torch.Tensor, unsharpen_masking: bool = True, kernel_size: int = 11, sigma = 7.0) -> torch.Tensor:
183 | if not unsharpen_masking:
184 | raise NotImplementedError('Only unsharp masking is supported currently.')
185 | else:
186 | logger.info(f'sharpening with sigma {sigma}')
187 | return kornia.filters.unsharp_mask(image, (kernel_size,kernel_size), (sigma,sigma)).clamp(-1.0,1.0)
188 |
189 |
--------------------------------------------------------------------------------
/multi-scale-blended-diffusion/multi_scale_blended_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | This module can be used to batch process the files in a given input folder, along with a txt file containing the prompts and filenames.
3 | See inputs/inputs.txt for an example of the samples we use in our publication.
4 |
5 | For interactive use we recommend to use the notebook `InteractiveEditing.ipynb`.
6 |
7 | @author: Johannes Ackermann
8 | """
9 |
10 | import argparse
11 | import sys
12 | import os
13 | import logging
14 |
15 | from PIL import Image
16 |
17 | sys.path.append(os.getcwd())
18 | from msbd.MSBDGenerator import MSBDGenerator
19 |
20 | logger = logging.getLogger()
21 | logging.getLogger().setLevel(logging.INFO)
22 |
23 |
24 | def main(args):
25 |
26 | if args.input_list is not None and not args.input_list == 'None':
27 | prompts = []
28 | image_fps = []
29 | margin_mults = []
30 | with open(args.input_list, 'r') as f:
31 | lines = f.readlines()
32 | print('Reading from ', args.input_list)
33 | for line in lines:
34 | line = line.strip() # remove newline
35 | if line.startswith('#'):
36 | print('skipping line ', line)
37 | continue
38 | prompts.append(line.split(';')[0])
39 | margin_mults.append(float((line.split(';')[2])))
40 | image_fp = os.path.join(
41 | os.path.dirname(args.input_list),
42 | line.split(';')[1].replace(' ', '')
43 | )
44 | assert os.path.exists(image_fp), f'could not find image {image_fp}'
45 | mask_fp = os.path.splitext(image_fp)[0] + '_mask.png'
46 | assert os.path.exists(mask_fp), f'could not find mask {mask_fp}'
47 | image_fps.append(image_fp)
48 | print(image_fps[-1], prompts[-1], margin_mults[-1])
49 | else:
50 | image_fps = ['inputs/marunouchi.png']
51 | assert os.path.exists(image_fps[0])
52 | prompts = ['Statue of Roman Emperor, Canon 5D Mark 3, 35mm, flickr']
53 | margin_mults = [1.2]
54 |
55 | generator = MSBDGenerator(
56 | use_fp16=args.fp16,
57 | stable_diffusion=True,
58 | max_edgelen=args.max_edgelen,
59 | first_stage_batchsize=args.first_stage_batch
60 | )
61 |
62 | for prompt, image_fp, margin_mult in zip(prompts, image_fps, margin_mults):
63 | result = generator.multi_scale_generation(
64 | pil_img = Image.open(image_fp).convert('RGB'),
65 | pil_mask = Image.open(os.path.splitext(image_fp)[0] + '_mask.png'),
66 | prompt=prompt,
67 | ddim_steps=50,
68 | decoder_optimization=args.decoder_optimization,
69 | clip_reranking=args.clip_reranking,
70 | margin=margin_mult,
71 | seed=args.seed,
72 | repaint_steps=args.repaint_steps,
73 | start_timestep=args.start_timestep,
74 | upscaling_start_step=args.upscale_startstep,
75 | upscaling_mode=args.upscaling_mode,
76 | straight_to_grid=args.straight_to_grid,
77 | grid_upscaling_start_step=args.grid_startstep,
78 | log_folder=args.outdir,
79 | lowpass_reference=args.lowpass_reference,
80 | blended_upscale=args.blended_upscale,
81 | conditional_upscale=args.conditional_upscale,
82 | grid_overlap=args.grid_overlap,
83 | first_stage_size=args.first_stage_size
84 | )
85 | out_fp = os.path.splitext(image_fp)[0] + '_output.jpg'
86 | result.save(out_fp)
87 | print(f'Output saved to{out_fp}.')
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument("--input-list", type=str, nargs="?",
93 | default='inputs/inputs.txt', help="path to a list of prompts and file_paths ")
94 | def str2bool(v):
95 | # from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
96 | if isinstance(v, bool):
97 | return v
98 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
99 | return True
100 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
101 | return False
102 | else:
103 | raise argparse.ArgumentTypeError('Boolean value expected.')
104 |
105 |
106 | parser.add_argument("--prompt", type=str, nargs="?",
107 | default='Oil painting of Mt. Fuji, by Paul Sandby', help="the prompt to render")
108 | parser.add_argument("--outdir", type=str, nargs="?",
109 | help="dir to write results to", default="outputs")
110 | parser.add_argument("--ddim_steps", type=int, default=50,
111 | help="number of ddim sampling steps",)
112 |
113 | # Added to original LDM
114 | parser.add_argument("--fp16", type=str2bool, default=True,
115 | help="run inference in mixed precision",)
116 | parser.add_argument("--use-stablediffusion", type=str2bool, default=True,
117 | help="load the stable diffusion model, if False uses the LDM `text2img-large` model instead",)
118 |
119 | # decoder optimization
120 | parser.add_argument("--decoder-optimization", type=str2bool, default=True,
121 | help="optimize the weights of the decoder for each image",)
122 | parser.add_argument("--decoderopt-it", type=int, default=100,
123 | help="iterations of decoder finetuning, ignored if not using decoder optimization",)
124 |
125 | parser.add_argument("--dilate-mask", type=str2bool, default=False,
126 | help="dilates the mask and shrinks it to original size over the diffusion timesteps, use for small or masks with fine details",)
127 |
128 | parser.add_argument("--seed", type=int, default=-1,
129 | help="seed for everything, set to -1 to seed randomly",)
130 |
131 | # repaint
132 | parser.add_argument("--repaint-steps", type=int, default=5,
133 | help="repetitions of each denoising step for repainting, set to 0 to disable repaint",)
134 | parser.add_argument("--repaint-jump", type=int, default=0,
135 | help="jumps size in the repaint steps, jump size in DDIM steps, not DDPM steps, default=0, i.e. repeats the current step `repaint-step` times,",)
136 |
137 | parser.add_argument("--start-timestep", type=float, default=1.0,
138 | help="SDEdit-like relative timestep to start the diffusion process from, i.e. 1.0 to start from pure noise, 0.5 = T/2",)
139 | parser.add_argument("--upscale-startstep", type=float, default=0.4,
140 | help="start step for upscaling stages except final gridlike upscaling stage",)
141 | parser.add_argument("--grid-startstep", type=float, default=0.25,
142 | help="start step for upscaling in the grid stage",)
143 | parser.add_argument("--clip-reranking", type=str2bool, default=True,
144 | help="re-rerank first-stage outputs by clip similarity.",)
145 | parser.add_argument("--upscaling-mode", type=str, default='esrgan',
146 | help="interpolation mode in the upscaling, 'esrgan', 'sharpen', 'bilinear', or 'bicubic' (or anything supported by torch functional interpolation)",)
147 | parser.add_argument("--straight-to-grid", type=str2bool, default=False,
148 | help="after the first stage immediately go to the grid stage without intermediate steps",)
149 | parser.add_argument("--lowpass-reference", type=str, default='matching',
150 | help=" 'matching', 'half', or 'no' .",)
151 | parser.add_argument("--conditional-upscale", type=str2bool, default=True,
152 | help="do the upscaling with text conditioning",)
153 | parser.add_argument("--blended-upscale", type=str2bool, default=True,
154 | help="do the upscaling with a reference image",)
155 | parser.add_argument("--grid-overlap", type=int, default=128,
156 | help="overlap between different grid regions in pixels. Must be multiple of 64.",)
157 | parser.add_argument("--max-edgelen", type=int, default=12 * 64,
158 | help="Maximum edge length of square images processable by the used GPU. Default value requires 25GB of VRAM.",)
159 | parser.add_argument("--first-stage-size", type=int, default=512,
160 | help="Resolution to be used in the first stage. Default is 512 for stable diffusion, largest possible on V100 is 960",)
161 | parser.add_argument("--first-stage-batch", type=int, default=5,
162 | help="Batch size for first stage. If clip-reranking is enabled, the image with the highest clip similarity is chosen, else the first one is used in subsequent stages",)
163 |
164 | arglist = parser.parse_args()
165 | main(arglist)
166 |
--------------------------------------------------------------------------------
/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/multi-stage-blended-diffusion/b933b3d395edcbe266dbb51abf8f45d3b4b66f81/overview.jpg
--------------------------------------------------------------------------------