├── .gitignore ├── LICENSE ├── README.md ├── assets ├── a-container-0038.jpg ├── a-dog-on-top-of-sks-container-0023.jpg ├── a-red-sks-container-0021.jpg ├── photo-of-a-sks-container-0018.jpg ├── photo-of-a-sks-container-on-the-beach-0017.jpg └── photo-of-a-sks-container-on-the-moon-0016.jpg ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ ├── txt2img-1p4B-eval.yaml │ ├── txt2img-1p4B-eval_with_tokens.yaml │ ├── txt2img-1p4B-finetune.yaml │ └── txt2img-1p4B-finetune_style.yaml └── stable-diffusion │ ├── v1-finetune.yaml │ ├── v1-finetune_unfrozen.yaml │ └── v1-inference.yaml ├── data └── DejaVuSans.ttf ├── dreambooth_runpod_joepenna.ipynb ├── environment.yaml ├── evaluation └── clip_eval.py ├── gdrive ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ ├── lsun.py │ ├── personalized.py │ ├── personalized_batch.py │ └── personalized_style.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── embedding_manager.py │ ├── encoders │ │ ├── __init__.py │ │ ├── modules.py │ │ └── modules_bak.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── merge_embeddings.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── ffhq256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_beds256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── prune_ckpt.py ├── readme-images ├── better-training-images.png ├── caution-training.png ├── vast-ai-step1-select-docker-image.png ├── vast-ai-step2-instance-filters.png ├── vast-ai-step3-instances.png ├── vast-ai-step4-get-repo.png ├── vast-ai-step5-clone-repo.png └── vast-ai-step6-open-notebook.png ├── scripts ├── download_first_stages.sh ├── download_models.sh ├── evaluate_model.py ├── inpaint.py ├── latent_imagenet_diffusion.ipynb ├── prune-ckpt.py ├── sample_diffusion.py ├── stable_img2img.py ├── stable_txt2img.py └── txt2img.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | latent_diffusion.egg-info 3 | .DS_Store 4 | gen.bat 5 | gen_ref.bat 6 | train.bat 7 | __pycache__ 8 | */**/__pycache__ 9 | logs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/a-container-0038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/a-container-0038.jpg -------------------------------------------------------------------------------- /assets/a-dog-on-top-of-sks-container-0023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/a-dog-on-top-of-sks-container-0023.jpg -------------------------------------------------------------------------------- /assets/a-red-sks-container-0021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/a-red-sks-container-0021.jpg -------------------------------------------------------------------------------- /assets/photo-of-a-sks-container-0018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/photo-of-a-sks-container-0018.jpg -------------------------------------------------------------------------------- /assets/photo-of-a-sks-container-on-the-beach-0017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/photo-of-a-sks-container-on-the-beach-0017.jpg -------------------------------------------------------------------------------- /assets/photo-of-a-sks-container-on-the-moon-0016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/assets/photo-of-a-sks-container-on-the-moon-0016.jpg -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | personalization_config: 21 | target: ldm.modules.embedding_manager.EmbeddingManager 22 | params: 23 | placeholder_strings: ["*"] 24 | initializer_words: [] 25 | 26 | unet_config: 27 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 28 | params: 29 | image_size: 32 30 | in_channels: 4 31 | out_channels: 4 32 | model_channels: 320 33 | attention_resolutions: 34 | - 4 35 | - 2 36 | - 1 37 | num_res_blocks: 2 38 | channel_mult: 39 | - 1 40 | - 2 41 | - 4 42 | - 4 43 | num_heads: 8 44 | use_spatial_transformer: true 45 | transformer_depth: 1 46 | context_dim: 1280 47 | use_checkpoint: true 48 | legacy: False 49 | 50 | first_stage_config: 51 | target: ldm.models.autoencoder.AutoencoderKL 52 | params: 53 | embed_dim: 4 54 | monitor: val/rec_loss 55 | ddconfig: 56 | double_z: true 57 | z_channels: 4 58 | resolution: 256 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: 63 | - 1 64 | - 2 65 | - 4 66 | - 4 67 | num_res_blocks: 2 68 | attn_resolutions: [] 69 | dropout: 0.0 70 | lossconfig: 71 | target: torch.nn.Identity 72 | 73 | cond_stage_config: 74 | target: ldm.modules.encoders.modules.BERTEmbedder 75 | params: 76 | n_embed: 1280 77 | n_layer: 32 78 | -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-3 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["sculpture"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | progressive_words: False 29 | 30 | unet_config: 31 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 32 | params: 33 | image_size: 32 34 | in_channels: 4 35 | out_channels: 4 36 | model_channels: 320 37 | attention_resolutions: 38 | - 4 39 | - 2 40 | - 1 41 | num_res_blocks: 2 42 | channel_mult: 43 | - 1 44 | - 2 45 | - 4 46 | - 4 47 | num_heads: 8 48 | use_spatial_transformer: true 49 | transformer_depth: 1 50 | context_dim: 1280 51 | use_checkpoint: true 52 | legacy: False 53 | 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: 78 | target: ldm.modules.encoders.modules.BERTEmbedder 79 | params: 80 | n_embed: 1280 81 | n_layer: 32 82 | 83 | 84 | data: 85 | target: main.DataModuleFromConfig 86 | params: 87 | batch_size: 4 88 | num_workers: 2 89 | wrap: false 90 | train: 91 | target: ldm.data.personalized.PersonalizedBase 92 | params: 93 | size: 256 94 | set: train 95 | per_image_tokens: false 96 | repeats: 100 97 | validation: 98 | target: ldm.data.personalized.PersonalizedBase 99 | params: 100 | size: 256 101 | set: val 102 | per_image_tokens: false 103 | repeats: 10 104 | 105 | lightning: 106 | modelcheckpoint: 107 | params: 108 | every_n_train_steps: 500 109 | callbacks: 110 | image_logger: 111 | target: main.ImageLogger 112 | params: 113 | batch_frequency: 500 114 | max_images: 8 115 | increase_log_steps: False 116 | 117 | trainer: 118 | benchmark: True 119 | max_steps: 6100 -------------------------------------------------------------------------------- /configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-3 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | 21 | personalization_config: 22 | target: ldm.modules.embedding_manager.EmbeddingManager 23 | params: 24 | placeholder_strings: ["*"] 25 | initializer_words: ["painting"] 26 | per_image_tokens: false 27 | num_vectors_per_token: 1 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: 37 | - 4 38 | - 2 39 | - 1 40 | num_res_blocks: 2 41 | channel_mult: 42 | - 1 43 | - 2 44 | - 4 45 | - 4 46 | num_heads: 8 47 | use_spatial_transformer: true 48 | transformer_depth: 1 49 | context_dim: 1280 50 | use_checkpoint: true 51 | legacy: False 52 | 53 | first_stage_config: 54 | target: ldm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.BERTEmbedder 78 | params: 79 | n_embed: 1280 80 | n_layer: 32 81 | 82 | 83 | data: 84 | target: main.DataModuleFromConfig 85 | params: 86 | batch_size: 4 87 | num_workers: 4 88 | wrap: false 89 | train: 90 | target: ldm.data.personalized_style.PersonalizedBase 91 | params: 92 | size: 256 93 | set: train 94 | per_image_tokens: false 95 | repeats: 100 96 | validation: 97 | target: ldm.data.personalized_style.PersonalizedBase 98 | params: 99 | size: 256 100 | set: val 101 | per_image_tokens: false 102 | repeats: 10 103 | 104 | lightning: 105 | modelcheckpoint: 106 | params: 107 | every_n_train_steps: 500 108 | callbacks: 109 | image_logger: 110 | target: main.ImageLogger 111 | params: 112 | batch_frequency: 500 113 | max_images: 8 114 | increase_log_steps: False 115 | 116 | trainer: 117 | benchmark: True -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-finetune.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-03 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: true # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | embedding_reg_weight: 0.0 20 | unfreeze_model: False 21 | model_lr: 0.0 22 | 23 | personalization_config: 24 | target: ldm.modules.embedding_manager.EmbeddingManager 25 | params: 26 | placeholder_strings: ["*"] 27 | initializer_words: ["sculpture"] 28 | per_image_tokens: false 29 | num_vectors_per_token: 1 30 | progressive_words: False 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 # unused 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 320 39 | attention_resolutions: [ 4, 2, 1 ] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 4, 4 ] 42 | num_heads: 8 43 | use_spatial_transformer: True 44 | transformer_depth: 1 45 | context_dim: 768 46 | use_checkpoint: True 47 | legacy: False 48 | 49 | first_stage_config: 50 | target: ldm.models.autoencoder.AutoencoderKL 51 | params: 52 | embed_dim: 4 53 | monitor: val/rec_loss 54 | ddconfig: 55 | double_z: true 56 | z_channels: 4 57 | resolution: 512 58 | in_channels: 3 59 | out_ch: 3 60 | ch: 128 61 | ch_mult: 62 | - 1 63 | - 2 64 | - 4 65 | - 4 66 | num_res_blocks: 2 67 | attn_resolutions: [] 68 | dropout: 0.0 69 | lossconfig: 70 | target: torch.nn.Identity 71 | 72 | cond_stage_config: 73 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 74 | 75 | data: 76 | target: main.DataModuleFromConfig 77 | params: 78 | batch_size: 2 79 | num_workers: 2 80 | wrap: false 81 | train: 82 | target: ldm.data.personalized.PersonalizedBase 83 | params: 84 | size: 512 85 | set: train 86 | per_image_tokens: false 87 | repeats: 100 88 | validation: 89 | target: ldm.data.personalized.PersonalizedBase 90 | params: 91 | size: 512 92 | set: val 93 | per_image_tokens: false 94 | repeats: 10 95 | 96 | lightning: 97 | modelcheckpoint: 98 | params: 99 | every_n_train_steps: 500 100 | callbacks: 101 | image_logger: 102 | target: main.ImageLogger 103 | params: 104 | batch_frequency: 500 105 | max_images: 8 106 | increase_log_steps: False 107 | 108 | trainer: 109 | benchmark: True 110 | max_steps: 6100 -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-finetune_unfrozen.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | reg_weight: 1.0 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: image 12 | cond_stage_key: caption 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: true # Note: different from the one we trained before 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_factor: 0.18215 19 | use_ema: False 20 | embedding_reg_weight: 0.0 21 | unfreeze_model: True 22 | model_lr: 1.0e-6 23 | 24 | personalization_config: 25 | target: ldm.modules.embedding_manager.EmbeddingManager 26 | params: 27 | placeholder_strings: ["*"] 28 | initializer_words: ["sculpture"] 29 | per_image_tokens: false 30 | num_vectors_per_token: 1 31 | progressive_words: False 32 | 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 35 | params: 36 | image_size: 32 # unused 37 | in_channels: 4 38 | out_channels: 4 39 | model_channels: 320 40 | attention_resolutions: [ 4, 2, 1 ] 41 | num_res_blocks: 2 42 | channel_mult: [ 1, 2, 4, 4 ] 43 | num_heads: 8 44 | use_spatial_transformer: True 45 | transformer_depth: 1 46 | context_dim: 768 47 | use_checkpoint: True 48 | legacy: False 49 | 50 | first_stage_config: 51 | target: ldm.models.autoencoder.AutoencoderKL 52 | params: 53 | embed_dim: 4 54 | monitor: val/rec_loss 55 | ddconfig: 56 | double_z: true 57 | z_channels: 4 58 | resolution: 512 59 | in_channels: 3 60 | out_ch: 3 61 | ch: 128 62 | ch_mult: 63 | - 1 64 | - 2 65 | - 4 66 | - 4 67 | num_res_blocks: 2 68 | attn_resolutions: [] 69 | dropout: 0.0 70 | lossconfig: 71 | target: torch.nn.Identity 72 | 73 | cond_stage_config: 74 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 75 | 76 | data: 77 | target: main.DataModuleFromConfig 78 | params: 79 | batch_size: 1 80 | num_workers: 1 81 | wrap: false 82 | train: 83 | target: ldm.data.personalized_batch.PersonalizedBatchBase 84 | params: 85 | size: 512 86 | set: train 87 | repeats: 100 88 | validation: 89 | target: ldm.data.personalized.PersonalizedBase 90 | params: 91 | size: 512 92 | set: val 93 | repeats: 10 94 | 95 | lightning: 96 | modelcheckpoint: 97 | params: 98 | every_n_train_steps: 500 99 | callbacks: 100 | image_logger: 101 | target: main.ImageLogger 102 | params: 103 | batch_frequency: 500 104 | max_images: 8 105 | increase_log_steps: False 106 | 107 | trainer: 108 | benchmark: True 109 | max_steps: 3000 110 | -------------------------------------------------------------------------------- /configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | personalization_config: 21 | target: ldm.modules.embedding_manager.EmbeddingManager 22 | params: 23 | placeholder_strings: ["*"] 24 | initializer_words: ["sculpture"] 25 | per_image_tokens: false 26 | num_vectors_per_token: 1 27 | progressive_words: False 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder -------------------------------------------------------------------------------- /data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.10 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.10.2 10 | - torchvision=0.11.3 11 | - numpy=1.22.3 12 | - pip: 13 | - albumentations==1.1.0 14 | - opencv-python==4.2.0.34 15 | - pudb==2019.2 16 | - imageio==2.14.1 17 | - imageio-ffmpeg==0.4.7 18 | - pytorch-lightning==1.5.9 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - setuptools==59.5.0 23 | - pillow==9.0.1 24 | - einops==0.4.1 25 | - torch-fidelity==0.3.0 26 | - transformers==4.18.0 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /evaluation/clip_eval.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | from torchvision import transforms 4 | 5 | from ldm.models.diffusion.ddim import DDIMSampler 6 | 7 | class CLIPEvaluator(object): 8 | def __init__(self, device, clip_model='ViT-B/32') -> None: 9 | self.device = device 10 | self.model, clip_preprocess = clip.load(clip_model, device=self.device) 11 | 12 | self.clip_preprocess = clip_preprocess 13 | 14 | self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1]. 15 | clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions 16 | clip_preprocess.transforms[4:]) # + skip convert PIL to tensor 17 | 18 | def tokenize(self, strings: list): 19 | return clip.tokenize(strings).to(self.device) 20 | 21 | @torch.no_grad() 22 | def encode_text(self, tokens: list) -> torch.Tensor: 23 | return self.model.encode_text(tokens) 24 | 25 | @torch.no_grad() 26 | def encode_images(self, images: torch.Tensor) -> torch.Tensor: 27 | images = self.preprocess(images).to(self.device) 28 | return self.model.encode_image(images) 29 | 30 | def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor: 31 | 32 | tokens = clip.tokenize(text).to(self.device) 33 | 34 | text_features = self.encode_text(tokens).detach() 35 | 36 | if norm: 37 | text_features /= text_features.norm(dim=-1, keepdim=True) 38 | 39 | return text_features 40 | 41 | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: 42 | image_features = self.encode_images(img) 43 | 44 | if norm: 45 | image_features /= image_features.clone().norm(dim=-1, keepdim=True) 46 | 47 | return image_features 48 | 49 | def img_to_img_similarity(self, src_images, generated_images): 50 | src_img_features = self.get_image_features(src_images) 51 | gen_img_features = self.get_image_features(generated_images) 52 | 53 | return (src_img_features @ gen_img_features.T).mean() 54 | 55 | def txt_to_img_similarity(self, text, generated_images): 56 | text_features = self.get_text_features(text) 57 | gen_img_features = self.get_image_features(generated_images) 58 | 59 | return (text_features @ gen_img_features.T).mean() 60 | 61 | 62 | class LDMCLIPEvaluator(CLIPEvaluator): 63 | def __init__(self, device, clip_model='ViT-B/32') -> None: 64 | super().__init__(device, clip_model) 65 | 66 | def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50): 67 | 68 | sampler = DDIMSampler(ldm_model) 69 | 70 | samples_per_batch = 8 71 | n_batches = n_samples // samples_per_batch 72 | 73 | # generate samples 74 | all_samples=list() 75 | with torch.no_grad(): 76 | with ldm_model.ema_scope(): 77 | uc = ldm_model.get_learned_conditioning(samples_per_batch * [""]) 78 | 79 | for batch in range(n_batches): 80 | c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text]) 81 | shape = [4, 256//8, 256//8] 82 | samples_ddim, _ = sampler.sample(S=n_steps, 83 | conditioning=c, 84 | batch_size=samples_per_batch, 85 | shape=shape, 86 | verbose=False, 87 | unconditional_guidance_scale=5.0, 88 | unconditional_conditioning=uc, 89 | eta=0.0) 90 | 91 | x_samples_ddim = ldm_model.decode_first_stage(samples_ddim) 92 | x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0) 93 | 94 | all_samples.append(x_samples_ddim) 95 | 96 | all_samples = torch.cat(all_samples, axis=0) 97 | 98 | sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples) 99 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples) 100 | 101 | return sim_samples_to_img, sim_samples_to_text 102 | 103 | 104 | class ImageDirEvaluator(CLIPEvaluator): 105 | def __init__(self, device, clip_model='ViT-B/32') -> None: 106 | super().__init__(device, clip_model) 107 | 108 | def evaluate(self, gen_samples, src_images, target_text): 109 | 110 | sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples) 111 | sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples) 112 | 113 | return sim_samples_to_img, sim_samples_to_text -------------------------------------------------------------------------------- /gdrive: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/gdrive -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/data/personalized.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from pathlib import Path 8 | 9 | class PersonalizedBase(Dataset): 10 | def __init__(self, 11 | data_root, 12 | size=None, 13 | repeats=100, 14 | interpolation="bicubic", 15 | flip_p=0.0, 16 | set="train", 17 | center_crop=False, 18 | reg=False 19 | ): 20 | 21 | self.data_root = data_root 22 | 23 | self.image_paths = [] 24 | 25 | classes = os.listdir(self.data_root) 26 | 27 | for cl in classes: 28 | class_path = os.path.join(self.data_root, cl) 29 | for file_path in os.listdir(class_path): 30 | image_path = os.path.join(class_path, file_path) 31 | self.image_paths.append(image_path) 32 | 33 | # self._length = len(self.image_paths) 34 | self.num_images = len(self.image_paths) 35 | self._length = self.num_images 36 | 37 | self.center_crop = center_crop 38 | 39 | if set == "train": 40 | self._length = self.num_images * repeats 41 | 42 | self.size = size 43 | self.interpolation = {"linear": PIL.Image.LINEAR, 44 | "bilinear": PIL.Image.BILINEAR, 45 | "bicubic": PIL.Image.BICUBIC, 46 | "lanczos": PIL.Image.LANCZOS, 47 | }[interpolation] 48 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 49 | self.reg = reg 50 | 51 | def __len__(self): 52 | return self._length 53 | 54 | def __getitem__(self, i): 55 | example = {} 56 | 57 | image = Image.open(self.image_paths[i % self.num_images]) 58 | 59 | if not image.mode == "RGB": 60 | image = image.convert("RGB") 61 | 62 | pathname = Path(self.image_paths[i % self.num_images]).name 63 | 64 | parts = pathname.split("_") 65 | identifier = parts[0] 66 | 67 | example["caption"] = identifier 68 | 69 | # default to score-sde preprocessing 70 | img = np.array(image).astype(np.uint8) 71 | 72 | if self.center_crop: 73 | crop = min(img.shape[0], img.shape[1]) 74 | h, w, = img.shape[0], img.shape[1] 75 | img = img[(h - crop) // 2:(h + crop) // 2, 76 | (w - crop) // 2:(w + crop) // 2] 77 | 78 | image = Image.fromarray(img) 79 | if self.size is not None: 80 | image = image.resize((self.size, self.size), 81 | resample=self.interpolation) 82 | 83 | image = self.flip(image) 84 | image = np.array(image).astype(np.uint8) 85 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 86 | return example 87 | -------------------------------------------------------------------------------- /ldm/data/personalized_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from pathlib import Path 8 | 9 | class PersonalizedBatchBase(Dataset): 10 | def __init__(self, 11 | data_root, 12 | reg_data_root, 13 | size=None, 14 | repeats=100, 15 | interpolation="bicubic", 16 | flip_p=0.0, 17 | set="train", 18 | center_crop=False, 19 | reg=False 20 | ): 21 | 22 | self.data_root = data_root 23 | self.reg_data_root = reg_data_root 24 | 25 | self.image_paths = [] 26 | self.image_classes = [] 27 | 28 | classes = os.listdir(self.data_root) 29 | 30 | for cl in classes: 31 | class_path = os.path.join(self.data_root, cl) 32 | for file_path in os.listdir(class_path): 33 | image_path = os.path.join(class_path, file_path) 34 | self.image_paths.append(image_path) 35 | self.image_classes.append(cl) 36 | 37 | self.reg_image_paths = {} 38 | 39 | classes = os.listdir(self.reg_data_root) 40 | 41 | for cl in classes: 42 | self.reg_image_paths[cl] = [] 43 | class_path = os.path.join(self.reg_data_root, cl) 44 | for file_path in os.listdir(class_path): 45 | image_path = os.path.join(class_path, file_path) 46 | self.reg_image_paths[cl].append(image_path) 47 | 48 | # self._length = len(self.image_paths) 49 | self.num_images = len(self.image_paths) 50 | self._length = self.num_images 51 | 52 | self.center_crop = center_crop 53 | 54 | if set == "train": 55 | self._length = self.num_images * repeats 56 | 57 | self.size = size 58 | self.interpolation = {"linear": PIL.Image.LINEAR, 59 | "bilinear": PIL.Image.BILINEAR, 60 | "bicubic": PIL.Image.BICUBIC, 61 | "lanczos": PIL.Image.LANCZOS, 62 | }[interpolation] 63 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 64 | self.reg = reg 65 | 66 | def __len__(self): 67 | return self._length 68 | 69 | def __getitem__(self, i): 70 | idx = i % len(self.image_paths) 71 | example = self.get_image(self.image_paths[idx]) 72 | cl = self.image_classes[idx] 73 | example_reg = self.get_image(self.reg_image_paths[cl][i % len(self.reg_image_paths[cl])]) 74 | return tuple([example, example_reg]) 75 | 76 | def get_image(self, image_path): 77 | example = {} 78 | 79 | image = Image.open(image_path) 80 | 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | 84 | pathname = Path(image_path).name 85 | 86 | parts = pathname.split("_") 87 | identifier = parts[0] 88 | 89 | example["caption"] = identifier 90 | 91 | # default to score-sde preprocessing 92 | img = np.array(image).astype(np.uint8) 93 | 94 | if self.center_crop: 95 | crop = min(img.shape[0], img.shape[1]) 96 | h, w, = img.shape[0], img.shape[1] 97 | img = img[(h - crop) // 2:(h + crop) // 2, 98 | (w - crop) // 2:(w + crop) // 2] 99 | 100 | image = Image.fromarray(img) 101 | if self.size is not None: 102 | image = image.resize((self.size, self.size), 103 | resample=self.interpolation) 104 | 105 | image = self.flip(image) 106 | image = np.array(image).astype(np.uint8) 107 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 108 | return example 109 | -------------------------------------------------------------------------------- /ldm/data/personalized_style.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | import random 9 | 10 | imagenet_templates_small = [ 11 | 'a painting in the style of {}', 12 | 'a rendering in the style of {}', 13 | 'a cropped painting in the style of {}', 14 | 'the painting in the style of {}', 15 | 'a clean painting in the style of {}', 16 | 'a dirty painting in the style of {}', 17 | 'a dark painting in the style of {}', 18 | 'a picture in the style of {}', 19 | 'a cool painting in the style of {}', 20 | 'a close-up painting in the style of {}', 21 | 'a bright painting in the style of {}', 22 | 'a cropped painting in the style of {}', 23 | 'a good painting in the style of {}', 24 | 'a close-up painting in the style of {}', 25 | 'a rendition in the style of {}', 26 | 'a nice painting in the style of {}', 27 | 'a small painting in the style of {}', 28 | 'a weird painting in the style of {}', 29 | 'a large painting in the style of {}', 30 | ] 31 | 32 | imagenet_dual_templates_small = [ 33 | 'a painting in the style of {} with {}', 34 | 'a rendering in the style of {} with {}', 35 | 'a cropped painting in the style of {} with {}', 36 | 'the painting in the style of {} with {}', 37 | 'a clean painting in the style of {} with {}', 38 | 'a dirty painting in the style of {} with {}', 39 | 'a dark painting in the style of {} with {}', 40 | 'a cool painting in the style of {} with {}', 41 | 'a close-up painting in the style of {} with {}', 42 | 'a bright painting in the style of {} with {}', 43 | 'a cropped painting in the style of {} with {}', 44 | 'a good painting in the style of {} with {}', 45 | 'a painting of one {} in the style of {}', 46 | 'a nice painting in the style of {} with {}', 47 | 'a small painting in the style of {} with {}', 48 | 'a weird painting in the style of {} with {}', 49 | 'a large painting in the style of {} with {}', 50 | ] 51 | 52 | per_img_token_list = [ 53 | 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', 54 | ] 55 | 56 | class PersonalizedBase(Dataset): 57 | def __init__(self, 58 | data_root, 59 | size=None, 60 | repeats=100, 61 | interpolation="bicubic", 62 | flip_p=0.5, 63 | set="train", 64 | placeholder_token="*", 65 | per_image_tokens=False, 66 | center_crop=False, 67 | ): 68 | 69 | self.data_root = data_root 70 | 71 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 72 | 73 | # self._length = len(self.image_paths) 74 | self.num_images = len(self.image_paths) 75 | self._length = self.num_images 76 | 77 | self.placeholder_token = placeholder_token 78 | 79 | self.per_image_tokens = per_image_tokens 80 | self.center_crop = center_crop 81 | 82 | if per_image_tokens: 83 | assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." 84 | 85 | if set == "train": 86 | self._length = self.num_images * repeats 87 | 88 | self.size = size 89 | self.interpolation = {"linear": PIL.Image.LINEAR, 90 | "bilinear": PIL.Image.BILINEAR, 91 | "bicubic": PIL.Image.BICUBIC, 92 | "lanczos": PIL.Image.LANCZOS, 93 | }[interpolation] 94 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 95 | 96 | def __len__(self): 97 | return self._length 98 | 99 | def __getitem__(self, i): 100 | example = {} 101 | image = Image.open(self.image_paths[i % self.num_images]) 102 | 103 | if not image.mode == "RGB": 104 | image = image.convert("RGB") 105 | 106 | if self.per_image_tokens and np.random.uniform() < 0.25: 107 | text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) 108 | else: 109 | text = random.choice(imagenet_templates_small).format(self.placeholder_token) 110 | 111 | example["caption"] = text 112 | 113 | # default to score-sde preprocessing 114 | img = np.array(image).astype(np.uint8) 115 | 116 | if self.center_crop: 117 | crop = min(img.shape[0], img.shape[1]) 118 | h, w, = img.shape[0], img.shape[1] 119 | img = img[(h - crop) // 2:(h + crop) // 2, 120 | (w - crop) // 2:(w + crop) // 2] 121 | 122 | image = Image.fromarray(img) 123 | if self.size is not None: 124 | image = image.resize((self.size, self.size), resample=self.interpolation) 125 | 126 | image = self.flip(image) 127 | image = np.array(image).astype(np.uint8) 128 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 129 | return example -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | del context, x 178 | 179 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 180 | 181 | r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 182 | 183 | # valid values for steps = 2,4,8,16,32,64 184 | # higher steps is slower but less memory usage 185 | # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920 186 | # speed seems to be impacted more on 30x series cards 187 | steps = 16 188 | slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] 189 | for i in range(0, q.shape[1], slice_size): 190 | end = i + slice_size 191 | s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) 192 | s1 *= self.scale 193 | s2 = s1.softmax(dim=-1) 194 | del s1 195 | r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) 196 | del s2 197 | r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) 198 | del r1 199 | 200 | return self.to_out(r2) 201 | 202 | 203 | class BasicTransformerBlock(nn.Module): 204 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 205 | super().__init__() 206 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 207 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 208 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 209 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 210 | self.norm1 = nn.LayerNorm(dim) 211 | self.norm2 = nn.LayerNorm(dim) 212 | self.norm3 = nn.LayerNorm(dim) 213 | self.checkpoint = checkpoint 214 | 215 | def forward(self, x, context=None): 216 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 217 | 218 | def _forward(self, x, context=None): 219 | x = self.attn1(self.norm1(x)) + x 220 | x = self.attn2(self.norm2(x), context=context) + x 221 | x = self.ff(self.norm3(x)) + x 222 | return x 223 | 224 | 225 | class SpatialTransformer(nn.Module): 226 | """ 227 | Transformer block for image-like data. 228 | First, project the input (aka embedding) 229 | and reshape to b, t, d. 230 | Then apply standard transformer action. 231 | Finally, reshape to image 232 | """ 233 | def __init__(self, in_channels, n_heads, d_head, 234 | depth=1, dropout=0., context_dim=None): 235 | super().__init__() 236 | self.in_channels = in_channels 237 | inner_dim = n_heads * d_head 238 | self.norm = Normalize(in_channels) 239 | 240 | self.proj_in = nn.Conv2d(in_channels, 241 | inner_dim, 242 | kernel_size=1, 243 | stride=1, 244 | padding=0) 245 | 246 | self.transformer_blocks = nn.ModuleList( 247 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 248 | for d in range(depth)] 249 | ) 250 | 251 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 252 | in_channels, 253 | kernel_size=1, 254 | stride=1, 255 | padding=0)) 256 | 257 | def forward(self, x, context=None): 258 | # note: if no context is given, cross-attention defaults to self-attention 259 | b, c, h, w = x.shape 260 | x_in = x 261 | x = self.norm(x) 262 | x = self.proj_in(x) 263 | x = rearrange(x, 'b c h w -> b (h w) c') 264 | for block in self.transformer_blocks: 265 | x = block(x, context=context) 266 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 267 | x = self.proj_out(x) 268 | return x + x_in -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: # disabled checkpointing to allow requires_grad = False for main model 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() 268 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/embedding_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ldm.data.personalized import per_img_token_list 5 | from transformers import CLIPTokenizer 6 | from functools import partial 7 | 8 | DEFAULT_PLACEHOLDER_TOKEN = ["*"] 9 | 10 | PROGRESSIVE_SCALE = 2000 11 | 12 | def get_clip_token_for_string(tokenizer, string): 13 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 14 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 15 | tokens = batch_encoding["input_ids"] 16 | assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" 17 | 18 | return tokens[0, 1] 19 | 20 | def get_bert_token_for_string(tokenizer, string): 21 | token = tokenizer(string) 22 | assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" 23 | 24 | token = token[0, 1] 25 | 26 | return token 27 | 28 | def get_embedding_for_clip_token(embedder, token): 29 | return embedder(token.unsqueeze(0))[0, 0] 30 | 31 | 32 | class EmbeddingManager(nn.Module): 33 | def __init__( 34 | self, 35 | embedder, 36 | placeholder_strings=None, 37 | initializer_words=None, 38 | per_image_tokens=False, 39 | num_vectors_per_token=1, 40 | progressive_words=False, 41 | **kwargs 42 | ): 43 | super().__init__() 44 | 45 | self.string_to_token_dict = {} 46 | 47 | self.string_to_param_dict = nn.ParameterDict() 48 | 49 | self.initial_embeddings = nn.ParameterDict() # These should not be optimized 50 | 51 | self.progressive_words = progressive_words 52 | self.progressive_counter = 0 53 | 54 | self.max_vectors_per_token = num_vectors_per_token 55 | 56 | if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder 57 | self.is_clip = True 58 | get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) 59 | get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) 60 | token_dim = 768 61 | else: # using LDM's BERT encoder 62 | self.is_clip = False 63 | get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) 64 | get_embedding_for_tkn = embedder.transformer.token_emb 65 | token_dim = 1280 66 | 67 | if per_image_tokens: 68 | placeholder_strings.extend(per_img_token_list) 69 | 70 | for idx, placeholder_string in enumerate(placeholder_strings): 71 | 72 | token = get_token_for_string(placeholder_string) 73 | 74 | if initializer_words and idx < len(initializer_words): 75 | init_word_token = get_token_for_string(initializer_words[idx]) 76 | 77 | with torch.no_grad(): 78 | init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) 79 | 80 | token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) 81 | self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) 82 | else: 83 | token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) 84 | 85 | self.string_to_token_dict[placeholder_string] = token 86 | self.string_to_param_dict[placeholder_string] = token_params 87 | 88 | def forward( 89 | self, 90 | tokenized_text, 91 | embedded_text, 92 | ): 93 | b, n, device = *tokenized_text.shape, tokenized_text.device 94 | 95 | for placeholder_string, placeholder_token in self.string_to_token_dict.items(): 96 | 97 | placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) 98 | 99 | if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement 100 | placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) 101 | embedded_text[placeholder_idx] = placeholder_embedding 102 | else: # otherwise, need to insert and keep track of changing indices 103 | if self.progressive_words: 104 | self.progressive_counter += 1 105 | max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE 106 | else: 107 | max_step_tokens = self.max_vectors_per_token 108 | 109 | num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) 110 | 111 | placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) 112 | 113 | if placeholder_rows.nelement() == 0: 114 | continue 115 | 116 | sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) 117 | sorted_rows = placeholder_rows[sort_idx] 118 | 119 | for idx in range(len(sorted_rows)): 120 | row = sorted_rows[idx] 121 | col = sorted_cols[idx] 122 | 123 | new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] 124 | new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] 125 | 126 | embedded_text[row] = new_embed_row 127 | tokenized_text[row] = new_token_row 128 | 129 | return embedded_text 130 | 131 | def save(self, ckpt_path): 132 | torch.save({"string_to_token": self.string_to_token_dict, 133 | "string_to_param": self.string_to_param_dict}, ckpt_path) 134 | 135 | def load(self, ckpt_path): 136 | ckpt = torch.load(ckpt_path, map_location='cpu') 137 | 138 | self.string_to_token_dict = ckpt["string_to_token"] 139 | self.string_to_param_dict = ckpt["string_to_param"] 140 | 141 | def get_embedding_norms_squared(self): 142 | all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim 143 | param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders 144 | 145 | return param_norm_squared 146 | 147 | def embedding_parameters(self): 148 | return self.string_to_param_dict.parameters() 149 | 150 | def embedding_to_coarse_loss(self): 151 | 152 | loss = 0. 153 | num_embeddings = len(self.initial_embeddings) 154 | 155 | for key in self.initial_embeddings: 156 | optimized = self.string_to_param_dict[key] 157 | coarse = self.initial_embeddings[key].clone().to(optimized.device) 158 | 159 | loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings 160 | 161 | return loss -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | # Changed to work on Windows 26 | font = ImageFont.load_default() 27 | #font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 28 | nc = int(40 * (wh[0] / 256)) 29 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 30 | 31 | try: 32 | draw.text((0, 0), lines, fill="black", font=font) 33 | except UnicodeEncodeError: 34 | print("Cant encode string for logging. Skipping.") 35 | 36 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 37 | txts.append(txt) 38 | txts = np.stack(txts) 39 | txts = torch.tensor(txts) 40 | return txts 41 | 42 | 43 | def ismap(x): 44 | if not isinstance(x, torch.Tensor): 45 | return False 46 | return (len(x.shape) == 4) and (x.shape[1] > 3) 47 | 48 | 49 | def isimage(x): 50 | if not isinstance(x, torch.Tensor): 51 | return False 52 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 53 | 54 | 55 | def exists(x): 56 | return x is not None 57 | 58 | 59 | def default(val, d): 60 | if exists(val): 61 | return val 62 | return d() if isfunction(d) else d 63 | 64 | 65 | def mean_flat(tensor): 66 | """ 67 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 68 | Take the mean over all non-batch dimensions. 69 | """ 70 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 71 | 72 | 73 | def count_params(model, verbose=False): 74 | total_params = sum(p.numel() for p in model.parameters()) 75 | if verbose: 76 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 77 | return total_params 78 | 79 | 80 | def instantiate_from_config(config, **kwargs): 81 | if not "target" in config: 82 | if config == '__is_first_stage__': 83 | return None 84 | elif config == "__is_unconditional__": 85 | return None 86 | raise KeyError("Expected key `target` to instantiate.") 87 | return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) 88 | 89 | 90 | def get_obj_from_str(string, reload=False): 91 | module, cls = string.rsplit(".", 1) 92 | if reload: 93 | module_imp = importlib.import_module(module) 94 | importlib.reload(module_imp) 95 | return getattr(importlib.import_module(module, package=None), cls) 96 | 97 | 98 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 99 | # create dummy dataset instance 100 | 101 | # run prefetching 102 | if idx_to_fn: 103 | res = func(data, worker_id=idx) 104 | else: 105 | res = func(data) 106 | Q.put([idx, res]) 107 | Q.put("Done") 108 | 109 | 110 | def parallel_data_prefetch( 111 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 112 | ): 113 | # if target_data_type not in ["ndarray", "list"]: 114 | # raise ValueError( 115 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 116 | # ) 117 | if isinstance(data, np.ndarray) and target_data_type == "list": 118 | raise ValueError("list expected but function got ndarray.") 119 | elif isinstance(data, abc.Iterable): 120 | if isinstance(data, dict): 121 | print( 122 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 123 | ) 124 | data = list(data.values()) 125 | if target_data_type == "ndarray": 126 | data = np.asarray(data) 127 | else: 128 | data = list(data) 129 | else: 130 | raise TypeError( 131 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 132 | ) 133 | 134 | if cpu_intensive: 135 | Q = mp.Queue(1000) 136 | proc = mp.Process 137 | else: 138 | Q = Queue(1000) 139 | proc = Thread 140 | # spawn processes 141 | if target_data_type == "ndarray": 142 | arguments = [ 143 | [func, Q, part, i, use_worker_id] 144 | for i, part in enumerate(np.array_split(data, n_proc)) 145 | ] 146 | else: 147 | step = ( 148 | int(len(data) / n_proc + 1) 149 | if len(data) % n_proc != 0 150 | else int(len(data) / n_proc) 151 | ) 152 | arguments = [ 153 | [func, Q, part, i, use_worker_id] 154 | for i, part in enumerate( 155 | [data[i: i + step] for i in range(0, len(data), step)] 156 | ) 157 | ] 158 | processes = [] 159 | for i in range(n_proc): 160 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 161 | processes += [p] 162 | 163 | # start processes 164 | print(f"Start prefetching...") 165 | import time 166 | 167 | start = time.time() 168 | gather_res = [[] for _ in range(n_proc)] 169 | try: 170 | for p in processes: 171 | p.start() 172 | 173 | k = 0 174 | while k < n_proc: 175 | # get result 176 | res = Q.get() 177 | if res == "Done": 178 | k += 1 179 | else: 180 | gather_res[res[0]] = res[1] 181 | 182 | except Exception as e: 183 | print("Exception: ", e) 184 | for p in processes: 185 | p.terminate() 186 | 187 | raise e 188 | finally: 189 | for p in processes: 190 | p.join() 191 | print(f"Prefetching complete. [{time.time() - start} sec.]") 192 | 193 | if target_data_type == 'ndarray': 194 | if not isinstance(gather_res[0], np.ndarray): 195 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 196 | 197 | # order outputs 198 | return np.concatenate(gather_res, axis=0) 199 | elif target_data_type == 'list': 200 | out = [] 201 | for r in gather_res: 202 | out.extend(r) 203 | return out 204 | else: 205 | return gather_res 206 | -------------------------------------------------------------------------------- /merge_embeddings.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder 2 | from ldm.modules.embedding_manager import EmbeddingManager 3 | 4 | import argparse, os 5 | from functools import partial 6 | 7 | import torch 8 | 9 | def get_placeholder_loop(placeholder_string, embedder, is_sd): 10 | 11 | new_placeholder = None 12 | 13 | while True: 14 | if new_placeholder is None: 15 | new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") 16 | else: 17 | new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") 18 | 19 | token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder) 20 | 21 | if token is not None: 22 | return new_placeholder, token 23 | 24 | def get_clip_token_for_string(tokenizer, string): 25 | batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, 26 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 27 | tokens = batch_encoding["input_ids"] 28 | 29 | if torch.count_nonzero(tokens - 49407) == 2: 30 | return tokens[0, 1] 31 | 32 | return None 33 | 34 | def get_bert_token_for_string(tokenizer, string): 35 | token = tokenizer(string) 36 | if torch.count_nonzero(token) == 3: 37 | return token[0, 1] 38 | 39 | return None 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | parser = argparse.ArgumentParser() 45 | 46 | parser.add_argument( 47 | "--manager_ckpts", 48 | type=str, 49 | nargs="+", 50 | required=True, 51 | help="Paths to a set of embedding managers to be merged." 52 | ) 53 | 54 | parser.add_argument( 55 | "--output_path", 56 | type=str, 57 | required=True, 58 | help="Output path for the merged manager", 59 | ) 60 | 61 | parser.add_argument( 62 | "-sd", "--stable_diffusion", 63 | action="store_true", 64 | help="Flag to denote that we are merging stable diffusion embeddings" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | if args.stable_diffusion: 70 | embedder = FrozenCLIPEmbedder().cuda() 71 | else: 72 | embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() 73 | 74 | EmbeddingManager = partial(EmbeddingManager, embedder, ["*"]) 75 | 76 | string_to_token_dict = {} 77 | string_to_param_dict = torch.nn.ParameterDict() 78 | 79 | placeholder_to_src = {} 80 | 81 | for manager_ckpt in args.manager_ckpts: 82 | print(f"Parsing {manager_ckpt}...") 83 | 84 | manager = EmbeddingManager() 85 | manager.load(manager_ckpt) 86 | 87 | for placeholder_string in manager.string_to_token_dict: 88 | if not placeholder_string in string_to_token_dict: 89 | string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] 90 | string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] 91 | 92 | placeholder_to_src[placeholder_string] = manager_ckpt 93 | else: 94 | new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion) 95 | string_to_token_dict[new_placeholder] = new_token 96 | string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] 97 | 98 | placeholder_to_src[new_placeholder] = manager_ckpt 99 | 100 | print("Saving combined manager...") 101 | merged_manager = EmbeddingManager() 102 | merged_manager.string_to_param_dict = string_to_param_dict 103 | merged_manager.string_to_token_dict = string_to_token_dict 104 | merged_manager.save(args.output_path) 105 | 106 | print("Managers merged. Final list of placeholders: ") 107 | print(placeholder_to_src) 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /prune_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import glob 5 | 6 | 7 | parser = argparse.ArgumentParser(description='Pruning') 8 | parser.add_argument('--ckpt', type=str, default=None, help='path to model ckpt') 9 | args = parser.parse_args() 10 | ckpt = args.ckpt 11 | 12 | def prune_it(p, keep_only_ema=False): 13 | print(f"prunin' in path: {p}") 14 | size_initial = os.path.getsize(p) 15 | nsd = dict() 16 | sd = torch.load(p, map_location="cpu") 17 | print(sd.keys()) 18 | for k in sd.keys(): 19 | if k != "optimizer_states": 20 | nsd[k] = sd[k] 21 | else: 22 | print(f"removing optimizer states for path {p}") 23 | if "global_step" in sd: 24 | print(f"This is global step {sd['global_step']}.") 25 | if keep_only_ema: 26 | sd = nsd["state_dict"].copy() 27 | # infer ema keys 28 | ema_keys = {k: "model_ema." + k[6:].replace(".", ".") for k in sd.keys() if k.startswith("model.")} 29 | new_sd = dict() 30 | 31 | for k in sd: 32 | if k in ema_keys: 33 | new_sd[k] = sd[ema_keys[k]].half() 34 | elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: 35 | new_sd[k] = sd[k].half() 36 | 37 | assert len(new_sd) == len(sd) - len(ema_keys) 38 | nsd["state_dict"] = new_sd 39 | else: 40 | sd = nsd['state_dict'].copy() 41 | new_sd = dict() 42 | for k in sd: 43 | new_sd[k] = sd[k].half() 44 | nsd['state_dict'] = new_sd 45 | 46 | fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" 47 | print(f"saving pruned checkpoint at: {fn}") 48 | torch.save(nsd, fn) 49 | newsize = os.path.getsize(fn) 50 | MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \ 51 | f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states" 52 | if keep_only_ema: 53 | MSG += " and non-EMA weights" 54 | print(MSG) 55 | 56 | 57 | if __name__ == "__main__": 58 | prune_it(ckpt) -------------------------------------------------------------------------------- /readme-images/better-training-images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/better-training-images.png -------------------------------------------------------------------------------- /readme-images/caution-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/caution-training.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step1-select-docker-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step1-select-docker-image.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step2-instance-filters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step2-instance-filters.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step3-instances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step3-instances.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step4-get-repo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step4-get-repo.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step5-clone-repo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step5-clone-repo.png -------------------------------------------------------------------------------- /readme-images/vast-ai-step6-open-notebook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanewallmann/Dreambooth-Stable-Diffusion/6f04eae2331153f841cab502385415ac39503144/readme-images/vast-ai-step6-open-notebook.png -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /scripts/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | 3 | sys.path.append(os.path.join(sys.path[0], '..')) 4 | 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | from einops import rearrange 11 | from torchvision.utils import make_grid 12 | 13 | from ldm.util import instantiate_from_config 14 | from ldm.models.diffusion.ddim import DDIMSampler 15 | from ldm.models.diffusion.plms import PLMSSampler 16 | from ldm.data.personalized import PersonalizedBase 17 | from evaluation.clip_eval import LDMCLIPEvaluator 18 | 19 | def load_model_from_config(config, ckpt, verbose=False): 20 | print(f"Loading model from {ckpt}") 21 | pl_sd = torch.load(ckpt, map_location="cpu") 22 | sd = pl_sd["state_dict"] 23 | model = instantiate_from_config(config.model) 24 | m, u = model.load_state_dict(sd, strict=False) 25 | if len(m) > 0 and verbose: 26 | print("missing keys:") 27 | print(m) 28 | if len(u) > 0 and verbose: 29 | print("unexpected keys:") 30 | print(u) 31 | 32 | model.cuda() 33 | model.eval() 34 | return model 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument( 41 | "--prompt", 42 | type=str, 43 | nargs="?", 44 | default="a painting of a virus monster playing guitar", 45 | help="the prompt to render" 46 | ) 47 | 48 | parser.add_argument( 49 | "--ckpt_path", 50 | type=str, 51 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt", 52 | help="Path to pretrained ldm text2img model") 53 | 54 | parser.add_argument( 55 | "--embedding_path", 56 | type=str, 57 | help="Path to a pre-trained embedding manager checkpoint") 58 | 59 | parser.add_argument( 60 | "--data_dir", 61 | type=str, 62 | help="Path to directory with images used to train the embedding vectors" 63 | ) 64 | 65 | opt = parser.parse_args() 66 | 67 | 68 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic 69 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path 70 | model.embedding_manager.load(opt.embedding_path) 71 | 72 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 73 | model = model.to(device) 74 | 75 | evaluator = LDMCLIPEvaluator(device) 76 | 77 | prompt = opt.prompt 78 | 79 | data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0) 80 | 81 | images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)] 82 | images = torch.stack(images, axis=0) 83 | 84 | sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt) 85 | 86 | output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-")) 87 | 88 | print("Image similarity: ", sim_img) 89 | print("Text similarity: ", sim_text) -------------------------------------------------------------------------------- /scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /scripts/prune-ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import glob 5 | 6 | 7 | parser = argparse.ArgumentParser(description='Pruning') 8 | parser.add_argument('--ckpt', type=str, default=None, help='path to model ckpt') 9 | args = parser.parse_args() 10 | ckpt = args.ckpt 11 | 12 | def prune_it(p, keep_only_ema=False): 13 | print(f"prunin' in path: {p}") 14 | size_initial = os.path.getsize(p) 15 | nsd = dict() 16 | sd = torch.load(p, map_location="cpu") 17 | print(sd.keys()) 18 | for k in sd.keys(): 19 | if k != "optimizer_states": 20 | nsd[k] = sd[k] 21 | else: 22 | print(f"removing optimizer states for path {p}") 23 | if "global_step" in sd: 24 | print(f"This is global step {sd['global_step']}.") 25 | if keep_only_ema: 26 | sd = nsd["state_dict"].copy() 27 | # infer ema keys 28 | ema_keys = {k: "model_ema." + k[6:].replace(".", ".") for k in sd.keys() if k.startswith("model.")} 29 | new_sd = dict() 30 | 31 | for k in sd: 32 | if k in ema_keys: 33 | new_sd[k] = sd[ema_keys[k]].half() 34 | elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: 35 | new_sd[k] = sd[k].half() 36 | 37 | assert len(new_sd) == len(sd) - len(ema_keys) 38 | nsd["state_dict"] = new_sd 39 | else: 40 | sd = nsd['state_dict'].copy() 41 | new_sd = dict() 42 | for k in sd: 43 | new_sd[k] = sd[k].half() 44 | nsd['state_dict'] = new_sd 45 | 46 | fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" 47 | print(f"saving pruned checkpoint at: {fn}") 48 | torch.save(nsd, fn) 49 | newsize = os.path.getsize(fn) 50 | MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \ 51 | f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states" 52 | if keep_only_ema: 53 | MSG += " and non-EMA weights" 54 | print(MSG) 55 | 56 | 57 | if __name__ == "__main__": 58 | prune_it(ckpt) -------------------------------------------------------------------------------- /scripts/sample_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob, datetime, yaml 2 | import torch 3 | import time 4 | import numpy as np 5 | from tqdm import trange 6 | 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | 10 | from ldm.models.diffusion.ddim import DDIMSampler 11 | from ldm.util import instantiate_from_config 12 | 13 | rescale = lambda x: (x + 1.) / 2. 14 | 15 | def custom_to_pil(x): 16 | x = x.detach().cpu() 17 | x = torch.clamp(x, -1., 1.) 18 | x = (x + 1.) / 2. 19 | x = x.permute(1, 2, 0).numpy() 20 | x = (255 * x).astype(np.uint8) 21 | x = Image.fromarray(x) 22 | if not x.mode == "RGB": 23 | x = x.convert("RGB") 24 | return x 25 | 26 | 27 | def custom_to_np(x): 28 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py 29 | sample = x.detach().cpu() 30 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 31 | sample = sample.permute(0, 2, 3, 1) 32 | sample = sample.contiguous() 33 | return sample 34 | 35 | 36 | def logs2pil(logs, keys=["sample"]): 37 | imgs = dict() 38 | for k in logs: 39 | try: 40 | if len(logs[k].shape) == 4: 41 | img = custom_to_pil(logs[k][0, ...]) 42 | elif len(logs[k].shape) == 3: 43 | img = custom_to_pil(logs[k]) 44 | else: 45 | print(f"Unknown format for key {k}. ") 46 | img = None 47 | except: 48 | img = None 49 | imgs[k] = img 50 | return imgs 51 | 52 | 53 | @torch.no_grad() 54 | def convsample(model, shape, return_intermediates=True, 55 | verbose=True, 56 | make_prog_row=False): 57 | 58 | 59 | if not make_prog_row: 60 | return model.p_sample_loop(None, shape, 61 | return_intermediates=return_intermediates, verbose=verbose) 62 | else: 63 | return model.progressive_denoising( 64 | None, shape, verbose=True 65 | ) 66 | 67 | 68 | @torch.no_grad() 69 | def convsample_ddim(model, steps, shape, eta=1.0 70 | ): 71 | ddim = DDIMSampler(model) 72 | bs = shape[0] 73 | shape = shape[1:] 74 | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) 75 | return samples, intermediates 76 | 77 | 78 | @torch.no_grad() 79 | def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): 80 | 81 | 82 | log = dict() 83 | 84 | shape = [batch_size, 85 | model.model.diffusion_model.in_channels, 86 | model.model.diffusion_model.image_size, 87 | model.model.diffusion_model.image_size] 88 | 89 | with model.ema_scope("Plotting"): 90 | t0 = time.time() 91 | if vanilla: 92 | sample, progrow = convsample(model, shape, 93 | make_prog_row=True) 94 | else: 95 | sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, 96 | eta=eta) 97 | 98 | t1 = time.time() 99 | 100 | x_sample = model.decode_first_stage(sample) 101 | 102 | log["sample"] = x_sample 103 | log["time"] = t1 - t0 104 | log['throughput'] = sample.shape[0] / (t1 - t0) 105 | print(f'Throughput for this batch: {log["throughput"]}') 106 | return log 107 | 108 | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): 109 | if vanilla: 110 | print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') 111 | else: 112 | print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') 113 | 114 | 115 | tstart = time.time() 116 | n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 117 | # path = logdir 118 | if model.cond_stage_model is None: 119 | all_images = [] 120 | 121 | print(f"Running unconditional sampling for {n_samples} samples") 122 | for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): 123 | logs = make_convolutional_sample(model, batch_size=batch_size, 124 | vanilla=vanilla, custom_steps=custom_steps, 125 | eta=eta) 126 | n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") 127 | all_images.extend([custom_to_np(logs["sample"])]) 128 | if n_saved >= n_samples: 129 | print(f'Finish after generating {n_saved} samples') 130 | break 131 | all_img = np.concatenate(all_images, axis=0) 132 | all_img = all_img[:n_samples] 133 | shape_str = "x".join([str(x) for x in all_img.shape]) 134 | nppath = os.path.join(nplog, f"{shape_str}-samples.npz") 135 | np.savez(nppath, all_img) 136 | 137 | else: 138 | raise NotImplementedError('Currently only sampling for unconditional models supported.') 139 | 140 | print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") 141 | 142 | 143 | def save_logs(logs, path, n_saved=0, key="sample", np_path=None): 144 | for k in logs: 145 | if k == key: 146 | batch = logs[key] 147 | if np_path is None: 148 | for x in batch: 149 | img = custom_to_pil(x) 150 | imgpath = os.path.join(path, f"{key}_{n_saved:06}.png") 151 | img.save(imgpath) 152 | n_saved += 1 153 | else: 154 | npbatch = custom_to_np(batch) 155 | shape_str = "x".join([str(x) for x in npbatch.shape]) 156 | nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") 157 | np.savez(nppath, npbatch) 158 | n_saved += npbatch.shape[0] 159 | return n_saved 160 | 161 | 162 | def get_parser(): 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument( 165 | "-r", 166 | "--resume", 167 | type=str, 168 | nargs="?", 169 | help="load from logdir or checkpoint in logdir", 170 | ) 171 | parser.add_argument( 172 | "-n", 173 | "--n_samples", 174 | type=int, 175 | nargs="?", 176 | help="number of samples to draw", 177 | default=50000 178 | ) 179 | parser.add_argument( 180 | "-e", 181 | "--eta", 182 | type=float, 183 | nargs="?", 184 | help="eta for ddim sampling (0.0 yields deterministic sampling)", 185 | default=1.0 186 | ) 187 | parser.add_argument( 188 | "-v", 189 | "--vanilla_sample", 190 | default=False, 191 | action='store_true', 192 | help="vanilla sampling (default option is DDIM sampling)?", 193 | ) 194 | parser.add_argument( 195 | "-l", 196 | "--logdir", 197 | type=str, 198 | nargs="?", 199 | help="extra logdir", 200 | default="none" 201 | ) 202 | parser.add_argument( 203 | "-c", 204 | "--custom_steps", 205 | type=int, 206 | nargs="?", 207 | help="number of steps for ddim and fastdpm sampling", 208 | default=50 209 | ) 210 | parser.add_argument( 211 | "--batch_size", 212 | type=int, 213 | nargs="?", 214 | help="the bs", 215 | default=10 216 | ) 217 | return parser 218 | 219 | 220 | def load_model_from_config(config, sd): 221 | model = instantiate_from_config(config) 222 | model.load_state_dict(sd,strict=False) 223 | model.cuda() 224 | model.eval() 225 | return model 226 | 227 | 228 | def load_model(config, ckpt, gpu, eval_mode): 229 | if ckpt: 230 | print(f"Loading model from {ckpt}") 231 | pl_sd = torch.load(ckpt, map_location="cpu") 232 | global_step = pl_sd["global_step"] 233 | else: 234 | pl_sd = {"state_dict": None} 235 | global_step = None 236 | model = load_model_from_config(config.model, 237 | pl_sd["state_dict"]) 238 | 239 | return model, global_step 240 | 241 | 242 | if __name__ == "__main__": 243 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 244 | sys.path.append(os.getcwd()) 245 | command = " ".join(sys.argv) 246 | 247 | parser = get_parser() 248 | opt, unknown = parser.parse_known_args() 249 | ckpt = None 250 | 251 | if not os.path.exists(opt.resume): 252 | raise ValueError("Cannot find {}".format(opt.resume)) 253 | if os.path.isfile(opt.resume): 254 | # paths = opt.resume.split("/") 255 | try: 256 | logdir = '/'.join(opt.resume.split('/')[:-1]) 257 | # idx = len(paths)-paths[::-1].index("logs")+1 258 | print(f'Logdir is {logdir}') 259 | except ValueError: 260 | paths = opt.resume.split("/") 261 | idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt 262 | logdir = "/".join(paths[:idx]) 263 | ckpt = opt.resume 264 | else: 265 | assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" 266 | logdir = opt.resume.rstrip("/") 267 | ckpt = os.path.join(logdir, "model.ckpt") 268 | 269 | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) 270 | opt.base = base_configs 271 | 272 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 273 | cli = OmegaConf.from_dotlist(unknown) 274 | config = OmegaConf.merge(*configs, cli) 275 | 276 | gpu = True 277 | eval_mode = True 278 | 279 | if opt.logdir != "none": 280 | locallog = logdir.split(os.sep)[-1] 281 | if locallog == "": locallog = logdir.split(os.sep)[-2] 282 | print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") 283 | logdir = os.path.join(opt.logdir, locallog) 284 | 285 | print(config) 286 | 287 | model, global_step = load_model(config, ckpt, gpu, eval_mode) 288 | print(f"global step: {global_step}") 289 | print(75 * "=") 290 | print("logging to:") 291 | logdir = os.path.join(logdir, "samples", f"{global_step:08}", now) 292 | imglogdir = os.path.join(logdir, "img") 293 | numpylogdir = os.path.join(logdir, "numpy") 294 | 295 | os.makedirs(imglogdir) 296 | os.makedirs(numpylogdir) 297 | print(logdir) 298 | print(75 * "=") 299 | 300 | # write config out 301 | sampling_file = os.path.join(logdir, "sampling_config.yaml") 302 | sampling_conf = vars(opt) 303 | 304 | with open(sampling_file, 'w') as f: 305 | yaml.dump(sampling_conf, f, default_flow_style=False) 306 | print(sampling_conf) 307 | 308 | 309 | run(model, imglogdir, eta=opt.eta, 310 | vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, 311 | batch_size=opt.batch_size, nplog=numpylogdir) 312 | 313 | print("done.") 314 | -------------------------------------------------------------------------------- /scripts/stable_txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from itertools import islice 8 | from einops import rearrange 9 | from torchvision.utils import make_grid, save_image 10 | import time 11 | from pytorch_lightning import seed_everything 12 | from torch import autocast 13 | from contextlib import contextmanager, nullcontext 14 | 15 | from ldm.util import instantiate_from_config 16 | from ldm.models.diffusion.ddim import DDIMSampler 17 | from ldm.models.diffusion.plms import PLMSSampler 18 | 19 | 20 | def chunk(it, size): 21 | it = iter(it) 22 | return iter(lambda: tuple(islice(it, size)), ()) 23 | 24 | 25 | def load_model_from_config(config, ckpt, verbose=False): 26 | print(f"Loading model from {ckpt}") 27 | pl_sd = torch.load(ckpt, map_location="cpu") 28 | if "global_step" in pl_sd: 29 | print(f"Global Step: {pl_sd['global_step']}") 30 | sd = pl_sd["state_dict"] 31 | model = instantiate_from_config(config.model) 32 | m, u = model.load_state_dict(sd, strict=False) 33 | if len(m) > 0 and verbose: 34 | print("missing keys:") 35 | print(m) 36 | if len(u) > 0 and verbose: 37 | print("unexpected keys:") 38 | print(u) 39 | 40 | model.cuda() 41 | model.eval() 42 | return model 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument( 49 | "--prompt", 50 | type=str, 51 | nargs="?", 52 | default="a painting of a virus monster playing guitar", 53 | help="the prompt to render" 54 | ) 55 | parser.add_argument( 56 | "--outdir", 57 | type=str, 58 | nargs="?", 59 | help="dir to write results to", 60 | default="outputs/txt2img-samples" 61 | ) 62 | parser.add_argument( 63 | "--skip_grid", 64 | action='store_true', 65 | help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", 66 | ) 67 | parser.add_argument( 68 | "--skip_save", 69 | action='store_true', 70 | help="do not save individual samples. For speed measurements.", 71 | ) 72 | parser.add_argument( 73 | "--ddim_steps", 74 | type=int, 75 | default=50, 76 | help="number of ddim sampling steps", 77 | ) 78 | parser.add_argument( 79 | "--plms", 80 | action='store_true', 81 | help="use plms sampling", 82 | ) 83 | parser.add_argument( 84 | "--laion400m", 85 | action='store_true', 86 | help="uses the LAION400M model", 87 | ) 88 | parser.add_argument( 89 | "--fixed_code", 90 | action='store_true', 91 | help="if enabled, uses the same starting code across samples ", 92 | ) 93 | parser.add_argument( 94 | "--ddim_eta", 95 | type=float, 96 | default=0.0, 97 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 98 | ) 99 | parser.add_argument( 100 | "--n_iter", 101 | type=int, 102 | default=2, 103 | help="sample this often", 104 | ) 105 | parser.add_argument( 106 | "--H", 107 | type=int, 108 | default=512, 109 | help="image height, in pixel space", 110 | ) 111 | parser.add_argument( 112 | "--W", 113 | type=int, 114 | default=512, 115 | help="image width, in pixel space", 116 | ) 117 | parser.add_argument( 118 | "--C", 119 | type=int, 120 | default=4, 121 | help="latent channels", 122 | ) 123 | parser.add_argument( 124 | "--f", 125 | type=int, 126 | default=8, 127 | help="downsampling factor", 128 | ) 129 | parser.add_argument( 130 | "--n_samples", 131 | type=int, 132 | default=3, 133 | help="how many samples to produce for each given prompt. A.k.a. batch size", 134 | ) 135 | parser.add_argument( 136 | "--n_rows", 137 | type=int, 138 | default=0, 139 | help="rows in the grid (default: n_samples)", 140 | ) 141 | parser.add_argument( 142 | "--scale", 143 | type=float, 144 | default=7.5, 145 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 146 | ) 147 | parser.add_argument( 148 | "--from-file", 149 | type=str, 150 | help="if specified, load prompts from this file", 151 | ) 152 | parser.add_argument( 153 | "--config", 154 | type=str, 155 | default="configs/stable-diffusion/v1-inference.yaml", 156 | help="path to config which constructs model", 157 | ) 158 | parser.add_argument( 159 | "--ckpt", 160 | type=str, 161 | default="models/ldm/stable-diffusion-v1/model.ckpt", 162 | help="path to checkpoint of model", 163 | ) 164 | parser.add_argument( 165 | "--seed", 166 | type=int, 167 | default=42, 168 | help="the seed (for reproducible sampling)", 169 | ) 170 | parser.add_argument( 171 | "--precision", 172 | type=str, 173 | help="evaluate at this precision", 174 | choices=["full", "autocast"], 175 | default="autocast" 176 | ) 177 | 178 | 179 | parser.add_argument( 180 | "--embedding_path", 181 | type=str, 182 | help="Path to a pre-trained embedding manager checkpoint") 183 | 184 | opt = parser.parse_args() 185 | 186 | if opt.laion400m: 187 | print("Falling back to LAION 400M model...") 188 | opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" 189 | opt.ckpt = "models/ldm/text2img-large/model.ckpt" 190 | opt.outdir = "outputs/txt2img-samples-laion400m" 191 | 192 | seed_everything(opt.seed) 193 | 194 | config = OmegaConf.load(f"{opt.config}") 195 | model = load_model_from_config(config, f"{opt.ckpt}") 196 | #model.embedding_manager.load(opt.embedding_path) 197 | 198 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 199 | model = model.to(device) 200 | 201 | if opt.plms: 202 | sampler = PLMSSampler(model) 203 | else: 204 | sampler = DDIMSampler(model) 205 | 206 | os.makedirs(opt.outdir, exist_ok=True) 207 | outpath = opt.outdir 208 | 209 | batch_size = opt.n_samples 210 | n_rows = opt.n_rows if opt.n_rows > 0 else batch_size 211 | if not opt.from_file: 212 | prompt = opt.prompt 213 | assert prompt is not None 214 | data = [batch_size * [prompt]] 215 | 216 | else: 217 | print(f"reading prompts from {opt.from_file}") 218 | with open(opt.from_file, "r") as f: 219 | data = f.read().splitlines() 220 | data = list(chunk(data, batch_size)) 221 | 222 | sample_path = os.path.join(outpath, "samples") 223 | os.makedirs(sample_path, exist_ok=True) 224 | base_count = len(os.listdir(sample_path)) 225 | grid_count = len(os.listdir(outpath)) - 1 226 | 227 | start_code = None 228 | if opt.fixed_code: 229 | start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) 230 | 231 | precision_scope = autocast if opt.precision=="autocast" else nullcontext 232 | with torch.no_grad(): 233 | with precision_scope("cuda"): 234 | with model.ema_scope(): 235 | tic = time.time() 236 | all_samples = list() 237 | for n in trange(opt.n_iter, desc="Sampling"): 238 | for prompts in tqdm(data, desc="data"): 239 | uc = None 240 | if opt.scale != 1.0: 241 | uc = model.get_learned_conditioning(batch_size * [""]) 242 | if isinstance(prompts, tuple): 243 | prompts = list(prompts) 244 | c = model.get_learned_conditioning(prompts) 245 | shape = [opt.C, opt.H // opt.f, opt.W // opt.f] 246 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 247 | conditioning=c, 248 | batch_size=opt.n_samples, 249 | shape=shape, 250 | verbose=False, 251 | unconditional_guidance_scale=opt.scale, 252 | unconditional_conditioning=uc, 253 | eta=opt.ddim_eta, 254 | x_T=start_code) 255 | 256 | x_samples_ddim = model.decode_first_stage(samples_ddim) 257 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 258 | 259 | if not opt.skip_save: 260 | for x_sample in x_samples_ddim: 261 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 262 | Image.fromarray(x_sample.astype(np.uint8)).save( 263 | os.path.join(sample_path, f"{base_count:05}.jpg")) 264 | base_count += 1 265 | 266 | if not opt.skip_grid: 267 | all_samples.append(x_samples_ddim) 268 | 269 | if not opt.skip_grid: 270 | # additionally, save as grid 271 | grid = torch.stack(all_samples, 0) 272 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 273 | 274 | for i in range(grid.size(0)): 275 | save_image(grid[i, :, :, :], os.path.join(outpath,opt.prompt+'_{}.png'.format(i))) 276 | grid = make_grid(grid, nrow=n_rows) 277 | 278 | # to image 279 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 280 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}-{grid_count:04}.jpg')) 281 | grid_count += 1 282 | 283 | 284 | 285 | toc = time.time() 286 | 287 | print(f"Your samples are ready and waiting for you here: \n{outpath} \n" 288 | f" \nEnjoy.") 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /scripts/txt2img.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | import torch 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | from tqdm import tqdm, trange 7 | from einops import rearrange 8 | from torchvision.utils import make_grid, save_image 9 | 10 | from ldm.util import instantiate_from_config 11 | from ldm.models.diffusion.ddim import DDIMSampler 12 | from ldm.models.diffusion.plms import PLMSSampler 13 | 14 | def load_model_from_config(config, ckpt, verbose=False): 15 | print(f"Loading model from {ckpt}") 16 | pl_sd = torch.load(ckpt, map_location="cpu") 17 | sd = pl_sd["state_dict"] 18 | model = instantiate_from_config(config.model) 19 | m, u = model.load_state_dict(sd, strict=False) 20 | if len(m) > 0 and verbose: 21 | print("missing keys:") 22 | print(m) 23 | if len(u) > 0 and verbose: 24 | print("unexpected keys:") 25 | print(u) 26 | 27 | model.cuda() 28 | model.eval() 29 | return model 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument( 36 | "--prompt", 37 | type=str, 38 | nargs="?", 39 | default="a painting of a virus monster playing guitar", 40 | help="the prompt to render" 41 | ) 42 | 43 | parser.add_argument( 44 | "--outdir", 45 | type=str, 46 | nargs="?", 47 | help="dir to write results to", 48 | default="outputs/txt2img-samples" 49 | ) 50 | parser.add_argument( 51 | "--ddim_steps", 52 | type=int, 53 | default=200, 54 | help="number of ddim sampling steps", 55 | ) 56 | 57 | parser.add_argument( 58 | "--plms", 59 | action='store_true', 60 | help="use plms sampling", 61 | ) 62 | 63 | parser.add_argument( 64 | "--ddim_eta", 65 | type=float, 66 | default=0.0, 67 | help="ddim eta (eta=0.0 corresponds to deterministic sampling", 68 | ) 69 | parser.add_argument( 70 | "--n_iter", 71 | type=int, 72 | default=1, 73 | help="sample this often", 74 | ) 75 | 76 | parser.add_argument( 77 | "--H", 78 | type=int, 79 | default=256, 80 | help="image height, in pixel space", 81 | ) 82 | 83 | parser.add_argument( 84 | "--W", 85 | type=int, 86 | default=256, 87 | help="image width, in pixel space", 88 | ) 89 | 90 | parser.add_argument( 91 | "--n_samples", 92 | type=int, 93 | default=4, 94 | help="how many samples to produce for the given prompt", 95 | ) 96 | 97 | parser.add_argument( 98 | "--scale", 99 | type=float, 100 | default=5.0, 101 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 102 | ) 103 | 104 | parser.add_argument( 105 | "--ckpt_path", 106 | type=str, 107 | default="/data/pretrained_models/ldm/text2img-large/model.ckpt", 108 | help="Path to pretrained ldm text2img model") 109 | 110 | parser.add_argument( 111 | "--embedding_path", 112 | type=str, 113 | help="Path to a pre-trained embedding manager checkpoint") 114 | 115 | opt = parser.parse_args() 116 | 117 | 118 | config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic 119 | model = load_model_from_config(config, opt.ckpt_path) # TODO: check path 120 | #model.embedding_manager.load(opt.embedding_path) 121 | 122 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 123 | model = model.to(device) 124 | 125 | if opt.plms: 126 | sampler = PLMSSampler(model) 127 | else: 128 | sampler = DDIMSampler(model) 129 | 130 | os.makedirs(opt.outdir, exist_ok=True) 131 | outpath = opt.outdir 132 | 133 | prompt = opt.prompt 134 | 135 | 136 | sample_path = os.path.join(outpath, "samples") 137 | os.makedirs(sample_path, exist_ok=True) 138 | base_count = len(os.listdir(sample_path)) 139 | 140 | all_samples=list() 141 | with torch.no_grad(): 142 | with model.ema_scope(): 143 | uc = None 144 | if opt.scale != 1.0: 145 | uc = model.get_learned_conditioning(opt.n_samples * [""]) 146 | for n in trange(opt.n_iter, desc="Sampling"): 147 | c = model.get_learned_conditioning(opt.n_samples * [prompt]) 148 | shape = [4, opt.H//8, opt.W//8] 149 | samples_ddim, _ = sampler.sample(S=opt.ddim_steps, 150 | conditioning=c, 151 | batch_size=opt.n_samples, 152 | shape=shape, 153 | verbose=False, 154 | unconditional_guidance_scale=opt.scale, 155 | unconditional_conditioning=uc, 156 | eta=opt.ddim_eta) 157 | 158 | x_samples_ddim = model.decode_first_stage(samples_ddim) 159 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 160 | 161 | for x_sample in x_samples_ddim: 162 | x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 163 | Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.jpg")) 164 | base_count += 1 165 | all_samples.append(x_samples_ddim) 166 | 167 | 168 | # additionally, save as grid 169 | grid = torch.stack(all_samples, 0) 170 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 171 | 172 | for i in range(grid.size(0)): 173 | save_image(grid[i, :, :, :], os.path.join(outpath,opt.prompt+'_{}.png'.format(i))) 174 | 175 | grid = make_grid(grid, nrow=opt.n_samples) 176 | 177 | 178 | # to image 179 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 180 | Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.jpg')) 181 | 182 | 183 | 184 | print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") 185 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------