├── .DS_Store ├── README.md ├── __pycache__ ├── config.cpython-38.pyc └── tutorial_dataset_ade.cpython-38.pyc ├── color150.mat ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ └── PLACE.yaml ├── dataset.py ├── environment.yaml ├── index.html ├── inference.py ├── ldm ├── .DS_Store ├── __pycache__ │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddpm.cpython-38.pyc │ │ ├── myplms.cpython-38.pyc │ │ ├── plms.cpython-38.pyc │ │ ├── sampling_util.cpython-38.pyc │ │ └── testplms.cpython-38.pyc │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dpm_solver.cpython-38.pyc │ │ │ └── sampler.cpython-38.pyc │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── myplms.py │ │ ├── plms.py │ │ ├── sampling_util.py │ │ └── testplms.py ├── modules │ ├── .DS_Store │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ └── ema.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── modules.cpython-38.pyc │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── resources ├── .DS_Store ├── bibtex.txt ├── ind.png ├── method_diagram.png ├── newfig2_final.pdf ├── od.png ├── overview.png ├── paper.png └── teaser.png ├── run_inference_ADE20K.sh └── run_inference_COCO.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis (CVPR 2024) 2 | 3 | ### Introduction 4 | 5 | The source code for our paper "PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis" (CVPR 2024) 6 | 7 | [**[Project Page]**](https://cszy98.github.io/PLACE/) [**[Code]**](https://github.com/cszy98/PLACE/tree/main) [**[Paper]**](https://arxiv.org/abs/2403.01852) 8 | 9 | ### Overview 10 | 11 | ![overview](resources/overview.png) 12 | 13 | ### Quick Start 14 | 15 | #### Installation 16 | 17 | ``` 18 | git clone 19 | cd PLACE 20 | conda env create -f environment.yaml 21 | conda activate PLACE 22 | ``` 23 | 24 | #### Data Preparation 25 | 26 | Please follow the dataset preparation process in [FreestyleNet](https://github.com/essunny310/FreestyleNet). 27 | 28 | #### Running 29 | 30 | The pre-trained models can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/1b5pC52hasLwm1gOkc9LmdIyxZjrdlNWC?usp=drive_link) and should be put into the `ckpt` folder. 31 | 32 | After the dataset and pre-trained models are prepared, you may evaluate the model with the following scripts: 33 | 34 | ``` 35 | # evaluate on the ADE20K dataset 36 | ./run_inference_ADE20K.sh 37 | # evaluate on the COCO-Stuff dataset 38 | ./run_inference_COCO.sh 39 | ``` 40 | 41 | For out-of-distribution synthesis, you just need to modify the `ADE20K` or `COCO` dictionary in the `dataset.py` 42 | 43 | ### Citation 44 | 45 | ``` 46 | @article{lv2024place, 47 | title={PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis}, 48 | author={Lv, Zhengyao and Wei, Yuxiang and Zuo, Wangmeng and Kwan-Yee K. Wong}, 49 | journal={IEEE Conference on Computer Vision and Pattern Recognition}, 50 | year={2024} 51 | } 52 | ``` 53 | 54 | ### Contact 55 | 56 | Please send mail to cszy98@gmail.com 57 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/tutorial_dataset_ade.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/__pycache__/tutorial_dataset_ade.cpython-38.pyc -------------------------------------------------------------------------------- /color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/color150.mat -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /configs/stable-diffusion/PLACE.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.PLACE_LDM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "tks" 12 | control_key: "hint" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false # Note: different from the one we trained before 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False 20 | 21 | scheduler_config: # 10000 warmup steps 22 | target: ldm.lr_scheduler.LambdaLinearScheduler 23 | params: 24 | warm_up_steps: [ 10000 ] 25 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 26 | f_start: [ 1.e-6 ] 27 | f_max: [ 1. ] 28 | f_min: [ 1. ] 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.PLACEUnetModel 32 | params: 33 | image_size: 32 # unused 34 | in_channels: 4 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: [ 4, 2, 1 ] 38 | num_res_blocks: 2 39 | channel_mult: [ 1, 2, 4, 4 ] 40 | num_heads: 8 41 | use_spatial_transformer: True 42 | transformer_depth: 1 43 | context_dim: 768 44 | use_checkpoint: True 45 | legacy: False 46 | catype: 'PLACE' 47 | 48 | first_stage_config: 49 | target: ldm.models.autoencoder.AutoencoderKL 50 | params: 51 | embed_dim: 4 52 | monitor: val/rec_loss 53 | ddconfig: 54 | double_z: true 55 | z_channels: 4 56 | resolution: 256 57 | in_channels: 3 58 | out_ch: 3 59 | ch: 128 60 | ch_mult: 61 | - 1 62 | - 2 63 | - 4 64 | - 4 65 | num_res_blocks: 2 66 | attn_resolutions: [] 67 | dropout: 0.0 68 | lossconfig: 69 | target: torch.nn.Identity 70 | 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.MYFrozenCLIPEmbedder 73 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import cv2 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | import random 7 | from PIL import Image 8 | import open_clip 9 | import torch 10 | import os 11 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder as CLIP 12 | 13 | clip = CLIP() 14 | 15 | ADE20K={ 16 | 1 : "wall", 17 | 2 : "building edifice", 18 | 3 : "sky", 19 | 4 : "floor", 20 | 5 : "tree", 21 | 6 : "ceiling", 22 | 7 : "road", 23 | 8 : "bed", 24 | 9 : "windowpane", 25 | 10 : "grass", 26 | 11 : "cabinet", 27 | 12 : "sidewalk", 28 | 13 : "person", 29 | 14 : "earth ground", 30 | 15 : "door", 31 | 16 : "table", 32 | 17 : "mountain", 33 | 18 : "plant flora", 34 | 19 : "curtain drapery mantle pall", 35 | 20 : "chair", 36 | 21 : "car", 37 | 22 : "water", 38 | 23 : "painting picture", 39 | 24 : "sofa lounge", 40 | 25 : "shelf", 41 | 26 : "house", 42 | 27 : "sea", 43 | 28 : "mirror", 44 | 29 : "carpet", 45 | 30 : "field", 46 | 31 : "armchair", 47 | 32 : "seat", 48 | 33 : "fence", 49 | 34 : "desk", 50 | 35 : "rock stone", 51 | 36 : "wardrobe closet", 52 | 37 : "lamp", 53 | 38 : "bathtub", 54 | 39 : "railing", 55 | 40 : "cushion", 56 | 41 : "base pedestal stand", 57 | 42 : "box", 58 | 43 : "pillar", 59 | 44 : "signboard sign", 60 | 45 : "chest bureau dresser", 61 | 46 : "counter", 62 | 47 : "sand", 63 | 48 : "sink", 64 | 49 : "skyscraper", 65 | 50 : "fireplace", 66 | 51 : "refrigerator", 67 | 52 : "grandstand covered stand", 68 | 53 : "path", 69 | 54 : "stairs", 70 | 55 : "runway", 71 | 56 : "case showcase vitrine", 72 | 57 : "pool table billiard table snooker table", 73 | 58 : "pillow", 74 | 59 : "screen door", 75 | 60 : "stairway", 76 | 61 : "river", 77 | 62 : "bridge", 78 | 63 : "bookcase", 79 | 64 : "blind screen", 80 | 65 : "coffee table cocktail table", 81 | 66 : "toilet can commode potty", 82 | 67 : "flower", 83 | 68 : "book", 84 | 69 : "hill", 85 | 70 : "bench", 86 | 71 : "countertop", 87 | 72 : "kitchen range cooking stove", 88 | 73 : "palm tree", 89 | 74 : "kitchen island", 90 | 75 : "computer", 91 | 76 : "swivel chair", 92 | 77 : "boat", 93 | 78 : "bar", 94 | 79 : "arcade machine", 95 | 80 : "hovel shack shanty", 96 | 81 : "autobus motorbus omnibus", 97 | 82 : "towel", 98 | 83 : "light", 99 | 84 : "truck", 100 | 85 : "tower", 101 | 86 : "chandelier pendant", 102 | 87 : "awning sunblind", 103 | 88 : "streetlight", 104 | 89 : "booth cubicle kiosk", 105 | 90 : "television tv set idiot box boob tube telly goggle box", 106 | 91 : "airplane", 107 | 92 : "dirt track", 108 | 93 : "apparel", 109 | 94 : "pole", 110 | 95 : "land ground soil", 111 | 96 : "balustrade handrail", 112 | 97 : "escalator", 113 | 98 : "ottoman pouf hassock", 114 | 99 : "bottle", 115 | 100 : "buffet counter sideboard", 116 | 101 : "poster placard notice card", 117 | 102 : "stage", 118 | 103 : "van", 119 | 104 : "ship", 120 | 105 : "fountain", 121 | 106 : "conveyor belt transporter", 122 | 107 : "canopy", 123 | 108 : "washing machine", 124 | 109 : "toy", 125 | 110 : "swimming pool natatorium", 126 | 111 : "stool", 127 | 112 : "barrel", 128 | 113 : "basket handbasket", 129 | 114 : "waterfall", 130 | 115 : "tent", 131 | 116 : "bag", 132 | 117 : "motorbike", 133 | 118 : "cradle", 134 | 119 : "oven", 135 | 120 : "ball", 136 | 121 : "food", 137 | 122 : "stair", 138 | 123 : "storage tank", 139 | 124 : "brand marque", 140 | 125 : "microwave oven", 141 | 126 : "flowerpot", 142 | 127 : "animal fauna", 143 | 128 : "bicycle", 144 | 129 : "lake", 145 | 130 : "dishwasher", 146 | 131 : "screen silver screen projection screen", 147 | 132 : "blanket", 148 | 133 : "sculpture", 149 | 134 : "exhaust hood", 150 | 135 : "sconce", 151 | 136 : "vase", 152 | 137 : "traffic light", 153 | 138 : "tray", 154 | 139 : "ashcan trash can dustbin", 155 | 140 : "fan", 156 | 141 : "pier wharfage dock", 157 | 142 : "crt screen", 158 | 143 : "plate", 159 | 144 : "monitoring device", 160 | 145 : "notice board", 161 | 146 : "shower", 162 | 147 : "radiator", 163 | 148 : "drinking glass", 164 | 149 : "clock", 165 | 150 : "flag" 166 | } 167 | 168 | COCO={ 169 | 1 : "person", 170 | 2 : "bicycle", 171 | 3 : "car", 172 | 4 : "motorcycle", 173 | 5 : "airplane", 174 | 6 : "bus", 175 | 7 : "train", 176 | 8 : "truck", 177 | 9 : "boat", 178 | 10 : "traffic light", 179 | 11 : "fire hydrant", 180 | 12 : "street sign", 181 | 13 : "stop sign", 182 | 14 : "parking meter", 183 | 15 : "bench", 184 | 16 : "bird", 185 | 17 : "cat", 186 | 18 : "dog", 187 | 19 : "horse", 188 | 20 : "sheep", 189 | 21 : "cow", 190 | 22 : "elephant", 191 | 23 : "bear", 192 | 24 : "zebra", 193 | 25 : "giraffe", 194 | 26 : "hat", 195 | 27 : "backpack", 196 | 28 : "umbrella", 197 | 29 : "shoe", 198 | 30 : "eye glasses", 199 | 31 : "handbag", 200 | 32 : "tie", 201 | 33 : "suitcase", 202 | 34 : "frisbee", 203 | 35 : "skis", 204 | 36 : "snowboard", 205 | 37 : "sports ball", 206 | 38 : "kite", 207 | 39 : "baseball bat", 208 | 40 : "baseball glove", 209 | 41 : "skateboard", 210 | 42 : "surfboard", 211 | 43 : "tennis racket", 212 | 44 : "bottle", 213 | 45 : "plate", 214 | 46 : "wine glass", 215 | 47 : "cup", 216 | 48 : "fork", 217 | 49 : "knife", 218 | 50 : "spoon", 219 | 51 : "bowl", 220 | 52 : "banana", 221 | 53 : "apple", 222 | 54 : "sandwich", 223 | 55 : "orange", 224 | 56 : "broccoli", 225 | 57 : "carrot", 226 | 58 : "hot dog", 227 | 59 : "pizza", 228 | 60 : "donut", 229 | 61 : "cake", 230 | 62 : "chair", 231 | 63 : "couch", 232 | 64 : "potted plant", 233 | 65 : "bed", 234 | 66 : "mirror", 235 | 67 : "dining table", 236 | 68 : "window", 237 | 69 : "desk", 238 | 70 : "toilet", 239 | 71 : "door", 240 | 72 : "tv", 241 | 73 : "laptop", 242 | 74 : "mouse", 243 | 75 : "remote", 244 | 76 : "keyboard", 245 | 77 : "cell phone", 246 | 78 : "microwave", 247 | 79 : "oven", 248 | 80 : "toaster", 249 | 81 : "sink", 250 | 82 : "refrigerator", 251 | 83 : "blender", 252 | 84 : "book", 253 | 85 : "clock", 254 | 86 : "vase", 255 | 87 : "scissors", 256 | 88 : "teddy bear", 257 | 89 : "hair drier", 258 | 90 : "toothbrush", 259 | 91 : "hair brush", 260 | 92 : "banner", 261 | 93 : "blanket", 262 | 94 : "branch", 263 | 95 : "bridge", 264 | 96 : "building", 265 | 97 : "bush", 266 | 98 : "cabinet", 267 | 99 : "cage", 268 | 100 : "cardboard", 269 | 101 : "carpet", 270 | 102 : "ceiling", 271 | 103 : "tile ceiling", 272 | 104 : "cloth", 273 | 105 : "clothes", 274 | 106 : "clouds", 275 | 107 : "counter", 276 | 108 : "cupboard", 277 | 109 : "curtain", 278 | 110 : "desk", 279 | 111 : "dirt", 280 | 112 : "door", 281 | 113 : "fence", 282 | 114 : "marble floor", 283 | 115 : "floor", 284 | 116 : "stone floor", 285 | 117 : "tile floor", 286 | 118 : "wood floor", 287 | 119 : "flower", 288 | 120 : "fog", 289 | 121 : "food", 290 | 122 : "fruit", 291 | 123 : "furniture", 292 | 124 : "grass", 293 | 125 : "gravel", 294 | 126 : "ground", 295 | 127 : "hill", 296 | 128 : "house", 297 | 129 : "leaves", 298 | 130 : "light", 299 | 131 : "mat", 300 | 132 : "metal", 301 | 133 : "mirror", 302 | 134 : "moss", 303 | 135 : "mountain", 304 | 136 : "mud", 305 | 137 : "napkin", 306 | 138 : "net", 307 | 139 : "paper", 308 | 140 : "pavement", 309 | 141 : "pillow", 310 | 142 : "plant", 311 | 143 : "plastic", 312 | 144 : "platform", 313 | 145 : "playingfield", 314 | 146 : "railing", 315 | 147 : "railroad", 316 | 148 : "river", 317 | 149 : "road", 318 | 150 : "rock", 319 | 151 : "roof", 320 | 152 : "rug", 321 | 153 : "salad", 322 | 154 : "sand", 323 | 155 : "sea", 324 | 156 : "shelf", 325 | 157 : "sky", 326 | 158 : "skyscraper", 327 | 159 : "snow", 328 | 160 : "solid", 329 | 161 : "stairs", 330 | 162 : "stone", 331 | 163 : "straw", 332 | 164 : "structural", 333 | 165 : "table", 334 | 166 : "tent", 335 | 167 : "textile", 336 | 168 : "towel", 337 | 169 : "tree", 338 | 170 : "vegetable", 339 | 171 : "brick wall", 340 | 172 : "concrete wall", 341 | 173 : "wall", 342 | 174 : "panel wall", 343 | 175 : "stone wall", 344 | 176 : "tile wall", 345 | 177 : "wood wall", 346 | 178 : "water", 347 | 179 : "waterdrops", 348 | 180 : "blind window", 349 | 181 : "window", 350 | 182 : "wood" 351 | } 352 | 353 | def gettks(tkss): 354 | newtks = [] 355 | for i in range(77): 356 | if tkss[0,i]==49407: 357 | break 358 | elif tkss[0,i]==49406: 359 | continue 360 | elif tkss[0,i]==267: 361 | continue 362 | else: 363 | newtks.append(int(tkss[0,i])) 364 | return newtks 365 | 366 | class ADE20KDataset(Dataset): 367 | def __init__(self, data_root, phase='Val'): 368 | self.data = [] 369 | path = data_root 370 | prefix = 'validation' if phase=='Val' else 'training' 371 | files = os.listdir(path+'/images/'+prefix+'/') if phase=='Train' else os.listdir(path+'/images/'+prefix+'/') 372 | for fn in files: 373 | self.data.append({'target':path+'/images/'+prefix+'/'+fn,'source':path+'/annotations/'+prefix+'/'+fn.replace('.jpg','.png')}) 374 | 375 | self.labeldic={} 376 | self.namedic={} 377 | for k in ADE20K.keys(): 378 | self.namedic[int(k)-1] = ADE20K[k] 379 | 380 | batch_encoding = clip.tokenizer(ADE20K[k], truncation=True, max_length=clip.max_length, return_length=True, 381 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 382 | tokens = batch_encoding["input_ids"] 383 | corr_tks = gettks(tokens) 384 | 385 | self.labeldic[int(k)-1] = corr_tks 386 | 387 | def __len__(self): 388 | return len(self.data) 389 | 390 | def __getitem__(self, idx): 391 | item = self.data[idx] 392 | 393 | source_filename = item['source'] 394 | target_filename = item['target'] 395 | 396 | source = np.array(Image.open(source_filename).resize((512,512),Image.NEAREST),dtype=np.float) 397 | target = np.array(Image.open(target_filename).resize((512,512),Image.BICUBIC).convert('RGB')) 398 | 399 | tokens = np.full((77),49407) 400 | tokens[0] = 49406 401 | token_list = [] 402 | tokens_cls = np.full((77),49407) 403 | tokens_cls[0] = 49406 404 | tokens_cls_list = [] 405 | source_minus_1 = source - 1 406 | prompt = '' 407 | for lb in np.unique(source_minus_1): 408 | if lb==-1: 409 | continue 410 | token_list+=self.labeldic[lb] 411 | tokens_cls_list += [lb]*len(self.labeldic[lb]) 412 | prompt += self.namedic[lb] + ',' 413 | tokens[1:len(token_list)+1] = np.array(token_list) 414 | tokens_cls[1:len(token_list)+1] = np.array(tokens_cls_list) 415 | 416 | # Normalize source images to [0, 1]. 417 | viewsource = np.zeros((512,512,3)) 418 | viewsource[:,:,0] = source 419 | viewsource[:,:,1] = source 420 | viewsource[:,:,2] = source 421 | viewsource = np.array(viewsource,dtype=np.uint8) 422 | viewsource = viewsource.astype(np.float32) / 255.0 423 | 424 | # Normalize target images to [-1, 1]. 425 | target = (target.astype(np.float32) / 127.5) - 1.0 426 | 427 | source = source_minus_1[:,:,np.newaxis] 428 | 429 | assert viewsource.shape==target.shape, str(viewsource.shape)+' '+str(target.shape) + ' '+item['target']+' ' + item['source'] 430 | 431 | return dict(jpg=target, txt=prompt, tks=tokens, hint=source, targetpath=item['target'], sourcepath=item['source'], viewcontrol=viewsource, tokens_cls=tokens_cls) 432 | 433 | class COCODataset(Dataset): 434 | def __init__(self, data_root, phase='Val'): 435 | self.data = [] 436 | path = data_root 437 | prefix = 'val2017' if phase=='Val' else 'train2017' 438 | files = os.listdir(path+prefix) 439 | for fn in files: 440 | self.data.append({'target':path+'/'+prefix+'/'+fn,'source':path+'/stuffthingmaps_trainval2017/'+prefix+'/'+fn.replace('.jpg','.png')}) 441 | 442 | self.labeldic={} 443 | self.namedic={} 444 | for k in COCO.keys(): 445 | self.namedic[int(k)-1] = COCO[k] 446 | 447 | batch_encoding = clip.tokenizer(COCO[k], truncation=True, max_length=clip.max_length, return_length=True, 448 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 449 | tokens = batch_encoding["input_ids"] 450 | corr_tks = gettks(tokens) 451 | 452 | self.labeldic[int(k)-1] = corr_tks 453 | 454 | def __len__(self): 455 | return len(self.data) 456 | 457 | def __getitem__(self, idx): 458 | item = self.data[idx] 459 | 460 | source_filename = item['source'] 461 | target_filename = item['target'] 462 | 463 | source = np.array(Image.open(source_filename).resize((512,512),Image.NEAREST),dtype=np.float) 464 | source = np.where(source==255,-1,source) 465 | source = source + 1 466 | 467 | target = np.array(Image.open(target_filename).resize((512,512),Image.BILINEAR).convert('RGB')) 468 | 469 | tokens = np.full((77),49407) 470 | tokens[0] = 49406 471 | token_list = [] 472 | tokens_cls = np.full((77),49407) 473 | tokens_cls[0] = 49406 474 | tokens_cls_list = [] 475 | source_minus_1 = source - 1 476 | prompt = '' 477 | for lb in np.unique(source_minus_1): 478 | if lb==-1: 479 | continue 480 | token_list+=self.labeldic[lb] 481 | tokens_cls_list += [lb]*len(self.labeldic[lb]) 482 | prompt += self.namedic[lb] + ',' 483 | tokens[1:len(token_list)+1] = np.array(token_list) 484 | tokens_cls[1:len(token_list)+1] = np.array(tokens_cls_list) 485 | 486 | # Normalize source images to [0, 1]. 487 | viewsource = np.zeros((512,512,3)) 488 | viewsource[:,:,0] = source 489 | viewsource[:,:,1] = source 490 | viewsource[:,:,2] = source 491 | viewsource = np.array(viewsource,dtype=np.uint8) 492 | viewsource = viewsource.astype(np.float32) / 255.0 493 | 494 | # Normalize target images to [-1, 1]. 495 | target = (target.astype(np.float32) / 127.5) - 1.0 496 | 497 | source = source_minus_1[:,:,np.newaxis] 498 | 499 | assert viewsource.shape==target.shape, str(viewsource.shape)+' '+str(target.shape) + ' '+item['target']+' ' + item['source'] 500 | 501 | return dict(jpg=target, txt=prompt, tks=tokens, hint=source, targetpath=item['target'], sourcepath=item['source'], viewcontrol=viewsource, tokens_cls=tokens_cls) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: PLACE 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - gradio==3.16.2 14 | - albumentations==1.3.0 15 | - opencv-contrib-python==4.3.0.36 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.5.0 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit==1.12.1 22 | - einops==0.3.0 23 | - transformers==4.19.2 24 | - webdataset==0.2.5 25 | - kornia==0.6 26 | - open_clip_torch==2.0.2 27 | - invisible-watermark>=0.1.5 28 | - streamlit-drawable-canvas==0.8.0 29 | - torchmetrics==0.6.0 30 | - timm==0.6.12 31 | - addict==2.4.0 32 | - yapf==0.32.0 33 | - prettytable==3.6.0 34 | - safetensors==0.2.7 35 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 118 | 119 | 120 | 121 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 136 | 137 | 138 | 139 |
140 |
141 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis (CVPR 2024) 142 | 143 |
144 | 145 | 150 | 155 | 160 | 165 | 166 |
146 |
147 | Zhengyao Lv 148 |
149 |
151 |
152 | Yuxiang Wei 153 |
154 |
156 |
157 | Wangmeng Zuo 158 |
159 |
161 |
162 | Kwan-Yee K. Wong 163 |
164 |
167 | 168 | 169 | 174 | 179 | 180 |
170 |
171 | [Paper] 172 |
173 |
175 |
176 | [GitHub]
177 |
178 |
181 | 182 |
183 | 184 | 202 | 203 |
204 | 205 | 206 |

Abstract

207 | 208 | 212 | 213 |
209 | Recent advancements in large-scale pre-trained text-to-image models have led to remarkable progress in semantic image synthesis. Nevertheless, synthesizing high-quality images with consistent semantics and layout remains a challenge. In t 210 | his paper, we propose the adaPtive LAyout-semantiC fusion modulE (PLACE) that harnesses pre-trained models to alleviate the aforementioned issues. Specifically, we first employ the layout control map to faithfully represent layouts in the feature space. Subsequently, we combine the layout and semantic features in a timestep-adaptive manner to synthesize images with realistic details. During fine-tuning, we propose the Semantic Alignment (SA) loss to further enhance layout alignment. Additionally, we introduce the Layout-Free Prior Preservation (LFP) loss, which leverages unlabeled data to maintain the priors of pre-trained models, thereby improving the visual quality and semantic consistency of synthesized images. Extensive experiments demonstrate that our approach performs favorably in terms of visual quality, semantic consistency, and layout alignment. 211 |
214 |
215 | 216 |
221 | 222 | 232 | 233 |

Overview

234 | 235 | 236 |
237 |
238 | 240 | 241 | 242 |
239 |
243 | 244 | 245 | 248 | 249 | 250 | 251 |
246 |
247 |
252 | 253 |
254 |
255 | 258 | 259 | 260 |
256 | Overview of our method. (a) We utilize the layout control map calculated from semantic map and PLACE for layout control. During fine-tuning, we combine the LDM, SA, and LFP loss as optimization objective. (b) Calculation of the layout control map and details of the adaptive layout-semantic fusion module. Each vector in the Layout Control Map encodes all the semantic components in the reception field. The adaptive layout-semantic fusion module blends the layout and semantics feature in a timestep-adaptive way. 257 |
261 | 262 |
263 |
264 | 265 |

Results

266 | 267 | 268 |
269 |
270 | 272 | 273 | 274 |
271 |
275 | 276 | 277 | 283 | 284 | 285 | 286 | 291 |
278 |
279 |
280 | 281 |

Visual comparisons on ADE20K and COCO-Stuff

282 |
292 | 293 | 294 | 295 | 301 | 302 | 303 | 304 | 309 |
296 |
297 |
298 | 299 |

Visual comparisons for out-of-distribution synthesis

300 |
310 | 311 | 320 | 321 | 328 |
329 |
330 | 331 |

Paper and Supplementary Material

332 | 333 | 334 | 342 | 343 |
Zhengyao Lv, Yuxiang Wei, Wangmeng Zuo and Kwan-Yee K. Wong
335 | PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis
336 | In CVPR, 2024.
337 | (hosted on ArXiv)
338 | 339 |
340 |
341 |
344 |
345 | 346 | 347 | 348 | 351 | 352 |
349 | [Bibtex] 350 |
353 | 354 |
355 |
356 | 357 | 358 | 359 | 365 | 366 |
360 | 361 |

Acknowledgements

362 | This template was originally made by Phillip Isola and Richard Zhang for a colorful ECCV project; the code can be found here. 363 |
364 |
367 | 368 |
369 | 370 | 371 | 372 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import einops 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | import os, argparse 7 | 8 | from ldm.models.diffusion.plms import PLMSSampler 9 | from torch.utils.data import DataLoader 10 | from dataset import ADE20KDataset, COCODataset 11 | from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution 12 | 13 | from omegaconf import OmegaConf 14 | from ldm.util import instantiate_from_config 15 | 16 | from pytorch_lightning import seed_everything 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument( 23 | "--outdir", 24 | type=str, 25 | nargs="?", 26 | help="dir to write results to", 27 | default="output/ade20k" 28 | ) 29 | 30 | parser.add_argument( 31 | "--config", 32 | type=str, 33 | default="configs/stable-diffusion/xxx.yaml", 34 | help="path to config which constructs model", 35 | ) 36 | 37 | parser.add_argument( 38 | "--seed", 39 | type=int, 40 | default=42, 41 | help="the seed (for reproducible sampling)", 42 | ) 43 | 44 | parser.add_argument( 45 | "--data_root", 46 | type=str, 47 | required=True, 48 | help="Path to dataset directory" 49 | ) 50 | 51 | parser.add_argument( 52 | "--dataset", 53 | type=str, 54 | help="which dataset to evaluate", 55 | choices=["COCO", "ADE20K"], 56 | default="COCO" 57 | ) 58 | 59 | parser.add_argument( 60 | "--ckpt", 61 | type=str, 62 | default="models/ade20k.ckpt", 63 | help="path to checkpoint of model", 64 | ) 65 | 66 | opt = parser.parse_args() 67 | 68 | seed_everything(opt.seed) 69 | 70 | def get_state_dict(d): 71 | return d.get('state_dict', d) 72 | 73 | targetpth = opt.outdir 74 | if not os.path.exists(targetpth): 75 | os.mkdir(targetpth) 76 | 77 | config = OmegaConf.load(opt.config) 78 | model = instantiate_from_config(config.model).cpu() 79 | 80 | state_dict = get_state_dict(torch.load(opt.ckpt, map_location=torch.device('cuda'))) 81 | state_dict = get_state_dict(state_dict) 82 | model.load_state_dict(state_dict) 83 | model = model.cuda() 84 | 85 | sampler = PLMSSampler(model) 86 | 87 | if opt.dataset == 'ADE20K': 88 | dataset = ADE20KDataset(opt.data_root) 89 | elif opt.dataset == 'COCO': 90 | dataset = COCODataset(opt.data_root) 91 | 92 | dataloader = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=False) 93 | for batch in dataloader: 94 | names = batch['sourcepath'][0].split('/')[-1] 95 | print('processing:',targetpth+'/'+names) 96 | N = 1 97 | z, c = model.get_input(batch, model.first_stage_key, bs=1) 98 | c_tkscls = c['tkscls'][0][:N] 99 | c_cat, c, view_ctrol = c["c_concat"][0][:N], c["c_crossattn"][0][:N], c['viewcontrol'][0][:N] 100 | 101 | uc_cross = torch.zeros((c.shape[0],77),dtype=torch.int64) + 49407 102 | uc_cross[:,0] = 49406 103 | uc_cross = model.cond_stage_model.encode(uc_cross.to(model.device)) 104 | if isinstance(uc_cross, DiagonalGaussianDistribution): 105 | uc_cross = uc_cross.mode() 106 | 107 | uc_cat = c_cat 108 | uc_tkscls = torch.zeros_like(c_tkscls) + 49407 109 | uc_tkscls[:,0] = 49406 110 | uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross], "tkscls": [uc_tkscls]} 111 | 112 | cond = {"c_concat": [c_cat], "c_crossattn": [c], "tkscls": [c_tkscls]} 113 | un_cond ={"c_concat": [uc_cat], "c_crossattn": [uc_cross], "tkscls": [uc_tkscls]} 114 | 115 | H,W=512,512 116 | shape = (4, H // 8, W // 8) 117 | samples_ddim, _ = sampler.sample(50, 118 | conditioning=cond, 119 | batch_size=1, 120 | shape=shape, 121 | verbose=False, 122 | unconditional_guidance_scale=2.0, 123 | unconditional_conditioning=un_cond, 124 | eta=0.0, 125 | x_T=None) 126 | x_samples = model.decode_first_stage(samples_ddim) 127 | x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 128 | 129 | results = [x_samples[i] for i in range(1)] 130 | Image.fromarray(results[0]).save(os.path.join(targetpth,names.replace('png','jpg'))) 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /ldm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/.DS_Store -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | 6 | from ldm.modules.diffusionmodules.model import Encoder, Decoder 7 | from ldm.modules.distributions.distributions import DiagonalGaussianDistribution 8 | 9 | from ldm.util import instantiate_from_config 10 | from ldm.modules.ema import LitEma 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | ema_decay=None, 24 | learn_logvar=False 25 | ): 26 | super().__init__() 27 | self.learn_logvar = learn_logvar 28 | self.image_key = image_key 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | self.loss = instantiate_from_config(lossconfig) 32 | assert ddconfig["double_z"] 33 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 34 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 35 | self.embed_dim = embed_dim 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | self.use_ema = ema_decay is not None 43 | if self.use_ema: 44 | self.ema_decay = ema_decay 45 | assert 0. < ema_decay < 1. 46 | self.model_ema = LitEma(self, decay=ema_decay) 47 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 48 | 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | keys = list(sd.keys()) 55 | for k in keys: 56 | for ik in ignore_keys: 57 | if k.startswith(ik): 58 | print("Deleting key {} from state_dict.".format(k)) 59 | del sd[k] 60 | self.load_state_dict(sd, strict=False) 61 | print(f"Restored from {path}") 62 | 63 | @contextmanager 64 | def ema_scope(self, context=None): 65 | if self.use_ema: 66 | self.model_ema.store(self.parameters()) 67 | self.model_ema.copy_to(self) 68 | if context is not None: 69 | print(f"{context}: Switched to EMA weights") 70 | try: 71 | yield None 72 | finally: 73 | if self.use_ema: 74 | self.model_ema.restore(self.parameters()) 75 | if context is not None: 76 | print(f"{context}: Restored training weights") 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | if self.use_ema: 80 | self.model_ema(self) 81 | 82 | def encode(self, x): 83 | h = self.encoder(x) 84 | moments = self.quant_conv(h) 85 | posterior = DiagonalGaussianDistribution(moments) 86 | return posterior 87 | 88 | def decode(self, z): 89 | z = self.post_quant_conv(z) 90 | dec = self.decoder(z) 91 | return dec 92 | 93 | def forward(self, input, sample_posterior=True): 94 | posterior = self.encode(input) 95 | if sample_posterior: 96 | z = posterior.sample() 97 | else: 98 | z = posterior.mode() 99 | dec = self.decode(z) 100 | return dec, posterior 101 | 102 | def get_input(self, batch, k): 103 | x = batch[k] 104 | if len(x.shape) == 3: 105 | x = x[..., None] 106 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 107 | return x 108 | 109 | def training_step(self, batch, batch_idx, optimizer_idx): 110 | inputs = self.get_input(batch, self.image_key) 111 | reconstructions, posterior = self(inputs) 112 | 113 | if optimizer_idx == 0: 114 | # train encoder+decoder+logvar 115 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 116 | last_layer=self.get_last_layer(), split="train") 117 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 118 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 119 | return aeloss 120 | 121 | if optimizer_idx == 1: 122 | # train the discriminator 123 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 124 | last_layer=self.get_last_layer(), split="train") 125 | 126 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 127 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 128 | return discloss 129 | 130 | def validation_step(self, batch, batch_idx): 131 | log_dict = self._validation_step(batch, batch_idx) 132 | with self.ema_scope(): 133 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 134 | return log_dict 135 | 136 | def _validation_step(self, batch, batch_idx, postfix=""): 137 | inputs = self.get_input(batch, self.image_key) 138 | reconstructions, posterior = self(inputs) 139 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 140 | last_layer=self.get_last_layer(), split="val"+postfix) 141 | 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 143 | last_layer=self.get_last_layer(), split="val"+postfix) 144 | 145 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 146 | self.log_dict(log_dict_ae) 147 | self.log_dict(log_dict_disc) 148 | return self.log_dict 149 | 150 | def configure_optimizers(self): 151 | lr = self.learning_rate 152 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 153 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 154 | if self.learn_logvar: 155 | print(f"{self.__class__.__name__}: Learning logvar") 156 | ae_params_list.append(self.loss.logvar) 157 | opt_ae = torch.optim.Adam(ae_params_list, 158 | lr=lr, betas=(0.5, 0.9)) 159 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 160 | lr=lr, betas=(0.5, 0.9)) 161 | return [opt_ae, opt_disc], [] 162 | 163 | def get_last_layer(self): 164 | return self.decoder.conv_out.weight 165 | 166 | @torch.no_grad() 167 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 168 | log = dict() 169 | x = self.get_input(batch, self.image_key) 170 | x = x.to(self.device) 171 | if not only_inputs: 172 | xrec, posterior = self(x) 173 | if x.shape[1] > 3: 174 | # colorize with random projection 175 | assert xrec.shape[1] > 3 176 | x = self.to_rgb(x) 177 | xrec = self.to_rgb(xrec) 178 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 179 | log["reconstructions"] = xrec 180 | if log_ema or self.use_ema: 181 | with self.ema_scope(): 182 | xrec_ema, posterior_ema = self(x) 183 | if x.shape[1] > 3: 184 | # colorize with random projection 185 | assert xrec_ema.shape[1] > 3 186 | xrec_ema = self.to_rgb(xrec_ema) 187 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 188 | log["reconstructions_ema"] = xrec_ema 189 | log["inputs"] = x 190 | return log 191 | 192 | def to_rgb(self, x): 193 | assert self.image_key == "segmentation" 194 | if not hasattr(self, "colorize"): 195 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 196 | x = F.conv2d(x, weight=self.colorize) 197 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 198 | return x 199 | 200 | 201 | class IdentityFirstStage(torch.nn.Module): 202 | def __init__(self, *args, vq_interface=False, **kwargs): 203 | self.vq_interface = vq_interface 204 | super().__init__() 205 | 206 | def encode(self, x, *args, **kwargs): 207 | return x 208 | 209 | def decode(self, x, *args, **kwargs): 210 | return x 211 | 212 | def quantize(self, x, *args, **kwargs): 213 | if self.vq_interface: 214 | return x, None, [None, None, None] 215 | return x 216 | 217 | def forward(self, x, *args, **kwargs): 218 | return x 219 | 220 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/myplms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/myplms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/plms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/sampling_util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/testplms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/__pycache__/testplms.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/myplms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class MYPLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | ''' 85 | if conditioning is not None: 86 | if isinstance(conditioning, dict): 87 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 88 | if cbs != batch_size: 89 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 90 | else: 91 | if conditioning.shape[0] != batch_size: 92 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 93 | ''' 94 | if conditioning is not None: 95 | if isinstance(conditioning, dict): 96 | ctmp = conditioning[list(conditioning.keys())[0]] 97 | while isinstance(ctmp, list): ctmp = ctmp[0] 98 | cbs = ctmp.shape[0] 99 | if cbs != batch_size: 100 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 101 | 102 | elif isinstance(conditioning, list): 103 | for ctmp in conditioning: 104 | if ctmp.shape[0] != batch_size: 105 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 106 | 107 | else: 108 | if conditioning.shape[0] != batch_size: 109 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 110 | 111 | 112 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 113 | # sampling 114 | C, H, W = shape 115 | size = (batch_size, C, H, W) 116 | print(f'Data shape for PLMS sampling is {size}') 117 | 118 | samples, intermediates = self.plms_sampling(conditioning, size, 119 | callback=callback, 120 | img_callback=img_callback, 121 | quantize_denoised=quantize_x0, 122 | mask=mask, x0=x0, 123 | ddim_use_original_steps=False, 124 | noise_dropout=noise_dropout, 125 | temperature=temperature, 126 | score_corrector=score_corrector, 127 | corrector_kwargs=corrector_kwargs, 128 | x_T=x_T, 129 | log_every_t=log_every_t, 130 | unconditional_guidance_scale=unconditional_guidance_scale, 131 | unconditional_conditioning=unconditional_conditioning, 132 | dynamic_threshold=dynamic_threshold, 133 | ) 134 | return samples, intermediates 135 | 136 | @torch.no_grad() 137 | def plms_sampling(self, cond, shape, 138 | x_T=None, ddim_use_original_steps=False, 139 | callback=None, timesteps=None, quantize_denoised=False, 140 | mask=None, x0=None, img_callback=None, log_every_t=100, 141 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 142 | unconditional_guidance_scale=1., unconditional_conditioning=None, 143 | dynamic_threshold=None): 144 | device = self.model.betas.device 145 | b = shape[0] 146 | if x_T is None: 147 | img = torch.randn(shape, device=device) 148 | print('updated!') 149 | img = torch.nn.init.trunc_normal_(img, mean=0.0, std=1.0, a=-3.0, b=3.0) 150 | else: 151 | img = x_T 152 | 153 | if timesteps is None: 154 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 155 | elif timesteps is not None and not ddim_use_original_steps: 156 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 157 | timesteps = self.ddim_timesteps[:subset_end] 158 | 159 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 160 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 161 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 162 | print(f"Running PLMS Sampling with {total_steps} timesteps") 163 | 164 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 165 | old_eps = [] 166 | 167 | for i, step in enumerate(iterator): 168 | index = total_steps - i - 1 169 | ts = torch.full((b,), step, device=device, dtype=torch.long) 170 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 171 | 172 | if mask is not None: 173 | assert x0 is not None 174 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 175 | img = img_orig * mask + (1. - mask) * img 176 | 177 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 178 | quantize_denoised=quantize_denoised, temperature=temperature, 179 | noise_dropout=noise_dropout, score_corrector=score_corrector, 180 | corrector_kwargs=corrector_kwargs, 181 | unconditional_guidance_scale=unconditional_guidance_scale, 182 | unconditional_conditioning=unconditional_conditioning, 183 | old_eps=old_eps, t_next=ts_next, 184 | dynamic_threshold=dynamic_threshold) 185 | img, pred_x0, e_t = outs 186 | old_eps.append(e_t) 187 | if len(old_eps) >= 4: 188 | old_eps.pop(0) 189 | if callback: callback(i) 190 | if img_callback: img_callback(pred_x0, i) 191 | 192 | if index % log_every_t == 0 or index == total_steps - 1: 193 | intermediates['x_inter'].append(img) 194 | intermediates['pred_x0'].append(pred_x0) 195 | 196 | return img, intermediates 197 | 198 | @torch.no_grad() 199 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 200 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 201 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 202 | dynamic_threshold=None): 203 | b, *_, device = *x.shape, x.device 204 | 205 | def get_model_output(x, t): 206 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 207 | e_t = self.model.apply_model(x, t, c) 208 | else: 209 | x_in = torch.cat([x] * 2) 210 | t_in = torch.cat([t] * 2) 211 | #c_in = torch.cat([unconditional_conditioning, c]) 212 | 213 | if isinstance(c, dict): 214 | assert isinstance(unconditional_conditioning, dict) 215 | c_in = dict() 216 | for k in c: 217 | if isinstance(c[k], list): 218 | c_in[k] = [torch.cat([ 219 | unconditional_conditioning[k][i], 220 | c[k][i]]) for i in range(len(c[k]))] 221 | else: 222 | c_in[k] = torch.cat([ 223 | unconditional_conditioning[k], 224 | c[k]]) 225 | elif isinstance(c, list): 226 | c_in = list() 227 | assert isinstance(unconditional_conditioning, list) 228 | for i in range(len(c)): 229 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) 230 | else: 231 | c_in = torch.cat([unconditional_conditioning, c]) 232 | 233 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 234 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 235 | if score_corrector is not None: 236 | assert self.model.parameterization == "eps" 237 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 238 | 239 | return e_t 240 | 241 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 242 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 243 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 244 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 245 | 246 | def get_x_prev_and_pred_x0(e_t, index): 247 | # select parameters corresponding to the currently considered timestep 248 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 249 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 250 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 251 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 252 | 253 | # current prediction for x_0 254 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 255 | if quantize_denoised: 256 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 257 | if dynamic_threshold is not None: 258 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 259 | # direction pointing to x_t 260 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 261 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 262 | if noise_dropout > 0.: 263 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 264 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 265 | return x_prev, pred_x0 266 | 267 | e_t = get_model_output(x, t) 268 | if len(old_eps) == 0: 269 | # Pseudo Improved Euler (2nd order) 270 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 271 | e_t_next = get_model_output(x_prev, t_next) 272 | e_t_prime = (e_t + e_t_next) / 2 273 | elif len(old_eps) == 1: 274 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 275 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 276 | elif len(old_eps) == 2: 277 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 278 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 279 | elif len(old_eps) >= 3: 280 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 281 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 282 | 283 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 284 | 285 | return x_prev, pred_x0, e_t 286 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | ''' 85 | if conditioning is not None: 86 | if isinstance(conditioning, dict): 87 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 88 | if cbs != batch_size: 89 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 90 | else: 91 | if conditioning.shape[0] != batch_size: 92 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 93 | ''' 94 | if conditioning is not None: 95 | if isinstance(conditioning, dict): 96 | ctmp = conditioning[list(conditioning.keys())[0]] 97 | while isinstance(ctmp, list): ctmp = ctmp[0] 98 | cbs = ctmp.shape[0] 99 | if cbs != batch_size: 100 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 101 | 102 | elif isinstance(conditioning, list): 103 | for ctmp in conditioning: 104 | if ctmp.shape[0] != batch_size: 105 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 106 | 107 | else: 108 | if conditioning.shape[0] != batch_size: 109 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 110 | 111 | 112 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 113 | # sampling 114 | C, H, W = shape 115 | size = (batch_size, C, H, W) 116 | print(f'Data shape for PLMS sampling is {size}') 117 | 118 | samples, intermediates = self.plms_sampling(conditioning, size, 119 | callback=callback, 120 | img_callback=img_callback, 121 | quantize_denoised=quantize_x0, 122 | mask=mask, x0=x0, 123 | ddim_use_original_steps=False, 124 | noise_dropout=noise_dropout, 125 | temperature=temperature, 126 | score_corrector=score_corrector, 127 | corrector_kwargs=corrector_kwargs, 128 | x_T=x_T, 129 | log_every_t=log_every_t, 130 | unconditional_guidance_scale=unconditional_guidance_scale, 131 | unconditional_conditioning=unconditional_conditioning, 132 | dynamic_threshold=dynamic_threshold, 133 | ) 134 | return samples, intermediates 135 | 136 | @torch.no_grad() 137 | def plms_sampling(self, cond, shape, 138 | x_T=None, ddim_use_original_steps=False, 139 | callback=None, timesteps=None, quantize_denoised=False, 140 | mask=None, x0=None, img_callback=None, log_every_t=100, 141 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 142 | unconditional_guidance_scale=1., unconditional_conditioning=None, 143 | dynamic_threshold=None): 144 | device = self.model.betas.device 145 | b = shape[0] 146 | if x_T is None: 147 | img = torch.randn(shape, device=device) 148 | else: 149 | img = x_T 150 | 151 | if timesteps is None: 152 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 153 | elif timesteps is not None and not ddim_use_original_steps: 154 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 155 | timesteps = self.ddim_timesteps[:subset_end] 156 | 157 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 158 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 159 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 160 | print(f"Running PLMS Sampling with {total_steps} timesteps") 161 | 162 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 163 | old_eps = [] 164 | 165 | for i, step in enumerate(iterator): 166 | index = total_steps - i - 1 167 | ts = torch.full((b,), step, device=device, dtype=torch.long) 168 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 169 | 170 | if mask is not None: 171 | assert x0 is not None 172 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 173 | img = img_orig * mask + (1. - mask) * img 174 | 175 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 176 | quantize_denoised=quantize_denoised, temperature=temperature, 177 | noise_dropout=noise_dropout, score_corrector=score_corrector, 178 | corrector_kwargs=corrector_kwargs, 179 | unconditional_guidance_scale=unconditional_guidance_scale, 180 | unconditional_conditioning=unconditional_conditioning, 181 | old_eps=old_eps, t_next=ts_next, 182 | dynamic_threshold=dynamic_threshold) 183 | img, pred_x0, e_t = outs 184 | old_eps.append(e_t) 185 | if len(old_eps) >= 4: 186 | old_eps.pop(0) 187 | if callback: callback(i) 188 | if img_callback: img_callback(pred_x0, i) 189 | 190 | if index % log_every_t == 0 or index == total_steps - 1: 191 | intermediates['x_inter'].append(img) 192 | intermediates['pred_x0'].append(pred_x0) 193 | 194 | return img, intermediates 195 | 196 | @torch.no_grad() 197 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 198 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 199 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 200 | dynamic_threshold=None): 201 | b, *_, device = *x.shape, x.device 202 | 203 | def get_model_output(x, t): 204 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 205 | e_t = self.model.apply_model(x, t, c) 206 | else: 207 | x_in = torch.cat([x] * 2) 208 | t_in = torch.cat([t] * 2) 209 | #c_in = torch.cat([unconditional_conditioning, c]) 210 | 211 | if isinstance(c, dict): 212 | assert isinstance(unconditional_conditioning, dict) 213 | c_in = dict() 214 | for k in c: 215 | if isinstance(c[k], list): 216 | c_in[k] = [torch.cat([ 217 | unconditional_conditioning[k][i], 218 | c[k][i]]) for i in range(len(c[k]))] 219 | else: 220 | c_in[k] = torch.cat([ 221 | unconditional_conditioning[k], 222 | c[k]]) 223 | elif isinstance(c, list): 224 | c_in = list() 225 | assert isinstance(unconditional_conditioning, list) 226 | for i in range(len(c)): 227 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) 228 | else: 229 | c_in = torch.cat([unconditional_conditioning, c]) 230 | 231 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 232 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 233 | if score_corrector is not None: 234 | assert self.model.parameterization == "eps" 235 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 236 | 237 | return e_t 238 | 239 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 240 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 241 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 242 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 243 | 244 | def get_x_prev_and_pred_x0(e_t, index): 245 | # select parameters corresponding to the currently considered timestep 246 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 247 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 248 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 249 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 250 | 251 | # current prediction for x_0 252 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 253 | if quantize_denoised: 254 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 255 | if dynamic_threshold is not None: 256 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 257 | # direction pointing to x_t 258 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 259 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 260 | if noise_dropout > 0.: 261 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 262 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 263 | return x_prev, pred_x0 264 | 265 | e_t = get_model_output(x, t) 266 | if len(old_eps) == 0: 267 | # Pseudo Improved Euler (2nd order) 268 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 269 | e_t_next = get_model_output(x_prev, t_next) 270 | e_t_prime = (e_t + e_t_next) / 2 271 | elif len(old_eps) == 1: 272 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 273 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 274 | elif len(old_eps) == 2: 275 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 276 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 277 | elif len(old_eps) >= 3: 278 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 279 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 280 | 281 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 282 | 283 | 284 | ''' 285 | print(x_prev.shape,pred_x0.shape) 286 | xp = self.model.decode_first_stage(x_prev) 287 | import einops 288 | from PIL import Image 289 | xp = (einops.rearrange(xp, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)[0] 290 | # #print(xp.shape,'@@@@@@@@@@@@@@@') 291 | # #print(t,'############') 292 | Image.fromarray(xp).save('drawview_good225/xp_'+str(float(t[0]))+'.png') 293 | px0 = self.model.decode_first_stage(pred_x0) 294 | px0 = (einops.rearrange(px0, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)[0] 295 | Image.fromarray(px0).save('drawview_good225/px0_'+str(float(t[0]))+'.png') 296 | ''' 297 | 298 | return x_prev, pred_x0, e_t 299 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/models/diffusion/testplms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class TestPLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | names=None, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | x_T=None, 78 | log_every_t=100, 79 | unconditional_guidance_scale=1., 80 | unconditional_conditioning=None, 81 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 82 | dynamic_threshold=None, 83 | **kwargs 84 | ): 85 | ''' 86 | if conditioning is not None: 87 | if isinstance(conditioning, dict): 88 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 89 | if cbs != batch_size: 90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 91 | else: 92 | if conditioning.shape[0] != batch_size: 93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 94 | ''' 95 | if conditioning is not None: 96 | if isinstance(conditioning, dict): 97 | ctmp = conditioning[list(conditioning.keys())[0]] 98 | while isinstance(ctmp, list): ctmp = ctmp[0] 99 | cbs = ctmp.shape[0] 100 | if cbs != batch_size: 101 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 102 | 103 | elif isinstance(conditioning, list): 104 | for ctmp in conditioning: 105 | if ctmp.shape[0] != batch_size: 106 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 107 | 108 | else: 109 | if conditioning.shape[0] != batch_size: 110 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 111 | 112 | 113 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 114 | # sampling 115 | C, H, W = shape 116 | size = (batch_size, C, H, W) 117 | print(f'Data shape for PLMS sampling is {size}') 118 | 119 | samples, intermediates = self.plms_sampling(conditioning, size, 120 | names=names, 121 | callback=callback, 122 | img_callback=img_callback, 123 | quantize_denoised=quantize_x0, 124 | mask=mask, x0=x0, 125 | ddim_use_original_steps=False, 126 | noise_dropout=noise_dropout, 127 | temperature=temperature, 128 | score_corrector=score_corrector, 129 | corrector_kwargs=corrector_kwargs, 130 | x_T=x_T, 131 | log_every_t=log_every_t, 132 | unconditional_guidance_scale=unconditional_guidance_scale, 133 | unconditional_conditioning=unconditional_conditioning, 134 | dynamic_threshold=dynamic_threshold, 135 | ) 136 | return samples, intermediates 137 | 138 | @torch.no_grad() 139 | def plms_sampling(self, cond, shape, names=None, 140 | x_T=None, ddim_use_original_steps=False, 141 | callback=None, timesteps=None, quantize_denoised=False, 142 | mask=None, x0=None, img_callback=None, log_every_t=100, 143 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 144 | unconditional_guidance_scale=1., unconditional_conditioning=None, 145 | dynamic_threshold=None): 146 | device = self.model.betas.device 147 | b = shape[0] 148 | #if names!='ADE_val_00000027.png': 149 | # return None, None 150 | 151 | if x_T is None: 152 | img = torch.randn(shape, device=device) 153 | else: 154 | img = x_T 155 | 156 | if names!='ADE_val_00000027.png': 157 | return None, None 158 | 159 | if timesteps is None: 160 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 161 | elif timesteps is not None and not ddim_use_original_steps: 162 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 163 | timesteps = self.ddim_timesteps[:subset_end] 164 | 165 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 166 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 167 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 168 | print(f"Running PLMS Sampling with {total_steps} timesteps") 169 | 170 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 171 | old_eps = [] 172 | 173 | for i, step in enumerate(iterator): 174 | index = total_steps - i - 1 175 | ts = torch.full((b,), step, device=device, dtype=torch.long) 176 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 177 | 178 | if mask is not None: 179 | assert x0 is not None 180 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 181 | img = img_orig * mask + (1. - mask) * img 182 | 183 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 184 | quantize_denoised=quantize_denoised, temperature=temperature, 185 | noise_dropout=noise_dropout, score_corrector=score_corrector, 186 | corrector_kwargs=corrector_kwargs, 187 | unconditional_guidance_scale=unconditional_guidance_scale, 188 | unconditional_conditioning=unconditional_conditioning, 189 | old_eps=old_eps, t_next=ts_next, 190 | dynamic_threshold=dynamic_threshold) 191 | print('##############################',outs[1].shape, ts) 192 | #views_predx = outs[1].permute((0,2,3,1))[0].detach 193 | views_predx = self.model.decode_first_stage(outs[1]) 194 | import einops 195 | from PIL import Image 196 | x_pred0 = (einops.rearrange(views_predx, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) 197 | Image.fromarray(x_pred0[0]).save('viewx0/'+str(int(ts[0]))+'_'+names) 198 | 199 | img, pred_x0, e_t = outs 200 | old_eps.append(e_t) 201 | if len(old_eps) >= 4: 202 | old_eps.pop(0) 203 | if callback: callback(i) 204 | if img_callback: img_callback(pred_x0, i) 205 | 206 | if index % log_every_t == 0 or index == total_steps - 1: 207 | intermediates['x_inter'].append(img) 208 | intermediates['pred_x0'].append(pred_x0) 209 | 210 | return img, intermediates 211 | 212 | @torch.no_grad() 213 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 214 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 215 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 216 | dynamic_threshold=None): 217 | b, *_, device = *x.shape, x.device 218 | 219 | def get_model_output(x, t): 220 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 221 | e_t = self.model.apply_model(x, t, c) 222 | else: 223 | x_in = torch.cat([x] * 2) 224 | t_in = torch.cat([t] * 2) 225 | #c_in = torch.cat([unconditional_conditioning, c]) 226 | 227 | if isinstance(c, dict): 228 | assert isinstance(unconditional_conditioning, dict) 229 | c_in = dict() 230 | for k in c: 231 | if isinstance(c[k], list): 232 | c_in[k] = [torch.cat([ 233 | unconditional_conditioning[k][i], 234 | c[k][i]]) for i in range(len(c[k]))] 235 | else: 236 | c_in[k] = torch.cat([ 237 | unconditional_conditioning[k], 238 | c[k]]) 239 | elif isinstance(c, list): 240 | c_in = list() 241 | assert isinstance(unconditional_conditioning, list) 242 | for i in range(len(c)): 243 | c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) 244 | else: 245 | c_in = torch.cat([unconditional_conditioning, c]) 246 | 247 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 248 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 249 | if score_corrector is not None: 250 | assert self.model.parameterization == "eps" 251 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 252 | 253 | return e_t 254 | 255 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 256 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 257 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 258 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 259 | 260 | def get_x_prev_and_pred_x0(e_t, index): 261 | # select parameters corresponding to the currently considered timestep 262 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 263 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 264 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 265 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 266 | 267 | # current prediction for x_0 268 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 269 | if quantize_denoised: 270 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 271 | if dynamic_threshold is not None: 272 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 273 | # direction pointing to x_t 274 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 275 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 276 | if noise_dropout > 0.: 277 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 278 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 279 | return x_prev, pred_x0 280 | 281 | e_t = get_model_output(x, t) 282 | if len(old_eps) == 0: 283 | # Pseudo Improved Euler (2nd order) 284 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 285 | e_t_next = get_model_output(x_prev, t_next) 286 | e_t_prime = (e_t + e_t_next) / 2 287 | elif len(old_eps) == 1: 288 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 289 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 290 | elif len(old_eps) == 2: 291 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 292 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 293 | elif len(old_eps) >= 3: 294 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 295 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 296 | 297 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 298 | 299 | return x_prev, pred_x0, e_t 300 | -------------------------------------------------------------------------------- /ldm/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/.DS_Store -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 126 | "dtype": torch.get_autocast_gpu_dtype(), 127 | "cache_enabled": torch.is_autocast_cache_enabled()} 128 | with torch.no_grad(): 129 | output_tensors = ctx.run_function(*ctx.input_tensors) 130 | return output_tensors 131 | 132 | @staticmethod 133 | def backward(ctx, *output_grads): 134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 135 | with torch.enable_grad(), \ 136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 137 | # Fixes a bug where the first op in run_function modifies the 138 | # Tensor storage in place, which is not allowed for detach()'d 139 | # Tensors. 140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 141 | output_tensors = ctx.run_function(*shallow_copies) 142 | 143 | #print(len(ctx.input_tensors),len(ctx.input_params),ctx.run_function,'################') 144 | input_grads = torch.autograd.grad( 145 | output_tensors, 146 | ctx.input_tensors + ctx.input_params, 147 | output_grads, 148 | allow_unused=True, 149 | ) 150 | del ctx.input_tensors 151 | del ctx.input_params 152 | del output_tensors 153 | return (None, None) + input_grads 154 | 155 | 156 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 157 | """ 158 | Create sinusoidal timestep embeddings. 159 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 160 | These may be fractional. 161 | :param dim: the dimension of the output. 162 | :param max_period: controls the minimum frequency of the embeddings. 163 | :return: an [N x dim] Tensor of positional embeddings. 164 | """ 165 | if not repeat_only: 166 | half = dim // 2 167 | freqs = torch.exp( 168 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 169 | ).to(device=timesteps.device) 170 | args = timesteps[:, None].float() * freqs[None] 171 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 172 | if dim % 2: 173 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 174 | else: 175 | embedding = repeat(timesteps, 'b -> b d', d=dim) 176 | return embedding 177 | 178 | 179 | def zero_module(module): 180 | """ 181 | Zero out the parameters of a module and return it. 182 | """ 183 | for p in module.parameters(): 184 | p.detach().zero_() 185 | return module 186 | 187 | 188 | def scale_module(module, scale): 189 | """ 190 | Scale the parameters of a module and return it. 191 | """ 192 | for p in module.parameters(): 193 | p.detach().mul_(scale) 194 | return module 195 | 196 | 197 | def mean_flat(tensor): 198 | """ 199 | Take the mean over all non-batch dimensions. 200 | """ 201 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 202 | 203 | 204 | def normalization(channels): 205 | """ 206 | Make a standard normalization layer. 207 | :param channels: number of input channels. 208 | :return: an nn.Module for normalization. 209 | """ 210 | return GroupNorm32(32, channels) 211 | 212 | 213 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 214 | class SiLU(nn.Module): 215 | def forward(self, x): 216 | return x * torch.sigmoid(x) 217 | 218 | 219 | class GroupNorm32(nn.GroupNorm): 220 | def forward(self, x): 221 | return super().forward(x.float()).type(x.dtype) 222 | 223 | def conv_nd(dims, *args, **kwargs): 224 | """ 225 | Create a 1D, 2D, or 3D convolution module. 226 | """ 227 | if dims == 1: 228 | return nn.Conv1d(*args, **kwargs) 229 | elif dims == 2: 230 | return nn.Conv2d(*args, **kwargs) 231 | elif dims == 3: 232 | return nn.Conv3d(*args, **kwargs) 233 | raise ValueError(f"unsupported dimensions: {dims}") 234 | 235 | 236 | def linear(*args, **kwargs): 237 | """ 238 | Create a linear module. 239 | """ 240 | return nn.Linear(*args, **kwargs) 241 | 242 | 243 | def avg_pool_nd(dims, *args, **kwargs): 244 | """ 245 | Create a 1D, 2D, or 3D average pooling module. 246 | """ 247 | if dims == 1: 248 | return nn.AvgPool1d(*args, **kwargs) 249 | elif dims == 2: 250 | return nn.AvgPool2d(*args, **kwargs) 251 | elif dims == 3: 252 | return nn.AvgPool3d(*args, **kwargs) 253 | raise ValueError(f"unsupported dimensions: {dims}") 254 | 255 | 256 | class HybridConditioner(nn.Module): 257 | 258 | def __init__(self, c_concat_config, c_crossattn_config): 259 | super().__init__() 260 | self.concat_conditioner = instantiate_from_config(c_concat_config) 261 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 262 | 263 | def forward(self, c_concat, c_crossattn): 264 | c_concat = self.concat_conditioner(c_concat) 265 | c_crossattn = self.crossattn_conditioner(c_crossattn) 266 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 267 | 268 | 269 | def noise_like(shape, device, repeat=False): 270 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 271 | noise = lambda: torch.randn(shape, device=device) 272 | return repeat_noise() if repeat else noise() 273 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 6 | 7 | import open_clip 8 | from ldm.util import default, count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | 21 | def encode(self, x): 22 | return x 23 | 24 | 25 | class ClassEmbedder(nn.Module): 26 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 27 | super().__init__() 28 | self.key = key 29 | self.embedding = nn.Embedding(n_classes, embed_dim) 30 | self.n_classes = n_classes 31 | self.ucg_rate = ucg_rate 32 | 33 | def forward(self, batch, key=None, disable_dropout=False): 34 | if key is None: 35 | key = self.key 36 | # this is for use in crossattn 37 | c = batch[key][:, None] 38 | if self.ucg_rate > 0. and not disable_dropout: 39 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 40 | c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) 41 | c = c.long() 42 | c = self.embedding(c) 43 | return c 44 | 45 | def get_unconditional_conditioning(self, bs, device="cuda"): 46 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 47 | uc = torch.ones((bs,), device=device) * uc_class 48 | uc = {self.key: uc} 49 | return uc 50 | 51 | 52 | def disabled_train(self, mode=True): 53 | """Overwrite model.train with this function to make sure train/eval mode 54 | does not change anymore.""" 55 | return self 56 | 57 | 58 | class FrozenT5Embedder(AbstractEncoder): 59 | """Uses the T5 transformer encoder for text""" 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 61 | super().__init__() 62 | self.tokenizer = T5Tokenizer.from_pretrained(version) 63 | self.transformer = T5EncoderModel.from_pretrained(version) 64 | self.device = device 65 | self.max_length = max_length # TODO: typical value? 66 | if freeze: 67 | self.freeze() 68 | 69 | def freeze(self): 70 | self.transformer = self.transformer.eval() 71 | #self.train = disabled_train 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, text): 76 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 77 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 78 | tokens = batch_encoding["input_ids"].to(self.device) 79 | outputs = self.transformer(input_ids=tokens) 80 | 81 | z = outputs.last_hidden_state 82 | return z 83 | 84 | def encode(self, text): 85 | return self(text) 86 | 87 | 88 | class FrozenCLIPEmbedder(AbstractEncoder): 89 | """Uses the CLIP transformer encoder for text (from huggingface)""" 90 | LAYERS = [ 91 | "last", 92 | "pooled", 93 | "hidden" 94 | ] 95 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 96 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 97 | super().__init__() 98 | assert layer in self.LAYERS 99 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 100 | self.transformer = CLIPTextModel.from_pretrained(version) 101 | self.device = device 102 | self.max_length = max_length 103 | if freeze: 104 | self.freeze() 105 | self.layer = layer 106 | self.layer_idx = layer_idx 107 | if layer == "hidden": 108 | assert layer_idx is not None 109 | assert 0 <= abs(layer_idx) <= 12 110 | 111 | def freeze(self): 112 | self.transformer = self.transformer.eval() 113 | #self.train = disabled_train 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, text): 118 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 119 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 120 | tokens = batch_encoding["input_ids"].to(self.device) 121 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 122 | if self.layer == "last": 123 | z = outputs.last_hidden_state 124 | elif self.layer == "pooled": 125 | z = outputs.pooler_output[:, None, :] 126 | else: 127 | z = outputs.hidden_states[self.layer_idx] 128 | return z 129 | 130 | def encode(self, text): 131 | return self(text) 132 | 133 | 134 | class MYFrozenCLIPEmbedder(AbstractEncoder): 135 | """Uses the CLIP transformer encoder for text (from huggingface)""" 136 | LAYERS = [ 137 | "last", 138 | "pooled", 139 | "hidden" 140 | ] 141 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 142 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 143 | super().__init__() 144 | assert layer in self.LAYERS 145 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 146 | self.transformer = CLIPTextModel.from_pretrained(version) 147 | self.device = device 148 | self.max_length = max_length 149 | if freeze: 150 | self.freeze() 151 | self.layer = layer 152 | self.layer_idx = layer_idx 153 | if layer == "hidden": 154 | assert layer_idx is not None 155 | assert 0 <= abs(layer_idx) <= 12 156 | 157 | def freeze(self): 158 | self.transformer = self.transformer.eval() 159 | #self.train = disabled_train 160 | for param in self.parameters(): 161 | param.requires_grad = False 162 | 163 | def forward(self, tokens): 164 | #batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 165 | # return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 166 | #print('realb',tokens.shape,tokens.dtype) 167 | tokens = tokens.to(self.device) 168 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") 169 | if self.layer == "last": 170 | z = outputs.last_hidden_state 171 | elif self.layer == "pooled": 172 | z = outputs.pooler_output[:, None, :] 173 | else: 174 | z = outputs.hidden_states[self.layer_idx] 175 | #print('reala',z.shape,z.dtype) 176 | return z 177 | 178 | def encode(self, tokens): 179 | return self(tokens) 180 | 181 | 182 | 183 | 184 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 185 | """ 186 | Uses the OpenCLIP transformer encoder for text 187 | """ 188 | LAYERS = [ 189 | #"pooled", 190 | "last", 191 | "penultimate" 192 | ] 193 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 194 | freeze=True, layer="last"): 195 | super().__init__() 196 | assert layer in self.LAYERS 197 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 198 | del model.visual 199 | self.model = model 200 | 201 | self.device = device 202 | self.max_length = max_length 203 | if freeze: 204 | self.freeze() 205 | self.layer = layer 206 | if self.layer == "last": 207 | self.layer_idx = 0 208 | elif self.layer == "penultimate": 209 | self.layer_idx = 1 210 | else: 211 | raise NotImplementedError() 212 | 213 | def freeze(self): 214 | self.model = self.model.eval() 215 | for param in self.parameters(): 216 | param.requires_grad = False 217 | 218 | def forward(self, text): 219 | tokens = open_clip.tokenize(text) 220 | z = self.encode_with_transformer(tokens.to(self.device)) 221 | return z 222 | 223 | def encode_with_transformer(self, text): 224 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 225 | x = x + self.model.positional_embedding 226 | x = x.permute(1, 0, 2) # NLD -> LND 227 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 228 | x = x.permute(1, 0, 2) # LND -> NLD 229 | x = self.model.ln_final(x) 230 | return x 231 | 232 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 233 | for i, r in enumerate(self.model.transformer.resblocks): 234 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 235 | break 236 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 237 | x = checkpoint(r, x, attn_mask) 238 | else: 239 | x = r(x, attn_mask=attn_mask) 240 | return x 241 | 242 | def encode(self, text): 243 | return self(text) 244 | 245 | 246 | class FrozenCLIPT5Encoder(AbstractEncoder): 247 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 248 | clip_max_length=77, t5_max_length=77): 249 | super().__init__() 250 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 251 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 252 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " 253 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") 254 | 255 | def encode(self, text): 256 | return self(text) 257 | 258 | def forward(self, text): 259 | clip_z = self.clip_encoder.encode(text) 260 | t5_z = self.t5_encoder.encode(text) 261 | return [clip_z, t5_z] 262 | 263 | 264 | 265 | 266 | 267 | class MYFrozenOpenCLIPEmbedder(AbstractEncoder): 268 | """ 269 | Uses the OpenCLIP transformer encoder for text 270 | """ 271 | LAYERS = [ 272 | #"pooled", 273 | "last", 274 | "penultimate" 275 | ] 276 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 277 | freeze=True, layer="penultimate"): 278 | super().__init__() 279 | assert layer in self.LAYERS 280 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cuda'), pretrained=version) 281 | del model.visual 282 | self.model = model 283 | 284 | self.device = device 285 | self.max_length = max_length 286 | if freeze: 287 | self.freeze() 288 | self.layer = layer 289 | if self.layer == "last": 290 | self.layer_idx = 0 291 | elif self.layer == "penultimate": 292 | self.layer_idx = 1 293 | else: 294 | raise NotImplementedError() 295 | 296 | def freeze(self): 297 | self.model = self.model.eval() 298 | for param in self.parameters(): 299 | param.requires_grad = False 300 | 301 | def forward(self, text): 302 | tokens = open_clip.tokenize(text) 303 | z = self.encode_with_transformer(tokens.to(self.device)) 304 | return (z[torch.arange(z.shape[0]), 0], z[torch.arange(z.shape[0]), tokens.argmax(dim=-1)]) 305 | 306 | def encode_with_transformer(self, text): 307 | #print(text.device,self.device, self.model.device) 308 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 309 | x = x + self.model.positional_embedding 310 | x = x.permute(1, 0, 2) # NLD -> LND 311 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 312 | x = x.permute(1, 0, 2) # LND -> NLD 313 | x = self.model.ln_final(x) 314 | return x 315 | 316 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 317 | for i, r in enumerate(self.model.transformer.resblocks): 318 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 319 | break 320 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 321 | x = checkpoint(r, x, attn_mask) 322 | else: 323 | x = r(x, attn_mask=attn_mask) 324 | return x 325 | 326 | def encode(self, text): 327 | return self(text) 328 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss -------------------------------------------------------------------------------- /resources/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/.DS_Store -------------------------------------------------------------------------------- /resources/bibtex.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{lv24cvpr, 2 | author = "Lv, Z. and Wei, Y. and Zuo, W. and Wong, K.-Y.~K.", 3 | title = "PLACE: Adaptive Layout-Semantic Fusion for Semantic Image Synthesis", 4 | booktitle = "Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition", 5 | volume = "", 6 | pages = "", 7 | address = "Seattle, Washington", 8 | month = "June", 9 | year = "2024" 10 | } -------------------------------------------------------------------------------- /resources/ind.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/ind.png -------------------------------------------------------------------------------- /resources/method_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/method_diagram.png -------------------------------------------------------------------------------- /resources/newfig2_final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/newfig2_final.pdf -------------------------------------------------------------------------------- /resources/od.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/od.png -------------------------------------------------------------------------------- /resources/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/overview.png -------------------------------------------------------------------------------- /resources/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/paper.png -------------------------------------------------------------------------------- /resources/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cszy98/PLACE/4594350491df5c7489eb4edf0802f303692dd530/resources/teaser.png -------------------------------------------------------------------------------- /run_inference_ADE20K.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICE=0 python inference.py --outdir output/ade20k \ 2 | --config configs/stable-diffusion/PLACE.yaml \ 3 | --ckpt ckpt/ade20k_best.ckpt \ 4 | --dataset ADE20K \ 5 | --data_root path_to_ADE20K 6 | -------------------------------------------------------------------------------- /run_inference_COCO.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICE=0 python inference.py --outdir output/COCO \ 2 | --config configs/stable-diffusion/PLACE.yaml \ 3 | --ckpt ckpt/coco_best.ckpt \ 4 | --dataset COCO \ 5 | --data_root path_to_COCO 6 | --------------------------------------------------------------------------------