├── .DS_Store
├── README.md
├── __pycache__
├── config.cpython-38.pyc
└── tutorial_dataset_ade.cpython-38.pyc
├── color150.mat
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
├── retrieval-augmented-diffusion
│ └── 768x768.yaml
└── stable-diffusion
│ └── PLACE.yaml
├── dataset.py
├── environment.yaml
├── index.html
├── inference.py
├── ldm
├── .DS_Store
├── __pycache__
│ └── util.cpython-38.pyc
├── data
│ ├── __init__.py
│ └── util.py
├── models
│ ├── __pycache__
│ │ └── autoencoder.cpython-38.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ddim.cpython-38.pyc
│ │ ├── ddpm.cpython-38.pyc
│ │ ├── myplms.cpython-38.pyc
│ │ ├── plms.cpython-38.pyc
│ │ ├── sampling_util.cpython-38.pyc
│ │ └── testplms.cpython-38.pyc
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── dpm_solver.cpython-38.pyc
│ │ │ └── sampler.cpython-38.pyc
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── myplms.py
│ │ ├── plms.py
│ │ ├── sampling_util.py
│ │ └── testplms.py
├── modules
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── attention.cpython-38.pyc
│ │ └── ema.cpython-38.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ └── util.cpython-38.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── distributions.cpython-38.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ └── modules.cpython-38.pyc
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
└── util.py
├── resources
├── .DS_Store
├── bibtex.txt
├── ind.png
├── method_diagram.png
├── newfig2_final.pdf
├── od.png
├── overview.png
├── paper.png
└── teaser.png
├── run_inference_ADE20K.sh
└── run_inference_COCO.sh
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis (CVPR 2024)
2 |
3 | ### Introduction
4 |
5 | The source code for our paper "PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis" (CVPR 2024)
6 |
7 | [**[Project Page]**](https://cszy98.github.io/PLACE/) [**[Code]**](https://github.com/cszy98/PLACE/tree/main) [**[Paper]**](https://arxiv.org/abs/2403.01852)
8 |
9 | ### Overview
10 |
11 | 
12 |
13 | ### Quick Start
14 |
15 | #### Installation
16 |
17 | ```
18 | git clone
19 | cd PLACE
20 | conda env create -f environment.yaml
21 | conda activate PLACE
22 | ```
23 |
24 | #### Data Preparation
25 |
26 | Please follow the dataset preparation process in [FreestyleNet](https://github.com/essunny310/FreestyleNet).
27 |
28 | #### Running
29 |
30 | The pre-trained models can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1b5pC52hasLwm1gOkc9LmdIyxZjrdlNWC?usp=drive_link) and should be put into the `ckpt` folder.
31 |
32 | After the dataset and pre-trained models are prepared, you may evaluate the model with the following scripts:
33 |
34 | ```
35 | # evaluate on the ADE20K dataset
36 | ./run_inference_ADE20K.sh
37 | # evaluate on the COCO-Stuff dataset
38 | ./run_inference_COCO.sh
39 | ```
40 |
41 | For out-of-distribution synthesis, you just need to modify the `ADE20K` or `COCO` dictionary in the `dataset.py`
42 |
43 | ### Citation
44 |
45 | ```
46 | @article{lv2024place,
47 | title={PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis},
48 | author={Lv, Zhengyao and Wei, Yuxiang and Zuo, Wangmeng and Kwan-Yee K. Wong},
49 | journal={IEEE Conference on Computer Vision and Pattern Recognition},
50 | year={2024}
51 | }
52 | ```
53 |
54 | ### Contact
55 |
56 | Please send mail to cszy98@gmail.com
57 |
--------------------------------------------------------------------------------
/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/tutorial_dataset_ade.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/__pycache__/tutorial_dataset_ade.cpython-38.pyc
--------------------------------------------------------------------------------
/color150.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/color150.mat
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/configs/stable-diffusion/PLACE.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.PLACE_LDM
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "tks"
12 | control_key: "hint"
13 | image_size: 64
14 | channels: 4
15 | cond_stage_trainable: false # Note: different from the one we trained before
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_factor: 0.18215
19 | use_ema: False
20 |
21 | scheduler_config: # 10000 warmup steps
22 | target: ldm.lr_scheduler.LambdaLinearScheduler
23 | params:
24 | warm_up_steps: [ 10000 ]
25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26 | f_start: [ 1.e-6 ]
27 | f_max: [ 1. ]
28 | f_min: [ 1. ]
29 |
30 | unet_config:
31 | target: ldm.modules.diffusionmodules.openaimodel.PLACEUnetModel
32 | params:
33 | image_size: 32 # unused
34 | in_channels: 4
35 | out_channels: 4
36 | model_channels: 320
37 | attention_resolutions: [ 4, 2, 1 ]
38 | num_res_blocks: 2
39 | channel_mult: [ 1, 2, 4, 4 ]
40 | num_heads: 8
41 | use_spatial_transformer: True
42 | transformer_depth: 1
43 | context_dim: 768
44 | use_checkpoint: True
45 | legacy: False
46 | catype: 'PLACE'
47 |
48 | first_stage_config:
49 | target: ldm.models.autoencoder.AutoencoderKL
50 | params:
51 | embed_dim: 4
52 | monitor: val/rec_loss
53 | ddconfig:
54 | double_z: true
55 | z_channels: 4
56 | resolution: 256
57 | in_channels: 3
58 | out_ch: 3
59 | ch: 128
60 | ch_mult:
61 | - 1
62 | - 2
63 | - 4
64 | - 4
65 | num_res_blocks: 2
66 | attn_resolutions: []
67 | dropout: 0.0
68 | lossconfig:
69 | target: torch.nn.Identity
70 |
71 | cond_stage_config:
72 | target: ldm.modules.encoders.modules.MYFrozenCLIPEmbedder
73 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import cv2
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset
6 | import random
7 | from PIL import Image
8 | import open_clip
9 | import torch
10 | import os
11 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder as CLIP
12 |
13 | clip = CLIP()
14 |
15 | ADE20K={
16 | 1 : "wall",
17 | 2 : "building edifice",
18 | 3 : "sky",
19 | 4 : "floor",
20 | 5 : "tree",
21 | 6 : "ceiling",
22 | 7 : "road",
23 | 8 : "bed",
24 | 9 : "windowpane",
25 | 10 : "grass",
26 | 11 : "cabinet",
27 | 12 : "sidewalk",
28 | 13 : "person",
29 | 14 : "earth ground",
30 | 15 : "door",
31 | 16 : "table",
32 | 17 : "mountain",
33 | 18 : "plant flora",
34 | 19 : "curtain drapery mantle pall",
35 | 20 : "chair",
36 | 21 : "car",
37 | 22 : "water",
38 | 23 : "painting picture",
39 | 24 : "sofa lounge",
40 | 25 : "shelf",
41 | 26 : "house",
42 | 27 : "sea",
43 | 28 : "mirror",
44 | 29 : "carpet",
45 | 30 : "field",
46 | 31 : "armchair",
47 | 32 : "seat",
48 | 33 : "fence",
49 | 34 : "desk",
50 | 35 : "rock stone",
51 | 36 : "wardrobe closet",
52 | 37 : "lamp",
53 | 38 : "bathtub",
54 | 39 : "railing",
55 | 40 : "cushion",
56 | 41 : "base pedestal stand",
57 | 42 : "box",
58 | 43 : "pillar",
59 | 44 : "signboard sign",
60 | 45 : "chest bureau dresser",
61 | 46 : "counter",
62 | 47 : "sand",
63 | 48 : "sink",
64 | 49 : "skyscraper",
65 | 50 : "fireplace",
66 | 51 : "refrigerator",
67 | 52 : "grandstand covered stand",
68 | 53 : "path",
69 | 54 : "stairs",
70 | 55 : "runway",
71 | 56 : "case showcase vitrine",
72 | 57 : "pool table billiard table snooker table",
73 | 58 : "pillow",
74 | 59 : "screen door",
75 | 60 : "stairway",
76 | 61 : "river",
77 | 62 : "bridge",
78 | 63 : "bookcase",
79 | 64 : "blind screen",
80 | 65 : "coffee table cocktail table",
81 | 66 : "toilet can commode potty",
82 | 67 : "flower",
83 | 68 : "book",
84 | 69 : "hill",
85 | 70 : "bench",
86 | 71 : "countertop",
87 | 72 : "kitchen range cooking stove",
88 | 73 : "palm tree",
89 | 74 : "kitchen island",
90 | 75 : "computer",
91 | 76 : "swivel chair",
92 | 77 : "boat",
93 | 78 : "bar",
94 | 79 : "arcade machine",
95 | 80 : "hovel shack shanty",
96 | 81 : "autobus motorbus omnibus",
97 | 82 : "towel",
98 | 83 : "light",
99 | 84 : "truck",
100 | 85 : "tower",
101 | 86 : "chandelier pendant",
102 | 87 : "awning sunblind",
103 | 88 : "streetlight",
104 | 89 : "booth cubicle kiosk",
105 | 90 : "television tv set idiot box boob tube telly goggle box",
106 | 91 : "airplane",
107 | 92 : "dirt track",
108 | 93 : "apparel",
109 | 94 : "pole",
110 | 95 : "land ground soil",
111 | 96 : "balustrade handrail",
112 | 97 : "escalator",
113 | 98 : "ottoman pouf hassock",
114 | 99 : "bottle",
115 | 100 : "buffet counter sideboard",
116 | 101 : "poster placard notice card",
117 | 102 : "stage",
118 | 103 : "van",
119 | 104 : "ship",
120 | 105 : "fountain",
121 | 106 : "conveyor belt transporter",
122 | 107 : "canopy",
123 | 108 : "washing machine",
124 | 109 : "toy",
125 | 110 : "swimming pool natatorium",
126 | 111 : "stool",
127 | 112 : "barrel",
128 | 113 : "basket handbasket",
129 | 114 : "waterfall",
130 | 115 : "tent",
131 | 116 : "bag",
132 | 117 : "motorbike",
133 | 118 : "cradle",
134 | 119 : "oven",
135 | 120 : "ball",
136 | 121 : "food",
137 | 122 : "stair",
138 | 123 : "storage tank",
139 | 124 : "brand marque",
140 | 125 : "microwave oven",
141 | 126 : "flowerpot",
142 | 127 : "animal fauna",
143 | 128 : "bicycle",
144 | 129 : "lake",
145 | 130 : "dishwasher",
146 | 131 : "screen silver screen projection screen",
147 | 132 : "blanket",
148 | 133 : "sculpture",
149 | 134 : "exhaust hood",
150 | 135 : "sconce",
151 | 136 : "vase",
152 | 137 : "traffic light",
153 | 138 : "tray",
154 | 139 : "ashcan trash can dustbin",
155 | 140 : "fan",
156 | 141 : "pier wharfage dock",
157 | 142 : "crt screen",
158 | 143 : "plate",
159 | 144 : "monitoring device",
160 | 145 : "notice board",
161 | 146 : "shower",
162 | 147 : "radiator",
163 | 148 : "drinking glass",
164 | 149 : "clock",
165 | 150 : "flag"
166 | }
167 |
168 | COCO={
169 | 1 : "person",
170 | 2 : "bicycle",
171 | 3 : "car",
172 | 4 : "motorcycle",
173 | 5 : "airplane",
174 | 6 : "bus",
175 | 7 : "train",
176 | 8 : "truck",
177 | 9 : "boat",
178 | 10 : "traffic light",
179 | 11 : "fire hydrant",
180 | 12 : "street sign",
181 | 13 : "stop sign",
182 | 14 : "parking meter",
183 | 15 : "bench",
184 | 16 : "bird",
185 | 17 : "cat",
186 | 18 : "dog",
187 | 19 : "horse",
188 | 20 : "sheep",
189 | 21 : "cow",
190 | 22 : "elephant",
191 | 23 : "bear",
192 | 24 : "zebra",
193 | 25 : "giraffe",
194 | 26 : "hat",
195 | 27 : "backpack",
196 | 28 : "umbrella",
197 | 29 : "shoe",
198 | 30 : "eye glasses",
199 | 31 : "handbag",
200 | 32 : "tie",
201 | 33 : "suitcase",
202 | 34 : "frisbee",
203 | 35 : "skis",
204 | 36 : "snowboard",
205 | 37 : "sports ball",
206 | 38 : "kite",
207 | 39 : "baseball bat",
208 | 40 : "baseball glove",
209 | 41 : "skateboard",
210 | 42 : "surfboard",
211 | 43 : "tennis racket",
212 | 44 : "bottle",
213 | 45 : "plate",
214 | 46 : "wine glass",
215 | 47 : "cup",
216 | 48 : "fork",
217 | 49 : "knife",
218 | 50 : "spoon",
219 | 51 : "bowl",
220 | 52 : "banana",
221 | 53 : "apple",
222 | 54 : "sandwich",
223 | 55 : "orange",
224 | 56 : "broccoli",
225 | 57 : "carrot",
226 | 58 : "hot dog",
227 | 59 : "pizza",
228 | 60 : "donut",
229 | 61 : "cake",
230 | 62 : "chair",
231 | 63 : "couch",
232 | 64 : "potted plant",
233 | 65 : "bed",
234 | 66 : "mirror",
235 | 67 : "dining table",
236 | 68 : "window",
237 | 69 : "desk",
238 | 70 : "toilet",
239 | 71 : "door",
240 | 72 : "tv",
241 | 73 : "laptop",
242 | 74 : "mouse",
243 | 75 : "remote",
244 | 76 : "keyboard",
245 | 77 : "cell phone",
246 | 78 : "microwave",
247 | 79 : "oven",
248 | 80 : "toaster",
249 | 81 : "sink",
250 | 82 : "refrigerator",
251 | 83 : "blender",
252 | 84 : "book",
253 | 85 : "clock",
254 | 86 : "vase",
255 | 87 : "scissors",
256 | 88 : "teddy bear",
257 | 89 : "hair drier",
258 | 90 : "toothbrush",
259 | 91 : "hair brush",
260 | 92 : "banner",
261 | 93 : "blanket",
262 | 94 : "branch",
263 | 95 : "bridge",
264 | 96 : "building",
265 | 97 : "bush",
266 | 98 : "cabinet",
267 | 99 : "cage",
268 | 100 : "cardboard",
269 | 101 : "carpet",
270 | 102 : "ceiling",
271 | 103 : "tile ceiling",
272 | 104 : "cloth",
273 | 105 : "clothes",
274 | 106 : "clouds",
275 | 107 : "counter",
276 | 108 : "cupboard",
277 | 109 : "curtain",
278 | 110 : "desk",
279 | 111 : "dirt",
280 | 112 : "door",
281 | 113 : "fence",
282 | 114 : "marble floor",
283 | 115 : "floor",
284 | 116 : "stone floor",
285 | 117 : "tile floor",
286 | 118 : "wood floor",
287 | 119 : "flower",
288 | 120 : "fog",
289 | 121 : "food",
290 | 122 : "fruit",
291 | 123 : "furniture",
292 | 124 : "grass",
293 | 125 : "gravel",
294 | 126 : "ground",
295 | 127 : "hill",
296 | 128 : "house",
297 | 129 : "leaves",
298 | 130 : "light",
299 | 131 : "mat",
300 | 132 : "metal",
301 | 133 : "mirror",
302 | 134 : "moss",
303 | 135 : "mountain",
304 | 136 : "mud",
305 | 137 : "napkin",
306 | 138 : "net",
307 | 139 : "paper",
308 | 140 : "pavement",
309 | 141 : "pillow",
310 | 142 : "plant",
311 | 143 : "plastic",
312 | 144 : "platform",
313 | 145 : "playingfield",
314 | 146 : "railing",
315 | 147 : "railroad",
316 | 148 : "river",
317 | 149 : "road",
318 | 150 : "rock",
319 | 151 : "roof",
320 | 152 : "rug",
321 | 153 : "salad",
322 | 154 : "sand",
323 | 155 : "sea",
324 | 156 : "shelf",
325 | 157 : "sky",
326 | 158 : "skyscraper",
327 | 159 : "snow",
328 | 160 : "solid",
329 | 161 : "stairs",
330 | 162 : "stone",
331 | 163 : "straw",
332 | 164 : "structural",
333 | 165 : "table",
334 | 166 : "tent",
335 | 167 : "textile",
336 | 168 : "towel",
337 | 169 : "tree",
338 | 170 : "vegetable",
339 | 171 : "brick wall",
340 | 172 : "concrete wall",
341 | 173 : "wall",
342 | 174 : "panel wall",
343 | 175 : "stone wall",
344 | 176 : "tile wall",
345 | 177 : "wood wall",
346 | 178 : "water",
347 | 179 : "waterdrops",
348 | 180 : "blind window",
349 | 181 : "window",
350 | 182 : "wood"
351 | }
352 |
353 | def gettks(tkss):
354 | newtks = []
355 | for i in range(77):
356 | if tkss[0,i]==49407:
357 | break
358 | elif tkss[0,i]==49406:
359 | continue
360 | elif tkss[0,i]==267:
361 | continue
362 | else:
363 | newtks.append(int(tkss[0,i]))
364 | return newtks
365 |
366 | class ADE20KDataset(Dataset):
367 | def __init__(self, data_root, phase='Val'):
368 | self.data = []
369 | path = data_root
370 | prefix = 'validation' if phase=='Val' else 'training'
371 | files = os.listdir(path+'/images/'+prefix+'/') if phase=='Train' else os.listdir(path+'/images/'+prefix+'/')
372 | for fn in files:
373 | self.data.append({'target':path+'/images/'+prefix+'/'+fn,'source':path+'/annotations/'+prefix+'/'+fn.replace('.jpg','.png')})
374 |
375 | self.labeldic={}
376 | self.namedic={}
377 | for k in ADE20K.keys():
378 | self.namedic[int(k)-1] = ADE20K[k]
379 |
380 | batch_encoding = clip.tokenizer(ADE20K[k], truncation=True, max_length=clip.max_length, return_length=True,
381 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
382 | tokens = batch_encoding["input_ids"]
383 | corr_tks = gettks(tokens)
384 |
385 | self.labeldic[int(k)-1] = corr_tks
386 |
387 | def __len__(self):
388 | return len(self.data)
389 |
390 | def __getitem__(self, idx):
391 | item = self.data[idx]
392 |
393 | source_filename = item['source']
394 | target_filename = item['target']
395 |
396 | source = np.array(Image.open(source_filename).resize((512,512),Image.NEAREST),dtype=np.float)
397 | target = np.array(Image.open(target_filename).resize((512,512),Image.BICUBIC).convert('RGB'))
398 |
399 | tokens = np.full((77),49407)
400 | tokens[0] = 49406
401 | token_list = []
402 | tokens_cls = np.full((77),49407)
403 | tokens_cls[0] = 49406
404 | tokens_cls_list = []
405 | source_minus_1 = source - 1
406 | prompt = ''
407 | for lb in np.unique(source_minus_1):
408 | if lb==-1:
409 | continue
410 | token_list+=self.labeldic[lb]
411 | tokens_cls_list += [lb]*len(self.labeldic[lb])
412 | prompt += self.namedic[lb] + ','
413 | tokens[1:len(token_list)+1] = np.array(token_list)
414 | tokens_cls[1:len(token_list)+1] = np.array(tokens_cls_list)
415 |
416 | # Normalize source images to [0, 1].
417 | viewsource = np.zeros((512,512,3))
418 | viewsource[:,:,0] = source
419 | viewsource[:,:,1] = source
420 | viewsource[:,:,2] = source
421 | viewsource = np.array(viewsource,dtype=np.uint8)
422 | viewsource = viewsource.astype(np.float32) / 255.0
423 |
424 | # Normalize target images to [-1, 1].
425 | target = (target.astype(np.float32) / 127.5) - 1.0
426 |
427 | source = source_minus_1[:,:,np.newaxis]
428 |
429 | assert viewsource.shape==target.shape, str(viewsource.shape)+' '+str(target.shape) + ' '+item['target']+' ' + item['source']
430 |
431 | return dict(jpg=target, txt=prompt, tks=tokens, hint=source, targetpath=item['target'], sourcepath=item['source'], viewcontrol=viewsource, tokens_cls=tokens_cls)
432 |
433 | class COCODataset(Dataset):
434 | def __init__(self, data_root, phase='Val'):
435 | self.data = []
436 | path = data_root
437 | prefix = 'val2017' if phase=='Val' else 'train2017'
438 | files = os.listdir(path+prefix)
439 | for fn in files:
440 | self.data.append({'target':path+'/'+prefix+'/'+fn,'source':path+'/stuffthingmaps_trainval2017/'+prefix+'/'+fn.replace('.jpg','.png')})
441 |
442 | self.labeldic={}
443 | self.namedic={}
444 | for k in COCO.keys():
445 | self.namedic[int(k)-1] = COCO[k]
446 |
447 | batch_encoding = clip.tokenizer(COCO[k], truncation=True, max_length=clip.max_length, return_length=True,
448 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
449 | tokens = batch_encoding["input_ids"]
450 | corr_tks = gettks(tokens)
451 |
452 | self.labeldic[int(k)-1] = corr_tks
453 |
454 | def __len__(self):
455 | return len(self.data)
456 |
457 | def __getitem__(self, idx):
458 | item = self.data[idx]
459 |
460 | source_filename = item['source']
461 | target_filename = item['target']
462 |
463 | source = np.array(Image.open(source_filename).resize((512,512),Image.NEAREST),dtype=np.float)
464 | source = np.where(source==255,-1,source)
465 | source = source + 1
466 |
467 | target = np.array(Image.open(target_filename).resize((512,512),Image.BILINEAR).convert('RGB'))
468 |
469 | tokens = np.full((77),49407)
470 | tokens[0] = 49406
471 | token_list = []
472 | tokens_cls = np.full((77),49407)
473 | tokens_cls[0] = 49406
474 | tokens_cls_list = []
475 | source_minus_1 = source - 1
476 | prompt = ''
477 | for lb in np.unique(source_minus_1):
478 | if lb==-1:
479 | continue
480 | token_list+=self.labeldic[lb]
481 | tokens_cls_list += [lb]*len(self.labeldic[lb])
482 | prompt += self.namedic[lb] + ','
483 | tokens[1:len(token_list)+1] = np.array(token_list)
484 | tokens_cls[1:len(token_list)+1] = np.array(tokens_cls_list)
485 |
486 | # Normalize source images to [0, 1].
487 | viewsource = np.zeros((512,512,3))
488 | viewsource[:,:,0] = source
489 | viewsource[:,:,1] = source
490 | viewsource[:,:,2] = source
491 | viewsource = np.array(viewsource,dtype=np.uint8)
492 | viewsource = viewsource.astype(np.float32) / 255.0
493 |
494 | # Normalize target images to [-1, 1].
495 | target = (target.astype(np.float32) / 127.5) - 1.0
496 |
497 | source = source_minus_1[:,:,np.newaxis]
498 |
499 | assert viewsource.shape==target.shape, str(viewsource.shape)+' '+str(target.shape) + ' '+item['target']+' ' + item['source']
500 |
501 | return dict(jpg=target, txt=prompt, tks=tokens, hint=source, targetpath=item['target'], sourcepath=item['source'], viewcontrol=viewsource, tokens_cls=tokens_cls)
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: PLACE
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
12 | - pip:
13 | - gradio==3.16.2
14 | - albumentations==1.3.0
15 | - opencv-contrib-python==4.3.0.36
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.5.0
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit==1.12.1
22 | - einops==0.3.0
23 | - transformers==4.19.2
24 | - webdataset==0.2.5
25 | - kornia==0.6
26 | - open_clip_torch==2.0.2
27 | - invisible-watermark>=0.1.5
28 | - streamlit-drawable-canvas==0.8.0
29 | - torchmetrics==0.6.0
30 | - timm==0.6.12
31 | - addict==2.4.0
32 | - yapf==0.32.0
33 | - prettytable==3.6.0
34 | - safetensors==0.2.7
35 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
118 |
119 |
120 |
121 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
136 |
137 |
138 |
139 |
140 |
141 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis (CVPR 2024)
142 |
143 |
167 |
168 |
169 |
170 |
171 | [Paper]
172 |
173 | |
174 |
175 |
176 | [GitHub]
177 |
178 | |
179 |
180 |
181 |
182 |
183 |
184 |
202 |
203 |
204 |
205 |
206 | Abstract
207 |
208 |
209 | Recent advancements in large-scale pre-trained text-to-image models have led to remarkable progress in semantic image synthesis. Nevertheless, synthesizing high-quality images with consistent semantics and layout remains a challenge. In t
210 | his paper, we propose the adaPtive LAyout-semantiC fusion modulE (PLACE) that harnesses pre-trained models to alleviate the aforementioned issues. Specifically, we first employ the layout control map to faithfully represent layouts in the feature space. Subsequently, we combine the layout and semantic features in a timestep-adaptive manner to synthesize images with realistic details. During fine-tuning, we propose the Semantic Alignment (SA) loss to further enhance layout alignment. Additionally, we introduce the Layout-Free Prior Preservation (LFP) loss, which leverages unlabeled data to maintain the priors of pre-trained models, thereby improving the visual quality and semantic consistency of synthesized images. Extensive experiments demonstrate that our approach performs favorably in terms of visual quality, semantic consistency, and layout alignment.
211 | |
212 |
213 |
214 |
215 |
216 |
221 |
222 |
232 |
233 | Overview
234 |
235 |
236 |
237 |
238 |
239 | |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |  |
248 |
249 | |
250 |
251 |
252 |
253 |
254 |
255 |
256 | Overview of our method. (a) We utilize the layout control map calculated from semantic map and PLACE for layout control. During fine-tuning, we combine the LDM, SA, and LFP loss as optimization objective. (b) Calculation of the layout control map and details of the adaptive layout-semantic fusion module. Each vector in the Layout Control Map encodes all the semantic components in the reception field. The adaptive layout-semantic fusion module blends the layout and semantics feature in a timestep-adaptive way.
257 | |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 | Results
266 |
267 |
268 |
269 |
270 |
271 | |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 | Visual comparisons on ADE20K and COCO-Stuff
282 | |
283 |
284 | |
285 |
286 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 | Visual comparisons for out-of-distribution synthesis
300 | |
301 |
302 | |
303 |
304 |
309 |
310 |
311 |
320 |
321 |
328 |
329 |
330 |
331 | Paper and Supplementary Material
332 |
333 |  |
334 | Zhengyao Lv, Yuxiang Wei, Wangmeng Zuo and Kwan-Yee K. Wong
335 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis
336 | In CVPR, 2024.
337 | (hosted on ArXiv)
338 |
339 |
340 |
341 | |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 | [Bibtex]
350 | |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 | Acknowledgements
362 | This template was originally made by Phillip Isola and Richard Zhang for a colorful ECCV project; the code can be found here.
363 |
364 | |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import einops
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | import os, argparse
7 |
8 | from ldm.models.diffusion.plms import PLMSSampler
9 | from torch.utils.data import DataLoader
10 | from dataset import ADE20KDataset, COCODataset
11 | from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
12 |
13 | from omegaconf import OmegaConf
14 | from ldm.util import instantiate_from_config
15 |
16 | from pytorch_lightning import seed_everything
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser()
21 |
22 | parser.add_argument(
23 | "--outdir",
24 | type=str,
25 | nargs="?",
26 | help="dir to write results to",
27 | default="output/ade20k"
28 | )
29 |
30 | parser.add_argument(
31 | "--config",
32 | type=str,
33 | default="configs/stable-diffusion/xxx.yaml",
34 | help="path to config which constructs model",
35 | )
36 |
37 | parser.add_argument(
38 | "--seed",
39 | type=int,
40 | default=42,
41 | help="the seed (for reproducible sampling)",
42 | )
43 |
44 | parser.add_argument(
45 | "--data_root",
46 | type=str,
47 | required=True,
48 | help="Path to dataset directory"
49 | )
50 |
51 | parser.add_argument(
52 | "--dataset",
53 | type=str,
54 | help="which dataset to evaluate",
55 | choices=["COCO", "ADE20K"],
56 | default="COCO"
57 | )
58 |
59 | parser.add_argument(
60 | "--ckpt",
61 | type=str,
62 | default="models/ade20k.ckpt",
63 | help="path to checkpoint of model",
64 | )
65 |
66 | opt = parser.parse_args()
67 |
68 | seed_everything(opt.seed)
69 |
70 | def get_state_dict(d):
71 | return d.get('state_dict', d)
72 |
73 | targetpth = opt.outdir
74 | if not os.path.exists(targetpth):
75 | os.mkdir(targetpth)
76 |
77 | config = OmegaConf.load(opt.config)
78 | model = instantiate_from_config(config.model).cpu()
79 |
80 | state_dict = get_state_dict(torch.load(opt.ckpt, map_location=torch.device('cuda')))
81 | state_dict = get_state_dict(state_dict)
82 | model.load_state_dict(state_dict)
83 | model = model.cuda()
84 |
85 | sampler = PLMSSampler(model)
86 |
87 | if opt.dataset == 'ADE20K':
88 | dataset = ADE20KDataset(opt.data_root)
89 | elif opt.dataset == 'COCO':
90 | dataset = COCODataset(opt.data_root)
91 |
92 | dataloader = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=False)
93 | for batch in dataloader:
94 | names = batch['sourcepath'][0].split('/')[-1]
95 | print('processing:',targetpth+'/'+names)
96 | N = 1
97 | z, c = model.get_input(batch, model.first_stage_key, bs=1)
98 | c_tkscls = c['tkscls'][0][:N]
99 | c_cat, c, view_ctrol = c["c_concat"][0][:N], c["c_crossattn"][0][:N], c['viewcontrol'][0][:N]
100 |
101 | uc_cross = torch.zeros((c.shape[0],77),dtype=torch.int64) + 49407
102 | uc_cross[:,0] = 49406
103 | uc_cross = model.cond_stage_model.encode(uc_cross.to(model.device))
104 | if isinstance(uc_cross, DiagonalGaussianDistribution):
105 | uc_cross = uc_cross.mode()
106 |
107 | uc_cat = c_cat
108 | uc_tkscls = torch.zeros_like(c_tkscls) + 49407
109 | uc_tkscls[:,0] = 49406
110 | uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross], "tkscls": [uc_tkscls]}
111 |
112 | cond = {"c_concat": [c_cat], "c_crossattn": [c], "tkscls": [c_tkscls]}
113 | un_cond ={"c_concat": [uc_cat], "c_crossattn": [uc_cross], "tkscls": [uc_tkscls]}
114 |
115 | H,W=512,512
116 | shape = (4, H // 8, W // 8)
117 | samples_ddim, _ = sampler.sample(50,
118 | conditioning=cond,
119 | batch_size=1,
120 | shape=shape,
121 | verbose=False,
122 | unconditional_guidance_scale=2.0,
123 | unconditional_conditioning=un_cond,
124 | eta=0.0,
125 | x_T=None)
126 | x_samples = model.decode_first_stage(samples_ddim)
127 | x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
128 |
129 | results = [x_samples[i] for i in range(1)]
130 | Image.fromarray(results[0]).save(os.path.join(targetpth,names.replace('png','jpg')))
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/ldm/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/.DS_Store
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 |
6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder
7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8 |
9 | from ldm.util import instantiate_from_config
10 | from ldm.modules.ema import LitEma
11 |
12 |
13 | class AutoencoderKL(pl.LightningModule):
14 | def __init__(self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | ckpt_path=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | ema_decay=None,
24 | learn_logvar=False
25 | ):
26 | super().__init__()
27 | self.learn_logvar = learn_logvar
28 | self.image_key = image_key
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 | self.loss = instantiate_from_config(lossconfig)
32 | assert ddconfig["double_z"]
33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35 | self.embed_dim = embed_dim
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | self.use_ema = ema_decay is not None
43 | if self.use_ema:
44 | self.ema_decay = ema_decay
45 | assert 0. < ema_decay < 1.
46 | self.model_ema = LitEma(self, decay=ema_decay)
47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48 |
49 | if ckpt_path is not None:
50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | sd = torch.load(path, map_location="cpu")["state_dict"]
54 | keys = list(sd.keys())
55 | for k in keys:
56 | for ik in ignore_keys:
57 | if k.startswith(ik):
58 | print("Deleting key {} from state_dict.".format(k))
59 | del sd[k]
60 | self.load_state_dict(sd, strict=False)
61 | print(f"Restored from {path}")
62 |
63 | @contextmanager
64 | def ema_scope(self, context=None):
65 | if self.use_ema:
66 | self.model_ema.store(self.parameters())
67 | self.model_ema.copy_to(self)
68 | if context is not None:
69 | print(f"{context}: Switched to EMA weights")
70 | try:
71 | yield None
72 | finally:
73 | if self.use_ema:
74 | self.model_ema.restore(self.parameters())
75 | if context is not None:
76 | print(f"{context}: Restored training weights")
77 |
78 | def on_train_batch_end(self, *args, **kwargs):
79 | if self.use_ema:
80 | self.model_ema(self)
81 |
82 | def encode(self, x):
83 | h = self.encoder(x)
84 | moments = self.quant_conv(h)
85 | posterior = DiagonalGaussianDistribution(moments)
86 | return posterior
87 |
88 | def decode(self, z):
89 | z = self.post_quant_conv(z)
90 | dec = self.decoder(z)
91 | return dec
92 |
93 | def forward(self, input, sample_posterior=True):
94 | posterior = self.encode(input)
95 | if sample_posterior:
96 | z = posterior.sample()
97 | else:
98 | z = posterior.mode()
99 | dec = self.decode(z)
100 | return dec, posterior
101 |
102 | def get_input(self, batch, k):
103 | x = batch[k]
104 | if len(x.shape) == 3:
105 | x = x[..., None]
106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107 | return x
108 |
109 | def training_step(self, batch, batch_idx, optimizer_idx):
110 | inputs = self.get_input(batch, self.image_key)
111 | reconstructions, posterior = self(inputs)
112 |
113 | if optimizer_idx == 0:
114 | # train encoder+decoder+logvar
115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116 | last_layer=self.get_last_layer(), split="train")
117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119 | return aeloss
120 |
121 | if optimizer_idx == 1:
122 | # train the discriminator
123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124 | last_layer=self.get_last_layer(), split="train")
125 |
126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128 | return discloss
129 |
130 | def validation_step(self, batch, batch_idx):
131 | log_dict = self._validation_step(batch, batch_idx)
132 | with self.ema_scope():
133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134 | return log_dict
135 |
136 | def _validation_step(self, batch, batch_idx, postfix=""):
137 | inputs = self.get_input(batch, self.image_key)
138 | reconstructions, posterior = self(inputs)
139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140 | last_layer=self.get_last_layer(), split="val"+postfix)
141 |
142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143 | last_layer=self.get_last_layer(), split="val"+postfix)
144 |
145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146 | self.log_dict(log_dict_ae)
147 | self.log_dict(log_dict_disc)
148 | return self.log_dict
149 |
150 | def configure_optimizers(self):
151 | lr = self.learning_rate
152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154 | if self.learn_logvar:
155 | print(f"{self.__class__.__name__}: Learning logvar")
156 | ae_params_list.append(self.loss.logvar)
157 | opt_ae = torch.optim.Adam(ae_params_list,
158 | lr=lr, betas=(0.5, 0.9))
159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160 | lr=lr, betas=(0.5, 0.9))
161 | return [opt_ae, opt_disc], []
162 |
163 | def get_last_layer(self):
164 | return self.decoder.conv_out.weight
165 |
166 | @torch.no_grad()
167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168 | log = dict()
169 | x = self.get_input(batch, self.image_key)
170 | x = x.to(self.device)
171 | if not only_inputs:
172 | xrec, posterior = self(x)
173 | if x.shape[1] > 3:
174 | # colorize with random projection
175 | assert xrec.shape[1] > 3
176 | x = self.to_rgb(x)
177 | xrec = self.to_rgb(xrec)
178 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179 | log["reconstructions"] = xrec
180 | if log_ema or self.use_ema:
181 | with self.ema_scope():
182 | xrec_ema, posterior_ema = self(x)
183 | if x.shape[1] > 3:
184 | # colorize with random projection
185 | assert xrec_ema.shape[1] > 3
186 | xrec_ema = self.to_rgb(xrec_ema)
187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188 | log["reconstructions_ema"] = xrec_ema
189 | log["inputs"] = x
190 | return log
191 |
192 | def to_rgb(self, x):
193 | assert self.image_key == "segmentation"
194 | if not hasattr(self, "colorize"):
195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196 | x = F.conv2d(x, weight=self.colorize)
197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198 | return x
199 |
200 |
201 | class IdentityFirstStage(torch.nn.Module):
202 | def __init__(self, *args, vq_interface=False, **kwargs):
203 | self.vq_interface = vq_interface
204 | super().__init__()
205 |
206 | def encode(self, x, *args, **kwargs):
207 | return x
208 |
209 | def decode(self, x, *args, **kwargs):
210 | return x
211 |
212 | def quantize(self, x, *args, **kwargs):
213 | if self.vq_interface:
214 | return x, None, [None, None, None]
215 | return x
216 |
217 | def forward(self, x, *args, **kwargs):
218 | return x
219 |
220 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/myplms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/myplms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/testplms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/testplms.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 |
7 | MODEL_TYPES = {
8 | "eps": "noise",
9 | "v": "v"
10 | }
11 |
12 |
13 | class DPMSolverSampler(object):
14 | def __init__(self, model, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != torch.device("cuda"):
23 | attr = attr.to(torch.device("cuda"))
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54 | if cbs != batch_size:
55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56 | else:
57 | if conditioning.shape[0] != batch_size:
58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59 |
60 | # sampling
61 | C, H, W = shape
62 | size = (batch_size, C, H, W)
63 |
64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65 |
66 | device = self.model.betas.device
67 | if x_T is None:
68 | img = torch.randn(size, device=device)
69 | else:
70 | img = x_T
71 |
72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73 |
74 | model_fn = model_wrapper(
75 | lambda x, t, c: self.model.apply_model(x, t, c),
76 | ns,
77 | model_type=MODEL_TYPES[self.model.parameterization],
78 | guidance_type="classifier-free",
79 | condition=conditioning,
80 | unconditional_condition=unconditional_conditioning,
81 | guidance_scale=unconditional_guidance_scale,
82 | )
83 |
84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86 |
87 | return x.to(device), None
--------------------------------------------------------------------------------
/ldm/models/diffusion/myplms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class MYPLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | '''
85 | if conditioning is not None:
86 | if isinstance(conditioning, dict):
87 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
88 | if cbs != batch_size:
89 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
90 | else:
91 | if conditioning.shape[0] != batch_size:
92 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
93 | '''
94 | if conditioning is not None:
95 | if isinstance(conditioning, dict):
96 | ctmp = conditioning[list(conditioning.keys())[0]]
97 | while isinstance(ctmp, list): ctmp = ctmp[0]
98 | cbs = ctmp.shape[0]
99 | if cbs != batch_size:
100 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
101 |
102 | elif isinstance(conditioning, list):
103 | for ctmp in conditioning:
104 | if ctmp.shape[0] != batch_size:
105 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
106 |
107 | else:
108 | if conditioning.shape[0] != batch_size:
109 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
110 |
111 |
112 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
113 | # sampling
114 | C, H, W = shape
115 | size = (batch_size, C, H, W)
116 | print(f'Data shape for PLMS sampling is {size}')
117 |
118 | samples, intermediates = self.plms_sampling(conditioning, size,
119 | callback=callback,
120 | img_callback=img_callback,
121 | quantize_denoised=quantize_x0,
122 | mask=mask, x0=x0,
123 | ddim_use_original_steps=False,
124 | noise_dropout=noise_dropout,
125 | temperature=temperature,
126 | score_corrector=score_corrector,
127 | corrector_kwargs=corrector_kwargs,
128 | x_T=x_T,
129 | log_every_t=log_every_t,
130 | unconditional_guidance_scale=unconditional_guidance_scale,
131 | unconditional_conditioning=unconditional_conditioning,
132 | dynamic_threshold=dynamic_threshold,
133 | )
134 | return samples, intermediates
135 |
136 | @torch.no_grad()
137 | def plms_sampling(self, cond, shape,
138 | x_T=None, ddim_use_original_steps=False,
139 | callback=None, timesteps=None, quantize_denoised=False,
140 | mask=None, x0=None, img_callback=None, log_every_t=100,
141 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
142 | unconditional_guidance_scale=1., unconditional_conditioning=None,
143 | dynamic_threshold=None):
144 | device = self.model.betas.device
145 | b = shape[0]
146 | if x_T is None:
147 | img = torch.randn(shape, device=device)
148 | print('updated!')
149 | img = torch.nn.init.trunc_normal_(img, mean=0.0, std=1.0, a=-3.0, b=3.0)
150 | else:
151 | img = x_T
152 |
153 | if timesteps is None:
154 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
155 | elif timesteps is not None and not ddim_use_original_steps:
156 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
157 | timesteps = self.ddim_timesteps[:subset_end]
158 |
159 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
160 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
161 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
162 | print(f"Running PLMS Sampling with {total_steps} timesteps")
163 |
164 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
165 | old_eps = []
166 |
167 | for i, step in enumerate(iterator):
168 | index = total_steps - i - 1
169 | ts = torch.full((b,), step, device=device, dtype=torch.long)
170 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
171 |
172 | if mask is not None:
173 | assert x0 is not None
174 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
175 | img = img_orig * mask + (1. - mask) * img
176 |
177 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
178 | quantize_denoised=quantize_denoised, temperature=temperature,
179 | noise_dropout=noise_dropout, score_corrector=score_corrector,
180 | corrector_kwargs=corrector_kwargs,
181 | unconditional_guidance_scale=unconditional_guidance_scale,
182 | unconditional_conditioning=unconditional_conditioning,
183 | old_eps=old_eps, t_next=ts_next,
184 | dynamic_threshold=dynamic_threshold)
185 | img, pred_x0, e_t = outs
186 | old_eps.append(e_t)
187 | if len(old_eps) >= 4:
188 | old_eps.pop(0)
189 | if callback: callback(i)
190 | if img_callback: img_callback(pred_x0, i)
191 |
192 | if index % log_every_t == 0 or index == total_steps - 1:
193 | intermediates['x_inter'].append(img)
194 | intermediates['pred_x0'].append(pred_x0)
195 |
196 | return img, intermediates
197 |
198 | @torch.no_grad()
199 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
200 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
201 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
202 | dynamic_threshold=None):
203 | b, *_, device = *x.shape, x.device
204 |
205 | def get_model_output(x, t):
206 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
207 | e_t = self.model.apply_model(x, t, c)
208 | else:
209 | x_in = torch.cat([x] * 2)
210 | t_in = torch.cat([t] * 2)
211 | #c_in = torch.cat([unconditional_conditioning, c])
212 |
213 | if isinstance(c, dict):
214 | assert isinstance(unconditional_conditioning, dict)
215 | c_in = dict()
216 | for k in c:
217 | if isinstance(c[k], list):
218 | c_in[k] = [torch.cat([
219 | unconditional_conditioning[k][i],
220 | c[k][i]]) for i in range(len(c[k]))]
221 | else:
222 | c_in[k] = torch.cat([
223 | unconditional_conditioning[k],
224 | c[k]])
225 | elif isinstance(c, list):
226 | c_in = list()
227 | assert isinstance(unconditional_conditioning, list)
228 | for i in range(len(c)):
229 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
230 | else:
231 | c_in = torch.cat([unconditional_conditioning, c])
232 |
233 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
234 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
235 | if score_corrector is not None:
236 | assert self.model.parameterization == "eps"
237 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
238 |
239 | return e_t
240 |
241 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
242 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
243 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
244 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
245 |
246 | def get_x_prev_and_pred_x0(e_t, index):
247 | # select parameters corresponding to the currently considered timestep
248 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
249 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
250 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
251 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
252 |
253 | # current prediction for x_0
254 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
255 | if quantize_denoised:
256 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
257 | if dynamic_threshold is not None:
258 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
259 | # direction pointing to x_t
260 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
261 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
262 | if noise_dropout > 0.:
263 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
264 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
265 | return x_prev, pred_x0
266 |
267 | e_t = get_model_output(x, t)
268 | if len(old_eps) == 0:
269 | # Pseudo Improved Euler (2nd order)
270 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
271 | e_t_next = get_model_output(x_prev, t_next)
272 | e_t_prime = (e_t + e_t_next) / 2
273 | elif len(old_eps) == 1:
274 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
275 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
276 | elif len(old_eps) == 2:
277 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
278 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
279 | elif len(old_eps) >= 3:
280 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
281 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
282 |
283 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
284 |
285 | return x_prev, pred_x0, e_t
286 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/plms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class PLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | conditioning=None,
64 | callback=None,
65 | normals_sequence=None,
66 | img_callback=None,
67 | quantize_x0=False,
68 | eta=0.,
69 | mask=None,
70 | x0=None,
71 | temperature=1.,
72 | noise_dropout=0.,
73 | score_corrector=None,
74 | corrector_kwargs=None,
75 | verbose=True,
76 | x_T=None,
77 | log_every_t=100,
78 | unconditional_guidance_scale=1.,
79 | unconditional_conditioning=None,
80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81 | dynamic_threshold=None,
82 | **kwargs
83 | ):
84 | '''
85 | if conditioning is not None:
86 | if isinstance(conditioning, dict):
87 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
88 | if cbs != batch_size:
89 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
90 | else:
91 | if conditioning.shape[0] != batch_size:
92 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
93 | '''
94 | if conditioning is not None:
95 | if isinstance(conditioning, dict):
96 | ctmp = conditioning[list(conditioning.keys())[0]]
97 | while isinstance(ctmp, list): ctmp = ctmp[0]
98 | cbs = ctmp.shape[0]
99 | if cbs != batch_size:
100 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
101 |
102 | elif isinstance(conditioning, list):
103 | for ctmp in conditioning:
104 | if ctmp.shape[0] != batch_size:
105 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
106 |
107 | else:
108 | if conditioning.shape[0] != batch_size:
109 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
110 |
111 |
112 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
113 | # sampling
114 | C, H, W = shape
115 | size = (batch_size, C, H, W)
116 | print(f'Data shape for PLMS sampling is {size}')
117 |
118 | samples, intermediates = self.plms_sampling(conditioning, size,
119 | callback=callback,
120 | img_callback=img_callback,
121 | quantize_denoised=quantize_x0,
122 | mask=mask, x0=x0,
123 | ddim_use_original_steps=False,
124 | noise_dropout=noise_dropout,
125 | temperature=temperature,
126 | score_corrector=score_corrector,
127 | corrector_kwargs=corrector_kwargs,
128 | x_T=x_T,
129 | log_every_t=log_every_t,
130 | unconditional_guidance_scale=unconditional_guidance_scale,
131 | unconditional_conditioning=unconditional_conditioning,
132 | dynamic_threshold=dynamic_threshold,
133 | )
134 | return samples, intermediates
135 |
136 | @torch.no_grad()
137 | def plms_sampling(self, cond, shape,
138 | x_T=None, ddim_use_original_steps=False,
139 | callback=None, timesteps=None, quantize_denoised=False,
140 | mask=None, x0=None, img_callback=None, log_every_t=100,
141 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
142 | unconditional_guidance_scale=1., unconditional_conditioning=None,
143 | dynamic_threshold=None):
144 | device = self.model.betas.device
145 | b = shape[0]
146 | if x_T is None:
147 | img = torch.randn(shape, device=device)
148 | else:
149 | img = x_T
150 |
151 | if timesteps is None:
152 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
153 | elif timesteps is not None and not ddim_use_original_steps:
154 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
155 | timesteps = self.ddim_timesteps[:subset_end]
156 |
157 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
158 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
159 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
160 | print(f"Running PLMS Sampling with {total_steps} timesteps")
161 |
162 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
163 | old_eps = []
164 |
165 | for i, step in enumerate(iterator):
166 | index = total_steps - i - 1
167 | ts = torch.full((b,), step, device=device, dtype=torch.long)
168 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
169 |
170 | if mask is not None:
171 | assert x0 is not None
172 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
173 | img = img_orig * mask + (1. - mask) * img
174 |
175 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
176 | quantize_denoised=quantize_denoised, temperature=temperature,
177 | noise_dropout=noise_dropout, score_corrector=score_corrector,
178 | corrector_kwargs=corrector_kwargs,
179 | unconditional_guidance_scale=unconditional_guidance_scale,
180 | unconditional_conditioning=unconditional_conditioning,
181 | old_eps=old_eps, t_next=ts_next,
182 | dynamic_threshold=dynamic_threshold)
183 | img, pred_x0, e_t = outs
184 | old_eps.append(e_t)
185 | if len(old_eps) >= 4:
186 | old_eps.pop(0)
187 | if callback: callback(i)
188 | if img_callback: img_callback(pred_x0, i)
189 |
190 | if index % log_every_t == 0 or index == total_steps - 1:
191 | intermediates['x_inter'].append(img)
192 | intermediates['pred_x0'].append(pred_x0)
193 |
194 | return img, intermediates
195 |
196 | @torch.no_grad()
197 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
198 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
199 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
200 | dynamic_threshold=None):
201 | b, *_, device = *x.shape, x.device
202 |
203 | def get_model_output(x, t):
204 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
205 | e_t = self.model.apply_model(x, t, c)
206 | else:
207 | x_in = torch.cat([x] * 2)
208 | t_in = torch.cat([t] * 2)
209 | #c_in = torch.cat([unconditional_conditioning, c])
210 |
211 | if isinstance(c, dict):
212 | assert isinstance(unconditional_conditioning, dict)
213 | c_in = dict()
214 | for k in c:
215 | if isinstance(c[k], list):
216 | c_in[k] = [torch.cat([
217 | unconditional_conditioning[k][i],
218 | c[k][i]]) for i in range(len(c[k]))]
219 | else:
220 | c_in[k] = torch.cat([
221 | unconditional_conditioning[k],
222 | c[k]])
223 | elif isinstance(c, list):
224 | c_in = list()
225 | assert isinstance(unconditional_conditioning, list)
226 | for i in range(len(c)):
227 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
228 | else:
229 | c_in = torch.cat([unconditional_conditioning, c])
230 |
231 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
232 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
233 | if score_corrector is not None:
234 | assert self.model.parameterization == "eps"
235 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
236 |
237 | return e_t
238 |
239 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
240 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
241 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
242 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
243 |
244 | def get_x_prev_and_pred_x0(e_t, index):
245 | # select parameters corresponding to the currently considered timestep
246 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
247 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
248 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
249 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
250 |
251 | # current prediction for x_0
252 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
253 | if quantize_denoised:
254 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
255 | if dynamic_threshold is not None:
256 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
257 | # direction pointing to x_t
258 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
259 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
260 | if noise_dropout > 0.:
261 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
262 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
263 | return x_prev, pred_x0
264 |
265 | e_t = get_model_output(x, t)
266 | if len(old_eps) == 0:
267 | # Pseudo Improved Euler (2nd order)
268 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
269 | e_t_next = get_model_output(x_prev, t_next)
270 | e_t_prime = (e_t + e_t_next) / 2
271 | elif len(old_eps) == 1:
272 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
273 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
274 | elif len(old_eps) == 2:
275 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
276 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
277 | elif len(old_eps) >= 3:
278 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
279 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
280 |
281 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
282 |
283 |
284 | '''
285 | print(x_prev.shape,pred_x0.shape)
286 | xp = self.model.decode_first_stage(x_prev)
287 | import einops
288 | from PIL import Image
289 | xp = (einops.rearrange(xp, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)[0]
290 | # #print(xp.shape,'@@@@@@@@@@@@@@@')
291 | # #print(t,'############')
292 | Image.fromarray(xp).save('drawview_good225/xp_'+str(float(t[0]))+'.png')
293 | px0 = self.model.decode_first_stage(pred_x0)
294 | px0 = (einops.rearrange(px0, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)[0]
295 | Image.fromarray(px0).save('drawview_good225/px0_'+str(float(t[0]))+'.png')
296 | '''
297 |
298 | return x_prev, pred_x0, e_t
299 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/models/diffusion/testplms.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 | from functools import partial
7 |
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 | from ldm.models.diffusion.sampling_util import norm_thresholding
10 |
11 |
12 | class TestPLMSSampler(object):
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | if ddim_eta != 0:
27 | raise ValueError('ddim_eta must be 0 for PLMS')
28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30 | alphas_cumprod = self.model.alphas_cumprod
31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33 |
34 | self.register_buffer('betas', to_torch(self.model.betas))
35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37 |
38 | # calculations for diffusion q(x_t | x_{t-1}) and others
39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44 |
45 | # ddim sampling parameters
46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47 | ddim_timesteps=self.ddim_timesteps,
48 | eta=ddim_eta,verbose=verbose)
49 | self.register_buffer('ddim_sigmas', ddim_sigmas)
50 | self.register_buffer('ddim_alphas', ddim_alphas)
51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57 |
58 | @torch.no_grad()
59 | def sample(self,
60 | S,
61 | batch_size,
62 | shape,
63 | names=None,
64 | conditioning=None,
65 | callback=None,
66 | normals_sequence=None,
67 | img_callback=None,
68 | quantize_x0=False,
69 | eta=0.,
70 | mask=None,
71 | x0=None,
72 | temperature=1.,
73 | noise_dropout=0.,
74 | score_corrector=None,
75 | corrector_kwargs=None,
76 | verbose=True,
77 | x_T=None,
78 | log_every_t=100,
79 | unconditional_guidance_scale=1.,
80 | unconditional_conditioning=None,
81 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
82 | dynamic_threshold=None,
83 | **kwargs
84 | ):
85 | '''
86 | if conditioning is not None:
87 | if isinstance(conditioning, dict):
88 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
89 | if cbs != batch_size:
90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
91 | else:
92 | if conditioning.shape[0] != batch_size:
93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
94 | '''
95 | if conditioning is not None:
96 | if isinstance(conditioning, dict):
97 | ctmp = conditioning[list(conditioning.keys())[0]]
98 | while isinstance(ctmp, list): ctmp = ctmp[0]
99 | cbs = ctmp.shape[0]
100 | if cbs != batch_size:
101 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
102 |
103 | elif isinstance(conditioning, list):
104 | for ctmp in conditioning:
105 | if ctmp.shape[0] != batch_size:
106 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
107 |
108 | else:
109 | if conditioning.shape[0] != batch_size:
110 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
111 |
112 |
113 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
114 | # sampling
115 | C, H, W = shape
116 | size = (batch_size, C, H, W)
117 | print(f'Data shape for PLMS sampling is {size}')
118 |
119 | samples, intermediates = self.plms_sampling(conditioning, size,
120 | names=names,
121 | callback=callback,
122 | img_callback=img_callback,
123 | quantize_denoised=quantize_x0,
124 | mask=mask, x0=x0,
125 | ddim_use_original_steps=False,
126 | noise_dropout=noise_dropout,
127 | temperature=temperature,
128 | score_corrector=score_corrector,
129 | corrector_kwargs=corrector_kwargs,
130 | x_T=x_T,
131 | log_every_t=log_every_t,
132 | unconditional_guidance_scale=unconditional_guidance_scale,
133 | unconditional_conditioning=unconditional_conditioning,
134 | dynamic_threshold=dynamic_threshold,
135 | )
136 | return samples, intermediates
137 |
138 | @torch.no_grad()
139 | def plms_sampling(self, cond, shape, names=None,
140 | x_T=None, ddim_use_original_steps=False,
141 | callback=None, timesteps=None, quantize_denoised=False,
142 | mask=None, x0=None, img_callback=None, log_every_t=100,
143 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
144 | unconditional_guidance_scale=1., unconditional_conditioning=None,
145 | dynamic_threshold=None):
146 | device = self.model.betas.device
147 | b = shape[0]
148 | #if names!='ADE_val_00000027.png':
149 | # return None, None
150 |
151 | if x_T is None:
152 | img = torch.randn(shape, device=device)
153 | else:
154 | img = x_T
155 |
156 | if names!='ADE_val_00000027.png':
157 | return None, None
158 |
159 | if timesteps is None:
160 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
161 | elif timesteps is not None and not ddim_use_original_steps:
162 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
163 | timesteps = self.ddim_timesteps[:subset_end]
164 |
165 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
166 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
167 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
168 | print(f"Running PLMS Sampling with {total_steps} timesteps")
169 |
170 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
171 | old_eps = []
172 |
173 | for i, step in enumerate(iterator):
174 | index = total_steps - i - 1
175 | ts = torch.full((b,), step, device=device, dtype=torch.long)
176 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
177 |
178 | if mask is not None:
179 | assert x0 is not None
180 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
181 | img = img_orig * mask + (1. - mask) * img
182 |
183 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
184 | quantize_denoised=quantize_denoised, temperature=temperature,
185 | noise_dropout=noise_dropout, score_corrector=score_corrector,
186 | corrector_kwargs=corrector_kwargs,
187 | unconditional_guidance_scale=unconditional_guidance_scale,
188 | unconditional_conditioning=unconditional_conditioning,
189 | old_eps=old_eps, t_next=ts_next,
190 | dynamic_threshold=dynamic_threshold)
191 | print('##############################',outs[1].shape, ts)
192 | #views_predx = outs[1].permute((0,2,3,1))[0].detach
193 | views_predx = self.model.decode_first_stage(outs[1])
194 | import einops
195 | from PIL import Image
196 | x_pred0 = (einops.rearrange(views_predx, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
197 | Image.fromarray(x_pred0[0]).save('viewx0/'+str(int(ts[0]))+'_'+names)
198 |
199 | img, pred_x0, e_t = outs
200 | old_eps.append(e_t)
201 | if len(old_eps) >= 4:
202 | old_eps.pop(0)
203 | if callback: callback(i)
204 | if img_callback: img_callback(pred_x0, i)
205 |
206 | if index % log_every_t == 0 or index == total_steps - 1:
207 | intermediates['x_inter'].append(img)
208 | intermediates['pred_x0'].append(pred_x0)
209 |
210 | return img, intermediates
211 |
212 | @torch.no_grad()
213 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
214 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
215 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
216 | dynamic_threshold=None):
217 | b, *_, device = *x.shape, x.device
218 |
219 | def get_model_output(x, t):
220 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
221 | e_t = self.model.apply_model(x, t, c)
222 | else:
223 | x_in = torch.cat([x] * 2)
224 | t_in = torch.cat([t] * 2)
225 | #c_in = torch.cat([unconditional_conditioning, c])
226 |
227 | if isinstance(c, dict):
228 | assert isinstance(unconditional_conditioning, dict)
229 | c_in = dict()
230 | for k in c:
231 | if isinstance(c[k], list):
232 | c_in[k] = [torch.cat([
233 | unconditional_conditioning[k][i],
234 | c[k][i]]) for i in range(len(c[k]))]
235 | else:
236 | c_in[k] = torch.cat([
237 | unconditional_conditioning[k],
238 | c[k]])
239 | elif isinstance(c, list):
240 | c_in = list()
241 | assert isinstance(unconditional_conditioning, list)
242 | for i in range(len(c)):
243 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
244 | else:
245 | c_in = torch.cat([unconditional_conditioning, c])
246 |
247 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
248 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
249 | if score_corrector is not None:
250 | assert self.model.parameterization == "eps"
251 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
252 |
253 | return e_t
254 |
255 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
256 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
257 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
258 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
259 |
260 | def get_x_prev_and_pred_x0(e_t, index):
261 | # select parameters corresponding to the currently considered timestep
262 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
263 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
264 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
265 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
266 |
267 | # current prediction for x_0
268 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
269 | if quantize_denoised:
270 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
271 | if dynamic_threshold is not None:
272 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
273 | # direction pointing to x_t
274 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
275 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
276 | if noise_dropout > 0.:
277 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
278 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
279 | return x_prev, pred_x0
280 |
281 | e_t = get_model_output(x, t)
282 | if len(old_eps) == 0:
283 | # Pseudo Improved Euler (2nd order)
284 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
285 | e_t_next = get_model_output(x_prev, t_next)
286 | e_t_prime = (e_t + e_t_next) / 2
287 | elif len(old_eps) == 1:
288 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
289 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
290 | elif len(old_eps) == 2:
291 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
292 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
293 | elif len(old_eps) >= 3:
294 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
295 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
296 |
297 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
298 |
299 | return x_prev, pred_x0, e_t
300 |
--------------------------------------------------------------------------------
/ldm/modules/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/.DS_Store
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from ldm.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126 | "dtype": torch.get_autocast_gpu_dtype(),
127 | "cache_enabled": torch.is_autocast_cache_enabled()}
128 | with torch.no_grad():
129 | output_tensors = ctx.run_function(*ctx.input_tensors)
130 | return output_tensors
131 |
132 | @staticmethod
133 | def backward(ctx, *output_grads):
134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135 | with torch.enable_grad(), \
136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137 | # Fixes a bug where the first op in run_function modifies the
138 | # Tensor storage in place, which is not allowed for detach()'d
139 | # Tensors.
140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141 | output_tensors = ctx.run_function(*shallow_copies)
142 |
143 | #print(len(ctx.input_tensors),len(ctx.input_params),ctx.run_function,'################')
144 | input_grads = torch.autograd.grad(
145 | output_tensors,
146 | ctx.input_tensors + ctx.input_params,
147 | output_grads,
148 | allow_unused=True,
149 | )
150 | del ctx.input_tensors
151 | del ctx.input_params
152 | del output_tensors
153 | return (None, None) + input_grads
154 |
155 |
156 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
157 | """
158 | Create sinusoidal timestep embeddings.
159 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
160 | These may be fractional.
161 | :param dim: the dimension of the output.
162 | :param max_period: controls the minimum frequency of the embeddings.
163 | :return: an [N x dim] Tensor of positional embeddings.
164 | """
165 | if not repeat_only:
166 | half = dim // 2
167 | freqs = torch.exp(
168 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
169 | ).to(device=timesteps.device)
170 | args = timesteps[:, None].float() * freqs[None]
171 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
172 | if dim % 2:
173 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
174 | else:
175 | embedding = repeat(timesteps, 'b -> b d', d=dim)
176 | return embedding
177 |
178 |
179 | def zero_module(module):
180 | """
181 | Zero out the parameters of a module and return it.
182 | """
183 | for p in module.parameters():
184 | p.detach().zero_()
185 | return module
186 |
187 |
188 | def scale_module(module, scale):
189 | """
190 | Scale the parameters of a module and return it.
191 | """
192 | for p in module.parameters():
193 | p.detach().mul_(scale)
194 | return module
195 |
196 |
197 | def mean_flat(tensor):
198 | """
199 | Take the mean over all non-batch dimensions.
200 | """
201 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
202 |
203 |
204 | def normalization(channels):
205 | """
206 | Make a standard normalization layer.
207 | :param channels: number of input channels.
208 | :return: an nn.Module for normalization.
209 | """
210 | return GroupNorm32(32, channels)
211 |
212 |
213 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
214 | class SiLU(nn.Module):
215 | def forward(self, x):
216 | return x * torch.sigmoid(x)
217 |
218 |
219 | class GroupNorm32(nn.GroupNorm):
220 | def forward(self, x):
221 | return super().forward(x.float()).type(x.dtype)
222 |
223 | def conv_nd(dims, *args, **kwargs):
224 | """
225 | Create a 1D, 2D, or 3D convolution module.
226 | """
227 | if dims == 1:
228 | return nn.Conv1d(*args, **kwargs)
229 | elif dims == 2:
230 | return nn.Conv2d(*args, **kwargs)
231 | elif dims == 3:
232 | return nn.Conv3d(*args, **kwargs)
233 | raise ValueError(f"unsupported dimensions: {dims}")
234 |
235 |
236 | def linear(*args, **kwargs):
237 | """
238 | Create a linear module.
239 | """
240 | return nn.Linear(*args, **kwargs)
241 |
242 |
243 | def avg_pool_nd(dims, *args, **kwargs):
244 | """
245 | Create a 1D, 2D, or 3D average pooling module.
246 | """
247 | if dims == 1:
248 | return nn.AvgPool1d(*args, **kwargs)
249 | elif dims == 2:
250 | return nn.AvgPool2d(*args, **kwargs)
251 | elif dims == 3:
252 | return nn.AvgPool3d(*args, **kwargs)
253 | raise ValueError(f"unsupported dimensions: {dims}")
254 |
255 |
256 | class HybridConditioner(nn.Module):
257 |
258 | def __init__(self, c_concat_config, c_crossattn_config):
259 | super().__init__()
260 | self.concat_conditioner = instantiate_from_config(c_concat_config)
261 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
262 |
263 | def forward(self, c_concat, c_crossattn):
264 | c_concat = self.concat_conditioner(c_concat)
265 | c_crossattn = self.crossattn_conditioner(c_crossattn)
266 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
267 |
268 |
269 | def noise_like(shape, device, repeat=False):
270 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
271 | noise = lambda: torch.randn(shape, device=device)
272 | return repeat_noise() if repeat else noise()
273 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6 |
7 | import open_clip
8 | from ldm.util import default, count_params
9 |
10 |
11 | class AbstractEncoder(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def encode(self, *args, **kwargs):
16 | raise NotImplementedError
17 |
18 |
19 | class IdentityEncoder(AbstractEncoder):
20 |
21 | def encode(self, x):
22 | return x
23 |
24 |
25 | class ClassEmbedder(nn.Module):
26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27 | super().__init__()
28 | self.key = key
29 | self.embedding = nn.Embedding(n_classes, embed_dim)
30 | self.n_classes = n_classes
31 | self.ucg_rate = ucg_rate
32 |
33 | def forward(self, batch, key=None, disable_dropout=False):
34 | if key is None:
35 | key = self.key
36 | # this is for use in crossattn
37 | c = batch[key][:, None]
38 | if self.ucg_rate > 0. and not disable_dropout:
39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41 | c = c.long()
42 | c = self.embedding(c)
43 | return c
44 |
45 | def get_unconditional_conditioning(self, bs, device="cuda"):
46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47 | uc = torch.ones((bs,), device=device) * uc_class
48 | uc = {self.key: uc}
49 | return uc
50 |
51 |
52 | def disabled_train(self, mode=True):
53 | """Overwrite model.train with this function to make sure train/eval mode
54 | does not change anymore."""
55 | return self
56 |
57 |
58 | class FrozenT5Embedder(AbstractEncoder):
59 | """Uses the T5 transformer encoder for text"""
60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61 | super().__init__()
62 | self.tokenizer = T5Tokenizer.from_pretrained(version)
63 | self.transformer = T5EncoderModel.from_pretrained(version)
64 | self.device = device
65 | self.max_length = max_length # TODO: typical value?
66 | if freeze:
67 | self.freeze()
68 |
69 | def freeze(self):
70 | self.transformer = self.transformer.eval()
71 | #self.train = disabled_train
72 | for param in self.parameters():
73 | param.requires_grad = False
74 |
75 | def forward(self, text):
76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78 | tokens = batch_encoding["input_ids"].to(self.device)
79 | outputs = self.transformer(input_ids=tokens)
80 |
81 | z = outputs.last_hidden_state
82 | return z
83 |
84 | def encode(self, text):
85 | return self(text)
86 |
87 |
88 | class FrozenCLIPEmbedder(AbstractEncoder):
89 | """Uses the CLIP transformer encoder for text (from huggingface)"""
90 | LAYERS = [
91 | "last",
92 | "pooled",
93 | "hidden"
94 | ]
95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97 | super().__init__()
98 | assert layer in self.LAYERS
99 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
100 | self.transformer = CLIPTextModel.from_pretrained(version)
101 | self.device = device
102 | self.max_length = max_length
103 | if freeze:
104 | self.freeze()
105 | self.layer = layer
106 | self.layer_idx = layer_idx
107 | if layer == "hidden":
108 | assert layer_idx is not None
109 | assert 0 <= abs(layer_idx) <= 12
110 |
111 | def freeze(self):
112 | self.transformer = self.transformer.eval()
113 | #self.train = disabled_train
114 | for param in self.parameters():
115 | param.requires_grad = False
116 |
117 | def forward(self, text):
118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120 | tokens = batch_encoding["input_ids"].to(self.device)
121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122 | if self.layer == "last":
123 | z = outputs.last_hidden_state
124 | elif self.layer == "pooled":
125 | z = outputs.pooler_output[:, None, :]
126 | else:
127 | z = outputs.hidden_states[self.layer_idx]
128 | return z
129 |
130 | def encode(self, text):
131 | return self(text)
132 |
133 |
134 | class MYFrozenCLIPEmbedder(AbstractEncoder):
135 | """Uses the CLIP transformer encoder for text (from huggingface)"""
136 | LAYERS = [
137 | "last",
138 | "pooled",
139 | "hidden"
140 | ]
141 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
142 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
143 | super().__init__()
144 | assert layer in self.LAYERS
145 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
146 | self.transformer = CLIPTextModel.from_pretrained(version)
147 | self.device = device
148 | self.max_length = max_length
149 | if freeze:
150 | self.freeze()
151 | self.layer = layer
152 | self.layer_idx = layer_idx
153 | if layer == "hidden":
154 | assert layer_idx is not None
155 | assert 0 <= abs(layer_idx) <= 12
156 |
157 | def freeze(self):
158 | self.transformer = self.transformer.eval()
159 | #self.train = disabled_train
160 | for param in self.parameters():
161 | param.requires_grad = False
162 |
163 | def forward(self, tokens):
164 | #batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
165 | # return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
166 | #print('realb',tokens.shape,tokens.dtype)
167 | tokens = tokens.to(self.device)
168 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
169 | if self.layer == "last":
170 | z = outputs.last_hidden_state
171 | elif self.layer == "pooled":
172 | z = outputs.pooler_output[:, None, :]
173 | else:
174 | z = outputs.hidden_states[self.layer_idx]
175 | #print('reala',z.shape,z.dtype)
176 | return z
177 |
178 | def encode(self, tokens):
179 | return self(tokens)
180 |
181 |
182 |
183 |
184 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
185 | """
186 | Uses the OpenCLIP transformer encoder for text
187 | """
188 | LAYERS = [
189 | #"pooled",
190 | "last",
191 | "penultimate"
192 | ]
193 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
194 | freeze=True, layer="last"):
195 | super().__init__()
196 | assert layer in self.LAYERS
197 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
198 | del model.visual
199 | self.model = model
200 |
201 | self.device = device
202 | self.max_length = max_length
203 | if freeze:
204 | self.freeze()
205 | self.layer = layer
206 | if self.layer == "last":
207 | self.layer_idx = 0
208 | elif self.layer == "penultimate":
209 | self.layer_idx = 1
210 | else:
211 | raise NotImplementedError()
212 |
213 | def freeze(self):
214 | self.model = self.model.eval()
215 | for param in self.parameters():
216 | param.requires_grad = False
217 |
218 | def forward(self, text):
219 | tokens = open_clip.tokenize(text)
220 | z = self.encode_with_transformer(tokens.to(self.device))
221 | return z
222 |
223 | def encode_with_transformer(self, text):
224 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
225 | x = x + self.model.positional_embedding
226 | x = x.permute(1, 0, 2) # NLD -> LND
227 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
228 | x = x.permute(1, 0, 2) # LND -> NLD
229 | x = self.model.ln_final(x)
230 | return x
231 |
232 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
233 | for i, r in enumerate(self.model.transformer.resblocks):
234 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
235 | break
236 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
237 | x = checkpoint(r, x, attn_mask)
238 | else:
239 | x = r(x, attn_mask=attn_mask)
240 | return x
241 |
242 | def encode(self, text):
243 | return self(text)
244 |
245 |
246 | class FrozenCLIPT5Encoder(AbstractEncoder):
247 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
248 | clip_max_length=77, t5_max_length=77):
249 | super().__init__()
250 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
251 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
252 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
253 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
254 |
255 | def encode(self, text):
256 | return self(text)
257 |
258 | def forward(self, text):
259 | clip_z = self.clip_encoder.encode(text)
260 | t5_z = self.t5_encoder.encode(text)
261 | return [clip_z, t5_z]
262 |
263 |
264 |
265 |
266 |
267 | class MYFrozenOpenCLIPEmbedder(AbstractEncoder):
268 | """
269 | Uses the OpenCLIP transformer encoder for text
270 | """
271 | LAYERS = [
272 | #"pooled",
273 | "last",
274 | "penultimate"
275 | ]
276 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
277 | freeze=True, layer="penultimate"):
278 | super().__init__()
279 | assert layer in self.LAYERS
280 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cuda'), pretrained=version)
281 | del model.visual
282 | self.model = model
283 |
284 | self.device = device
285 | self.max_length = max_length
286 | if freeze:
287 | self.freeze()
288 | self.layer = layer
289 | if self.layer == "last":
290 | self.layer_idx = 0
291 | elif self.layer == "penultimate":
292 | self.layer_idx = 1
293 | else:
294 | raise NotImplementedError()
295 |
296 | def freeze(self):
297 | self.model = self.model.eval()
298 | for param in self.parameters():
299 | param.requires_grad = False
300 |
301 | def forward(self, text):
302 | tokens = open_clip.tokenize(text)
303 | z = self.encode_with_transformer(tokens.to(self.device))
304 | return (z[torch.arange(z.shape[0]), 0], z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)])
305 |
306 | def encode_with_transformer(self, text):
307 | #print(text.device,self.device, self.model.device)
308 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
309 | x = x + self.model.positional_embedding
310 | x = x.permute(1, 0, 2) # NLD -> LND
311 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
312 | x = x.permute(1, 0, 2) # LND -> NLD
313 | x = self.model.ln_final(x)
314 | return x
315 |
316 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
317 | for i, r in enumerate(self.model.transformer.resblocks):
318 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
319 | break
320 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
321 | x = checkpoint(r, x, attn_mask)
322 | else:
323 | x = r(x, attn_mask=attn_mask)
324 | return x
325 |
326 | def encode(self, text):
327 | return self(text)
328 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def log_txt_as_img(wh, xc, size=10):
12 | # wh a tuple of (width, height)
13 | # xc a list of captions to plot
14 | b = len(xc)
15 | txts = list()
16 | for bi in range(b):
17 | txt = Image.new("RGB", wh, color="white")
18 | draw = ImageDraw.Draw(txt)
19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20 | nc = int(40 * (wh[0] / 256))
21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22 |
23 | try:
24 | draw.text((0, 0), lines, fill="black", font=font)
25 | except UnicodeEncodeError:
26 | print("Cant encode string for logging. Skipping.")
27 |
28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29 | txts.append(txt)
30 | txts = np.stack(txts)
31 | txts = torch.tensor(txts)
32 | return txts
33 |
34 |
35 | def ismap(x):
36 | if not isinstance(x, torch.Tensor):
37 | return False
38 | return (len(x.shape) == 4) and (x.shape[1] > 3)
39 |
40 |
41 | def isimage(x):
42 | if not isinstance(x,torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45 |
46 |
47 | def exists(x):
48 | return x is not None
49 |
50 |
51 | def default(val, d):
52 | if exists(val):
53 | return val
54 | return d() if isfunction(d) else d
55 |
56 |
57 | def mean_flat(tensor):
58 | """
59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60 | Take the mean over all non-batch dimensions.
61 | """
62 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
63 |
64 |
65 | def count_params(model, verbose=False):
66 | total_params = sum(p.numel() for p in model.parameters())
67 | if verbose:
68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69 | return total_params
70 |
71 |
72 | def instantiate_from_config(config):
73 | if not "target" in config:
74 | if config == '__is_first_stage__':
75 | return None
76 | elif config == "__is_unconditional__":
77 | return None
78 | raise KeyError("Expected key `target` to instantiate.")
79 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
80 |
81 |
82 | def get_obj_from_str(string, reload=False):
83 | module, cls = string.rsplit(".", 1)
84 | if reload:
85 | module_imp = importlib.import_module(module)
86 | importlib.reload(module_imp)
87 | return getattr(importlib.import_module(module, package=None), cls)
88 |
89 |
90 | class AdamWwithEMAandWings(optim.Optimizer):
91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94 | ema_power=1., param_names=()):
95 | """AdamW that saves EMA versions of the parameters."""
96 | if not 0.0 <= lr:
97 | raise ValueError("Invalid learning rate: {}".format(lr))
98 | if not 0.0 <= eps:
99 | raise ValueError("Invalid epsilon value: {}".format(eps))
100 | if not 0.0 <= betas[0] < 1.0:
101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102 | if not 0.0 <= betas[1] < 1.0:
103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104 | if not 0.0 <= weight_decay:
105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106 | if not 0.0 <= ema_decay <= 1.0:
107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108 | defaults = dict(lr=lr, betas=betas, eps=eps,
109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110 | ema_power=ema_power, param_names=param_names)
111 | super().__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super().__setstate__(state)
115 | for group in self.param_groups:
116 | group.setdefault('amsgrad', False)
117 |
118 | @torch.no_grad()
119 | def step(self, closure=None):
120 | """Performs a single optimization step.
121 | Args:
122 | closure (callable, optional): A closure that reevaluates the model
123 | and returns the loss.
124 | """
125 | loss = None
126 | if closure is not None:
127 | with torch.enable_grad():
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 | params_with_grad = []
132 | grads = []
133 | exp_avgs = []
134 | exp_avg_sqs = []
135 | ema_params_with_grad = []
136 | state_sums = []
137 | max_exp_avg_sqs = []
138 | state_steps = []
139 | amsgrad = group['amsgrad']
140 | beta1, beta2 = group['betas']
141 | ema_decay = group['ema_decay']
142 | ema_power = group['ema_power']
143 |
144 | for p in group['params']:
145 | if p.grad is None:
146 | continue
147 | params_with_grad.append(p)
148 | if p.grad.is_sparse:
149 | raise RuntimeError('AdamW does not support sparse gradients')
150 | grads.append(p.grad)
151 |
152 | state = self.state[p]
153 |
154 | # State initialization
155 | if len(state) == 0:
156 | state['step'] = 0
157 | # Exponential moving average of gradient values
158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159 | # Exponential moving average of squared gradient values
160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161 | if amsgrad:
162 | # Maintains max of all exp. moving avg. of sq. grad. values
163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164 | # Exponential moving average of parameter values
165 | state['param_exp_avg'] = p.detach().float().clone()
166 |
167 | exp_avgs.append(state['exp_avg'])
168 | exp_avg_sqs.append(state['exp_avg_sq'])
169 | ema_params_with_grad.append(state['param_exp_avg'])
170 |
171 | if amsgrad:
172 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173 |
174 | # update the steps for each param group update
175 | state['step'] += 1
176 | # record the step after step update
177 | state_steps.append(state['step'])
178 |
179 | optim._functional.adamw(params_with_grad,
180 | grads,
181 | exp_avgs,
182 | exp_avg_sqs,
183 | max_exp_avg_sqs,
184 | state_steps,
185 | amsgrad=amsgrad,
186 | beta1=beta1,
187 | beta2=beta2,
188 | lr=group['lr'],
189 | weight_decay=group['weight_decay'],
190 | eps=group['eps'],
191 | maximize=False)
192 |
193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196 |
197 | return loss
--------------------------------------------------------------------------------
/resources/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/.DS_Store
--------------------------------------------------------------------------------
/resources/bibtex.txt:
--------------------------------------------------------------------------------
1 | @inproceedings{lv24cvpr,
2 | author = "Lv, Z. and Wei, Y. and Zuo, W. and Wong, K.-Y.~K.",
3 | title = "PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis",
4 | booktitle = "Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition",
5 | volume = "",
6 | pages = "",
7 | address = "Seattle, Washington",
8 | month = "June",
9 | year = "2024"
10 | }
--------------------------------------------------------------------------------
/resources/ind.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/ind.png
--------------------------------------------------------------------------------
/resources/method_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/method_diagram.png
--------------------------------------------------------------------------------
/resources/newfig2_final.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/newfig2_final.pdf
--------------------------------------------------------------------------------
/resources/od.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/od.png
--------------------------------------------------------------------------------
/resources/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/overview.png
--------------------------------------------------------------------------------
/resources/paper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/paper.png
--------------------------------------------------------------------------------
/resources/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/teaser.png
--------------------------------------------------------------------------------
/run_inference_ADE20K.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICE=0 python inference.py --outdir output/ade20k \
2 | --config configs/stable-diffusion/PLACE.yaml \
3 | --ckpt ckpt/ade20k_best.ckpt \
4 | --dataset ADE20K \
5 | --data_root path_to_ADE20K
6 |
--------------------------------------------------------------------------------
/run_inference_COCO.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICE=0 python inference.py --outdir output/COCO \
2 | --config configs/stable-diffusion/PLACE.yaml \
3 | --ckpt ckpt/coco_best.ckpt \
4 | --dataset COCO \
5 | --data_root path_to_COCO
6 |
--------------------------------------------------------------------------------