├── .gitignore ├── LICENSE ├── README.md ├── assets ├── inpainting_example.png ├── intro.png ├── sdm-1.png ├── sdm-2.png └── sdm-3.png ├── dpm_solver_jax.py ├── dpm_solver_pytorch.py └── examples ├── ddpm_and_guided-diffusion ├── README.md ├── configs │ ├── bedroom_guided.yml │ ├── celeba.yml │ ├── cifar10.yml │ ├── imagenet128_guided.yml │ ├── imagenet256_guided.yml │ ├── imagenet512_guided.yml │ └── imagenet64.yml ├── datasets │ ├── __init__.py │ ├── celeba.py │ ├── ffhq.py │ ├── lsun.py │ ├── utils.py │ └── vision.py ├── dpm_solver │ └── sampler.py ├── evaluate │ ├── fid_score.py │ └── inception.py ├── functions │ ├── __init__.py │ ├── ckpt_util.py │ ├── denoising.py │ └── losses.py ├── main.py ├── models │ ├── diffusion.py │ ├── ema.py │ ├── guided_diffusion │ │ ├── __init__.py │ │ ├── fp16_util.py │ │ ├── logger.py │ │ ├── nn.py │ │ └── unet.py │ └── improved_ddpm │ │ ├── __init__.py │ │ ├── fp16_util.py │ │ ├── nn.py │ │ └── unet.py ├── runners │ ├── __init__.py │ └── diffusion.py └── sample.sh ├── score_sde_jax ├── .gitignore ├── LICENSE ├── README.md ├── Score_SDE_demo.ipynb ├── assets │ ├── bedroom.jpeg │ ├── celebahq_256.jpg │ ├── church.jpeg │ ├── ffhq_1024.jpeg │ ├── ffhq_256.jpg │ ├── ffhq_samples.jpg │ └── schematic.jpg ├── configs │ ├── default_celeba_configs.py │ ├── default_cifar10_configs.py │ ├── default_lsun_configs.py │ ├── subvp │ │ ├── cifar10_ddpm_continuous.py │ │ ├── cifar10_ddpmpp_continuous.py │ │ ├── cifar10_ddpmpp_deep_continuous.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ └── cifar10_ncsnpp_deep_continuous.py │ ├── ve │ │ ├── bedroom_ncsnpp_continuous.py │ │ ├── celeba_ncsnpp.py │ │ ├── celebahq_256_ncsnpp_continuous.py │ │ ├── celebahq_ncsnpp_continuous.py │ │ ├── church_ncsnpp_continuous.py │ │ ├── cifar10_ddpm.py │ │ ├── cifar10_ncsnpp.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ ├── cifar10_ncsnpp_deep_continuous.py │ │ ├── ffhq_256_ncsnpp_continuous.py │ │ ├── ffhq_ncsnpp_continuous.py │ │ ├── ncsn │ │ │ ├── celeba.py │ │ │ ├── celeba_124.py │ │ │ ├── celeba_1245.py │ │ │ ├── celeba_5.py │ │ │ ├── cifar10.py │ │ │ ├── cifar10_124.py │ │ │ ├── cifar10_1245.py │ │ │ └── cifar10_5.py │ │ └── ncsnv2 │ │ │ ├── bedroom.py │ │ │ ├── celeba.py │ │ │ └── cifar10.py │ └── vp │ │ ├── cifar10_ddpmpp.py │ │ ├── cifar10_ddpmpp_continuous.py │ │ ├── cifar10_ddpmpp_deep_continuous.py │ │ ├── cifar10_ncsnpp.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ ├── cifar10_ncsnpp_deep_continuous.py │ │ └── ddpm │ │ ├── bedroom.py │ │ ├── celebahq.py │ │ ├── church.py │ │ ├── cifar10.py │ │ ├── cifar10_continuous.py │ │ └── cifar10_unconditional.py ├── controllable_generation.py ├── datasets.py ├── dpm_solver.py ├── evaluation.py ├── likelihood.py ├── losses.py ├── main.py ├── models │ ├── __init__.py │ ├── ddpm.py │ ├── layers.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── ncsnv2.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ ├── utils.py │ └── wideresnet_noise_conditional.py ├── requirements.txt ├── run_lib.py ├── sample.sh ├── sampling.py ├── sde_lib.py └── utils.py ├── score_sde_pytorch ├── .gitignore ├── LICENSE ├── README.md ├── Score_SDE_demo_PyTorch.ipynb ├── assets │ ├── bedroom.jpeg │ ├── celebahq_256.jpg │ ├── church.jpeg │ ├── ffhq_1024.jpeg │ ├── ffhq_256.jpg │ ├── ffhq_samples.jpg │ └── schematic.jpg ├── configs │ ├── default_celeba_configs.py │ ├── default_cifar10_configs.py │ ├── default_lsun_configs.py │ ├── subvp │ │ ├── cifar10_ddpm_continuous.py │ │ ├── cifar10_ddpmpp_continuous.py │ │ ├── cifar10_ddpmpp_deep_continuous.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ └── cifar10_ncsnpp_deep_continuous.py │ ├── ve │ │ ├── bedroom_ncsnpp_continuous.py │ │ ├── celeba_ncsnpp.py │ │ ├── celebahq_256_ncsnpp_continuous.py │ │ ├── celebahq_ncsnpp_continuous.py │ │ ├── church_ncsnpp_continuous.py │ │ ├── cifar10_ddpm.py │ │ ├── cifar10_ncsnpp.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ ├── cifar10_ncsnpp_deep_continuous.py │ │ ├── ffhq_256_ncsnpp_continuous.py │ │ ├── ffhq_ncsnpp_continuous.py │ │ ├── ncsn │ │ │ ├── celeba.py │ │ │ ├── celeba_124.py │ │ │ ├── celeba_1245.py │ │ │ ├── celeba_5.py │ │ │ ├── cifar10.py │ │ │ ├── cifar10_124.py │ │ │ ├── cifar10_1245.py │ │ │ └── cifar10_5.py │ │ └── ncsnv2 │ │ │ ├── bedroom.py │ │ │ ├── celeba.py │ │ │ └── cifar10.py │ └── vp │ │ ├── cifar10_ddpmpp.py │ │ ├── cifar10_ddpmpp_continuous.py │ │ ├── cifar10_ddpmpp_deep_continuous.py │ │ ├── cifar10_ncsnpp.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ ├── cifar10_ncsnpp_deep_continuous.py │ │ └── ddpm │ │ ├── bedroom.py │ │ ├── celebahq.py │ │ ├── church.py │ │ ├── cifar10.py │ │ ├── cifar10_continuous.py │ │ └── cifar10_unconditional.py ├── controllable_generation.py ├── datasets.py ├── debug.py ├── dpm_solver.py ├── evaluation.py ├── likelihood.py ├── losses.py ├── main.py ├── models │ ├── __init__.py │ ├── ddpm.py │ ├── ema.py │ ├── layers.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── ncsnv2.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ └── utils.py ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── requirements.txt ├── run_lib.py ├── sample.sh ├── sampling.py ├── sde_lib.py └── utils.py └── stable-diffusion ├── LICENSE ├── README.md ├── Stable_Diffusion_v1_Model_Card.md ├── assets ├── a-painting-of-a-fire.png ├── a-photograph-of-a-fire.png ├── a-shirt-with-a-fire-printed-on-it.png ├── a-shirt-with-the-inscription-'fire'.png ├── a-watercolor-painting-of-a-fire.png ├── birdhouse.png ├── fire.png ├── inpainting.png ├── modelfigure.png ├── rdm-preview.jpg ├── reconstruction1.png ├── reconstruction2.png ├── results.gif ├── rick.jpeg ├── stable-samples │ ├── img2img │ │ ├── mountains-1.png │ │ ├── mountains-2.png │ │ ├── mountains-3.png │ │ ├── sketch-mountains-input.jpg │ │ ├── upscaling-in.png │ │ └── upscaling-out.png │ └── txt2img │ │ ├── 000002025.png │ │ ├── 000002035.png │ │ ├── merged-0005.png │ │ ├── merged-0006.png │ │ └── merged-0007.png ├── the-earth-is-on-fire,-oil-on-canvas.png ├── txt2img-convsample.png ├── txt2img-preview.png └── v1-variants-scores.jpg ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ └── v1-inference.yaml ├── data ├── DejaVuSans.ttf ├── example_conditioning │ ├── superresolution │ │ └── sample_0.jpg │ └── text_conditional │ │ └── sample_0.txt ├── fruit.png ├── imagenet_clsidx_to_label.txt ├── imagenet_train_hr_indices.p ├── imagenet_val_hr_indices.p ├── index_synset.yaml └── inpainting_examples │ ├── 6458524847_2f4c361183_k.png │ ├── 6458524847_2f4c361183_k_mask.png │ ├── 8399166846_f6fb4e4b8e_k.png │ ├── 8399166846_f6fb4e4b8e_k_mask.png │ ├── alex-iby-G_Pk4D9rMLs.png │ ├── alex-iby-G_Pk4D9rMLs_mask.png │ ├── bench2.png │ ├── bench2_mask.png │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png │ ├── billow926-12-Wc-Zgx6Y.png │ ├── billow926-12-Wc-Zgx6Y_mask.png │ ├── overture-creations-5sI6fQgYIuo.png │ ├── overture-creations-5sI6fQgYIuo_mask.png │ ├── photo-1583445095369-9c651e7e5d34.png │ └── photo-1583445095369-9c651e7e5d34_mask.png ├── environment.yaml ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── ffhq256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_beds256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── notebook_helpers.py ├── scripts ├── diffedit_inpaint.ipynb ├── download_first_stages.sh ├── download_models.sh ├── img2img.py ├── inpaint.py ├── knn2img.py ├── latent_imagenet_diffusion.ipynb ├── sample_diffusion.py ├── tests │ └── test_watermark.py ├── train_searcher.py └── txt2img.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Cheng Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/inpainting_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/assets/inpainting_example.png -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/assets/intro.png -------------------------------------------------------------------------------- /assets/sdm-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/assets/sdm-1.png -------------------------------------------------------------------------------- /assets/sdm-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/assets/sdm-2.png -------------------------------------------------------------------------------- /assets/sdm-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/assets/sdm-3.png -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/bedroom_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "bedroom" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | num_classes: 1 13 | 14 | model: 15 | model_type: "guided_diffusion" 16 | is_upsampling: false 17 | image_size: 256 18 | in_channels: 3 19 | model_channels: 256 20 | out_channels: 6 21 | num_res_blocks: 2 22 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 23 | dropout: 0.1 24 | channel_mult: [1, 1, 2, 2, 4, 4] 25 | conv_resample: true 26 | dims: 2 27 | num_classes: null 28 | use_checkpoint: false 29 | use_fp16: true 30 | num_heads: 4 31 | num_head_channels: 64 32 | num_heads_upsample: -1 33 | use_scale_shift_norm: true 34 | resblock_updown: true 35 | use_new_attention_order: false 36 | var_type: fixedlarge 37 | ema: false 38 | ckpt_dir: "~/ddpm_ckpt/bedroom/lsun_bedroom.pt" 39 | 40 | diffusion: 41 | beta_schedule: linear 42 | beta_start: 0.0001 43 | beta_end: 0.02 44 | num_diffusion_timesteps: 1000 45 | 46 | sampling: 47 | total_N: 1000 48 | batch_size: 50 49 | last_only: True 50 | fid_stats_dir: "fid_stats/VIRTUAL_lsun_bedroom256.npz" 51 | fid_total_samples: 50000 52 | fid_batch_size: 1000 53 | cond_class: false 54 | classifier_scale: 0.0 55 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CELEBA" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 1 12 | 13 | model: 14 | model_type: "ddpm" 15 | is_upsampling: false 16 | type: "simple" 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: [1, 2, 2, 2, 4] 21 | num_res_blocks: 2 22 | attn_resolutions: [16, ] 23 | dropout: 0.1 24 | var_type: fixedlarge 25 | ema_rate: 0.9999 26 | ema: True 27 | resamp_with_conv: True 28 | ckpt_dir: "~/ddpm_ckpt/celeba/ckpt.pth" 29 | 30 | diffusion: 31 | beta_schedule: linear 32 | beta_start: 0.0001 33 | beta_end: 0.02 34 | num_diffusion_timesteps: 1000 35 | 36 | training: 37 | batch_size: 128 38 | n_epochs: 10000 39 | n_iters: 5000000 40 | snapshot_freq: 5000 41 | validation_freq: 20000 42 | 43 | sampling: 44 | total_N: 1000 45 | batch_size: 500 46 | last_only: True 47 | fid_stats_dir: "fid_stats/fid_stats_celeba64_train_50000_ddim.npz" 48 | fid_total_samples: 50000 49 | fid_batch_size: 1000 50 | cond_class: false 51 | classifier_scale: 0.0 52 | 53 | optim: 54 | weight_decay: 0.000 55 | optimizer: "Adam" 56 | lr: 0.0002 57 | beta1: 0.9 58 | amsgrad: false 59 | eps: 0.00000001 60 | grad_clip: 1.0 61 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | image_size: 32 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 10 12 | 13 | model: 14 | model_type: "ddpm" 15 | is_upsampling: false 16 | type: "simple" 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: [1, 2, 2, 2] 21 | num_res_blocks: 2 22 | attn_resolutions: [16, ] 23 | dropout: 0.1 24 | var_type: fixedlarge 25 | ema_rate: 0.9999 26 | ema: True 27 | resamp_with_conv: True 28 | 29 | diffusion: 30 | beta_schedule: linear 31 | beta_start: 0.0001 32 | beta_end: 0.02 33 | num_diffusion_timesteps: 1000 34 | 35 | training: 36 | batch_size: 128 37 | n_epochs: 10000 38 | n_iters: 5000000 39 | snapshot_freq: 5000 40 | validation_freq: 2000 41 | 42 | sampling: 43 | total_N: 1000 44 | batch_size: 1000 45 | last_only: True 46 | fid_stats_dir: "fid_stats/fid_stats_cifar10_train_pytorch.npz" 47 | fid_total_samples: 50000 48 | fid_batch_size: 1000 49 | cond_class: false 50 | classifier_scale: 0.0 51 | 52 | optim: 53 | weight_decay: 0.000 54 | optimizer: "Adam" 55 | lr: 0.0002 56 | beta1: 0.9 57 | amsgrad: false 58 | eps: 0.00000001 59 | grad_clip: 1.0 60 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/imagenet128_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET128" 3 | image_size: 128 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "guided_diffusion" 15 | is_upsampling: false 16 | image_size: 128 17 | in_channels: 3 18 | model_channels: 256 19 | out_channels: 6 20 | num_res_blocks: 2 21 | attention_resolutions: [4, 8, 16] # [128 // 32, 128 // 16, 128 // 8] 22 | dropout: 0.0 23 | channel_mult: [1, 1, 2, 3, 4] 24 | conv_resample: true 25 | dims: 2 26 | num_classes: 1000 27 | use_checkpoint: false 28 | use_fp16: true 29 | num_heads: 4 30 | num_head_channels: -1 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | resblock_updown: true 34 | use_new_attention_order: false 35 | var_type: fixedlarge 36 | ema: false 37 | ckpt_dir: "~/ddpm_ckpt/imagenet128/128x128_diffusion.pt" 38 | 39 | classifier: 40 | ckpt_dir: "~/ddpm_ckpt/imagenet128/128x128_classifier.pt" 41 | image_size: 128 42 | in_channels: 3 43 | model_channels: 128 44 | out_channels: 1000 45 | num_res_blocks: 2 46 | attention_resolutions: [4, 8, 16] # [128 // 32, 128 // 16, 128 // 8] 47 | channel_mult: [1, 1, 2, 3, 4] 48 | use_fp16: true 49 | num_head_channels: 64 50 | use_scale_shift_norm: true 51 | resblock_updown: true 52 | pool: "attention" 53 | 54 | diffusion: 55 | beta_schedule: linear 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | num_diffusion_timesteps: 1000 59 | 60 | sampling: 61 | total_N: 1000 62 | schedule: "linear" 63 | time_input_type: '1' 64 | batch_size: 500 65 | last_only: True 66 | fid_stats_dir: "fid_stats/VIRTUAL_imagenet128_labeled.npz" 67 | fid_total_samples: 50000 68 | fid_batch_size: 200 69 | cond_class: true 70 | classifier_scale: 1.25 71 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/imagenet256_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET256" 3 | image_size: 256 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "guided_diffusion" 15 | is_upsampling: false 16 | image_size: 256 17 | in_channels: 3 18 | model_channels: 256 19 | out_channels: 6 20 | num_res_blocks: 2 21 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 22 | dropout: 0.0 23 | channel_mult: [1, 1, 2, 2, 4, 4] 24 | conv_resample: true 25 | dims: 2 26 | num_classes: 1000 27 | use_checkpoint: false 28 | use_fp16: true 29 | num_heads: 4 30 | num_head_channels: 64 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | resblock_updown: true 34 | use_new_attention_order: false 35 | var_type: fixedlarge 36 | ema: false 37 | ckpt_dir: "~/ddpm_ckpt/imagenet256/256x256_diffusion.pt" 38 | 39 | classifier: 40 | ckpt_dir: "~/ddpm_ckpt/imagenet256/256x256_classifier.pt" 41 | image_size: 256 42 | in_channels: 3 43 | model_channels: 128 44 | out_channels: 1000 45 | num_res_blocks: 2 46 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 47 | channel_mult: [1, 1, 2, 2, 4, 4] 48 | use_fp16: true 49 | num_head_channels: 64 50 | use_scale_shift_norm: true 51 | resblock_updown: true 52 | pool: "attention" 53 | 54 | diffusion: 55 | beta_schedule: linear 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | num_diffusion_timesteps: 1000 59 | 60 | sampling: 61 | total_N: 1000 62 | batch_size: 50 63 | last_only: True 64 | fid_stats_dir: "fid_stats/VIRTUAL_imagenet256_labeled.npz" 65 | fid_total_samples: 10000 66 | fid_batch_size: 200 67 | cond_class: true 68 | classifier_scale: 2.5 69 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/imagenet512_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET512" 3 | image_size: 512 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "guided_diffusion" 15 | is_upsampling: false 16 | image_size: 512 17 | in_channels: 3 18 | model_channels: 256 19 | out_channels: 6 20 | num_res_blocks: 2 21 | attention_resolutions: [16, 32, 64] # [512 // 32, 512 // 16, 512 // 8] 22 | dropout: 0.0 23 | channel_mult: [0.5, 1, 1, 2, 2, 4, 4] 24 | conv_resample: true 25 | dims: 2 26 | num_classes: 1000 27 | use_checkpoint: false 28 | use_fp16: false 29 | num_heads: 4 30 | num_head_channels: 64 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | resblock_updown: true 34 | use_new_attention_order: false 35 | var_type: fixedlarge 36 | ema: false 37 | ckpt_dir: "~/ddpm_ckpt/imagenet512/512x512_diffusion.pt" 38 | 39 | classifier: 40 | ckpt_dir: "~/ddpm_ckpt/imagenet512/512x512_classifier.pt" 41 | image_size: 512 42 | in_channels: 3 43 | model_channels: 128 44 | out_channels: 1000 45 | num_res_blocks: 2 46 | attention_resolutions: [16, 32, 64] # [256 // 32, 256 // 16, 256 // 8] 47 | channel_mult: [0.5, 1, 1, 2, 2, 4, 4] 48 | use_fp16: false 49 | num_head_channels: 64 50 | use_scale_shift_norm: true 51 | resblock_updown: true 52 | pool: "attention" 53 | 54 | diffusion: 55 | beta_schedule: linear 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | num_diffusion_timesteps: 1000 59 | 60 | sampling: 61 | total_N: 1000 62 | schedule: "linear" 63 | batch_size: 20 64 | last_only: True 65 | fid_stats_dir: "fid_stats/VIRTUAL_imagenet512.npz" 66 | fid_total_samples: 50000 67 | fid_batch_size: 200 68 | cond_class: true 69 | classifier_scale: 4.0 70 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/configs/imagenet64.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET64" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "improved_ddpm" 15 | is_upsampling: false 16 | in_channels: 3 17 | model_channels: 128 18 | out_channels: 6 19 | num_res_blocks: 3 20 | attention_resolutions: [4, 8] 21 | dropout: 0.0 22 | channel_mult: [1, 2, 3, 4] 23 | conv_resample: true 24 | dims: 2 25 | use_checkpoint: false 26 | num_heads: 4 27 | num_heads_upsample: -1 28 | use_scale_shift_norm: true 29 | var_type: fixedlarge 30 | use_fp16: false 31 | ema: false 32 | ckpt_dir: "~/ddpm_ckpt/imagenet64/imagenet64_uncond_100M_1500K.pt" 33 | 34 | diffusion: 35 | beta_schedule: cosine 36 | beta_start: null 37 | beta_end: null 38 | num_diffusion_timesteps: 4000 39 | 40 | sampling: 41 | total_N: 4000 42 | batch_size: 500 43 | last_only: True 44 | fid_stats_dir: "fid_stats/fid_stats_imagenet64_train.npz" 45 | fid_total_samples: 50000 46 | fid_batch_size: 1000 47 | cond_class: false 48 | classifier_scale: 0.0 49 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/functions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def noise_estimation_loss(model, 5 | x0: torch.Tensor, 6 | t: torch.LongTensor, 7 | e: torch.Tensor, 8 | b: torch.Tensor, keepdim=False): 9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt() 11 | output = model(x, t.float()) 12 | if keepdim: 13 | return (e - output).square().sum(dim=(1, 2, 3)) 14 | else: 15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 16 | 17 | 18 | loss_registry = { 19 | 'simple': noise_estimation_loss, 20 | } 21 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/models/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/ddpm_and_guided-diffusion/models/guided_diffusion/__init__.py -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/models/improved_ddpm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/ddpm_and_guided-diffusion/models/improved_ddpm/__init__.py -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/ddpm_and_guided-diffusion/runners/__init__.py -------------------------------------------------------------------------------- /examples/ddpm_and_guided-diffusion/sample.sh: -------------------------------------------------------------------------------- 1 | DEVICES='4,5' 2 | 3 | 4 | ########################## 5 | 6 | # CIFAR-10 (DDPM checkpoint) example 7 | 8 | data="cifar10" 9 | sampleMethod='dpmsolver++' 10 | type="dpmsolver" 11 | steps="10" 12 | DIS="logSNR" 13 | order="3" 14 | method="multistep" 15 | workdir="experiments/"$data"/"$sampleMethod"_"$method"_order"$order"_"$steps"_"$DIS"_type-"$type 16 | 17 | CUDA_VISIBLE_DEVICES=$DEVICES python main.py --config $data".yml" --exp=$workdir --sample --fid --timesteps=$steps --eta 0 --ni --skip_type=$DIS --sample_type=$sampleMethod --dpm_solver_order=$order --dpm_solver_method=$method --dpm_solver_type=$type --port 12350 18 | 19 | 20 | ######################### 21 | 22 | # ImageNet64 (improved-DDPM checkpoint) example 23 | 24 | data="imagenet64" 25 | sampleMethod='dpmsolver++' 26 | type="dpmsolver" 27 | steps="10" 28 | DIS="logSNR" 29 | order="3" 30 | method="multistep" 31 | workdir="experiments/"$data"/"$sampleMethod"_"$method"_order"$order"_"$steps"_"$DIS"_type-"$type 32 | 33 | CUDA_VISIBLE_DEVICES=$DEVICES python main.py --config $data".yml" --exp=$workdir --sample --fid --timesteps=$steps --eta 0 --ni --skip_type=$DIS --sample_type=$sampleMethod --dpm_solver_order=$order --dpm_solver_method=$method --dpm_solver_type=$type --port 12350 34 | 35 | 36 | ######################### 37 | 38 | # ImageNet256 with classifier guidance (large guidance scale) example 39 | 40 | data="imagenet256_guided" 41 | scale="8.0" 42 | sampleMethod='dpmsolver++' 43 | type="dpmsolver" 44 | steps="20" 45 | DIS="time_uniform" 46 | order="2" 47 | method="multistep" 48 | 49 | workdir="experiments/"$data"/"$sampleMethod"_"$method"_order"$order"_"$steps"_"$DIS"_scale"$scale"_type-"$type"_thresholding" 50 | CUDA_VISIBLE_DEVICES=$DEVICES python main.py --config $data".yml" --exp=$workdir --sample --fid --timesteps=$steps --eta 0 --ni --skip_type=$DIS --sample_type=$sampleMethod --dpm_solver_order=$order --dpm_solver_method=$method --dpm_solver_type=$type --port 12350 --scale=$scale --thresholding 51 | -------------------------------------------------------------------------------- /examples/score_sde_jax/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | 23 | experiments 24 | assets/stats -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/bedroom.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/bedroom.jpeg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/celebahq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/celebahq_256.jpg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/church.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/church.jpeg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/ffhq_1024.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/ffhq_1024.jpeg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/ffhq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/ffhq_256.jpg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/ffhq_samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/ffhq_samples.jpg -------------------------------------------------------------------------------- /examples/score_sde_jax/assets/schematic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_jax/assets/schematic.jpg -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/subvp/cifar10_ddpm_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with sub-VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'subvpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/subvp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/subvp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with sub-VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'subvpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/bedroom_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on bedroom with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'bedroom' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = True 43 | model.ema_rate = 0.999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 48 | model.num_res_blocks = 2 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'output_skip' 57 | model.progressive_input = 'input_skip' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.init_scale = 0. 61 | model.fourier_scale = 16 62 | model.conv_size = 3 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/celeba_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CelebA with SMLD.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.sigma_begin = 90 40 | model.ema_rate = 0.999 41 | model.normalization = 'GroupNorm' 42 | model.nonlinearity = 'swish' 43 | model.nf = 128 44 | model.ch_mult = (1, 2, 2, 2) 45 | model.num_res_blocks = 4 46 | model.attn_resolutions = (16,) 47 | model.resamp_with_conv = True 48 | model.conditional = True 49 | model.fir = True 50 | model.fir_kernel = [1, 3, 3, 1] 51 | model.skip_rescale = True 52 | model.resblock_type = 'biggan' 53 | model.progressive = 'none' 54 | model.progressive_input = 'residual' 55 | model.progressive_combine = 'sum' 56 | model.attention_type = 'ddpm' 57 | model.init_scale = 0.0 58 | model.conv_size = 3 59 | model.embedding_type = 'positional' 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/church_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on Church with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'church_outdoor' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.sigma_max = 380 43 | model.scale_by_sigma = True 44 | model.ema_rate = 0.999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 49 | model.num_res_blocks = 2 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'output_skip' 58 | model.progressive_input = 'input_skip' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/cifar10_ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Train the original DDPM model with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ddpm' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 2 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.conv_size = 3 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 4 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.fir = True 49 | model.fir_kernel = [1, 3, 3, 1] 50 | model.skip_rescale = True 51 | model.resblock_type = 'biggan' 52 | model.progressive = 'none' 53 | model.progressive_input = 'residual' 54 | model.progressive_combine = 'sum' 55 | model.attention_type = 'ddpm' 56 | model.init_scale = 0.0 57 | model.embedding_type = 'positional' 58 | model.conv_size = 3 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vesde' 26 | training.continuous = True 27 | 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'reverse_diffusion' 32 | sampling.corrector = 'langevin' 33 | 34 | # model 35 | model = config.model 36 | model.name = 'ncsnpp' 37 | model.scale_by_sigma = True 38 | model.ema_rate = 0.999 39 | model.normalization = 'GroupNorm' 40 | model.nonlinearity = 'swish' 41 | model.nf = 128 42 | model.ch_mult = (1, 2, 2, 2) 43 | model.num_res_blocks = 4 44 | model.attn_resolutions = (16,) 45 | model.resamp_with_conv = True 46 | model.conditional = True 47 | model.fir = True 48 | model.fir_kernel = [1, 3, 3, 1] 49 | model.skip_rescale = True 50 | model.resblock_type = 'biggan' 51 | model.progressive = 'none' 52 | model.progressive_input = 'residual' 53 | model.progressive_combine = 'sum' 54 | model.attention_type = 'ddpm' 55 | model.init_scale = 0. 56 | model.fourier_scale = 16 57 | model.conv_size = 3 58 | 59 | return config 60 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'langevin' 35 | 36 | # model 37 | model = config.model 38 | model.name = 'ncsnpp' 39 | model.fourier_scale = 16 40 | model.scale_by_sigma = True 41 | model.ema_rate = 0.999 42 | model.normalization = 'GroupNorm' 43 | model.nonlinearity = 'swish' 44 | model.nf = 128 45 | model.ch_mult = (1, 2, 2, 2) 46 | model.num_res_blocks = 8 47 | model.attn_resolutions = (16,) 48 | model.resamp_with_conv = True 49 | model.conditional = True 50 | model.fir = True 51 | model.fir_kernel = [1, 3, 3, 1] 52 | model.skip_rescale = True 53 | model.resblock_type = 'biggan' 54 | model.progressive = 'none' 55 | model.progressive_input = 'residual' 56 | model.progressive_combine = 'sum' 57 | model.attention_type = 'ddpm' 58 | model.init_scale = 0.0 59 | model.conv_size = 3 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.loss = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/celeba_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/celeba_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1245 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/celeba_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv1 model with technique 5 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1. 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/cifar10_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 232 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/cifar10_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4,5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000062 30 | n_steps_each = 5 31 | ckpt_id = 300000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.176 41 | # model 42 | model = config.model 43 | model.name = 'ncsn' 44 | model.scale_by_sigma = False 45 | model.num_scales = 232 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-3 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsn/cifar10_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.snr = 0.316 34 | sampling.n_steps_each = 100 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsnv2/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on bedroom.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.batch_size = 128 27 | training.sde = 'vesde' 28 | training.continuouse = False 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'none' 33 | sampling.corrector = 'ald' 34 | sampling.n_steps_each = 3 35 | sampling.snr = 0.095 36 | # data 37 | data = config.data 38 | data.category = 'bedroom' 39 | data.image_size = 128 40 | # model 41 | model = config.model 42 | model.name = 'ncsnv2_128' 43 | model.scale_by_sigma = True 44 | model.sigma_max = 190 45 | model.num_scales = 1086 46 | model.ema_rate = 0.9999 47 | model.sigma_min = 0.01 48 | model.normalization = 'InstanceNorm++' 49 | model.nonlinearity = 'elu' 50 | model.nf = 128 51 | model.interpolation = 'bilinear' 52 | # optim 53 | optim = config.optim 54 | optim.weight_decay = 0 55 | optim.optimizer = 'Adam' 56 | optim.lr = 1e-4 57 | optim.beta1 = 0.9 58 | optim.amsgrad = False 59 | optim.eps = 1e-8 60 | optim.warmup = 0 61 | optim.grad_clip = -1 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsnv2/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000033 30 | n_steps_each = 5 31 | ckpt_id = 210000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.128 41 | # model 42 | model = config.model 43 | model.name = 'ncsnv2_64' 44 | model.scale_by_sigma = True 45 | model.num_scales = 500 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-4 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/ve/ncsnv2/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsnv2_64' 38 | model.scale_by_sigma = True 39 | model.num_scales = 232 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-4 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/cifar10_ddpmpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'ancestral_sampling' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with DDPM.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'residual' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0.0 62 | model.embedding_type = 'positional' 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ncsnpp' 44 | model.fourier_scale = 16 45 | model.scale_by_sigma = False 46 | model.ema_rate = 0.9999 47 | model.normalization = 'GroupNorm' 48 | model.nonlinearity = 'swish' 49 | model.nf = 128 50 | model.ch_mult = (1, 2, 2, 2) 51 | model.num_res_blocks = 8 52 | model.attn_resolutions = (16,) 53 | model.resamp_with_conv = True 54 | model.conditional = True 55 | model.fir = True 56 | model.fir_kernel = [1, 3, 3, 1] 57 | model.skip_rescale = True 58 | model.resblock_type = 'biggan' 59 | model.progressive = 'none' 60 | model.progressive_input = 'residual' 61 | model.progressive_combine = 'sum' 62 | model.attention_type = 'ddpm' 63 | model.embedding_type = 'positional' 64 | model.init_scale = 0.0 65 | model.conv_size = 3 66 | 67 | return config 68 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'bedroom' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/celebahq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.dataset = 'CelebAHQ' 40 | data.centered = True 41 | data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' 42 | data.image_size = 256 43 | 44 | # model 45 | model = config.model 46 | model.name = 'ddpm' 47 | model.scale_by_sigma = False 48 | model.num_scales = 1000 49 | model.ema_rate = 0.9999 50 | model.normalization = 'GroupNorm' 51 | model.nonlinearity = 'swish' 52 | model.nf = 128 53 | model.ch_mult = (1, 1, 2, 2, 4, 4) 54 | model.num_res_blocks = 2 55 | model.attn_resolutions = (16,) 56 | model.resamp_with_conv = True 57 | model.conditional = True 58 | 59 | # optim 60 | optim = config.optim 61 | optim.lr = 2e-5 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/church.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on church_outdoor.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'church_outdoor' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on cifar-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/cifar10_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_jax/configs/vp/ddpm/cifar10_unconditional.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = False 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_jax/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /examples/score_sde_jax/requirements.txt: -------------------------------------------------------------------------------- 1 | ml-collections==0.1.0 2 | tensorflow-gan==2.0.0 3 | tensorflow_io 4 | tensorflow_datasets==3.1.0 5 | tensorflow==2.4.0 6 | tensorflow-addons==0.12.0 7 | tensorboard==2.4.0 8 | absl-py==0.10.0 9 | flax==0.3.1 10 | jax==0.2.8 11 | jaxlib==0.1.59 12 | -------------------------------------------------------------------------------- /examples/score_sde_jax/sample.sh: -------------------------------------------------------------------------------- 1 | devices="0" 2 | 3 | steps="10" 4 | eps="1e-3" 5 | skip="logSNR" 6 | method="singlestep" 7 | order="3" 8 | dir="experiments/cifar10_ddpmpp_deep_continuous_steps" 9 | 10 | CUDA_VISIBLE_DEVICES=$devices python main.py --config "configs/vp/cifar10_ddpmpp_deep_continuous.py" --mode "eval" --workdir $dir --config.sampling.eps=$eps --config.sampling.method="dpm_solver" --config.sampling.steps=$steps --config.sampling.skip_type=$skip --config.sampling.dpm_solver_order=$order --config.sampling.dpm_solver_method=$method --config.eval.batch_size=1000 11 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | 23 | experiments 24 | assets/stats -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/bedroom.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/bedroom.jpeg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/celebahq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/celebahq_256.jpg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/church.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/church.jpeg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/ffhq_1024.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/ffhq_1024.jpeg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/ffhq_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/ffhq_256.jpg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/ffhq_samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/ffhq_samples.jpg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/assets/schematic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/score_sde_pytorch/assets/schematic.jpg -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/subvp/cifar10_ddpm_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with sub-VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'subvpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/subvp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/subvp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with sub-VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'subvpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/bedroom_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on bedroom with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'bedroom' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = True 43 | model.ema_rate = 0.999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 48 | model.num_res_blocks = 2 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'output_skip' 57 | model.progressive_input = 'input_skip' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.init_scale = 0. 61 | model.fourier_scale = 16 62 | model.conv_size = 3 63 | 64 | return config 65 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/celeba_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CelebA with SMLD.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.sigma_begin = 90 40 | model.ema_rate = 0.999 41 | model.normalization = 'GroupNorm' 42 | model.nonlinearity = 'swish' 43 | model.nf = 128 44 | model.ch_mult = (1, 2, 2, 2) 45 | model.num_res_blocks = 4 46 | model.attn_resolutions = (16,) 47 | model.resamp_with_conv = True 48 | model.conditional = True 49 | model.fir = True 50 | model.fir_kernel = [1, 3, 3, 1] 51 | model.skip_rescale = True 52 | model.resblock_type = 'biggan' 53 | model.progressive = 'none' 54 | model.progressive_input = 'residual' 55 | model.progressive_combine = 'sum' 56 | model.attention_type = 'ddpm' 57 | model.init_scale = 0.0 58 | model.conv_size = 3 59 | model.embedding_type = 'positional' 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/church_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on Church with VE SDE.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # data 36 | data = config.data 37 | data.category = 'church_outdoor' 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.sigma_max = 380 43 | model.scale_by_sigma = True 44 | model.ema_rate = 0.999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 1, 2, 2, 2, 2, 2) 49 | model.num_res_blocks = 2 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'output_skip' 58 | model.progressive_input = 'input_skip' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/cifar10_ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Train the original DDPM model with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ddpm' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 2 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.conv_size = 3 49 | 50 | return config 51 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with SMLD.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'reverse_diffusion' 33 | sampling.corrector = 'langevin' 34 | 35 | # model 36 | model = config.model 37 | model.name = 'ncsnpp' 38 | model.scale_by_sigma = True 39 | model.ema_rate = 0.999 40 | model.normalization = 'GroupNorm' 41 | model.nonlinearity = 'swish' 42 | model.nf = 128 43 | model.ch_mult = (1, 2, 2, 2) 44 | model.num_res_blocks = 4 45 | model.attn_resolutions = (16,) 46 | model.resamp_with_conv = True 47 | model.conditional = True 48 | model.fir = True 49 | model.fir_kernel = [1, 3, 3, 1] 50 | model.skip_rescale = True 51 | model.resblock_type = 'biggan' 52 | model.progressive = 'none' 53 | model.progressive_input = 'residual' 54 | model.progressive_combine = 'sum' 55 | model.attention_type = 'ddpm' 56 | model.init_scale = 0.0 57 | model.embedding_type = 'positional' 58 | model.conv_size = 3 59 | 60 | return config 61 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vesde' 26 | training.continuous = True 27 | 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'reverse_diffusion' 32 | sampling.corrector = 'langevin' 33 | 34 | # model 35 | model = config.model 36 | model.name = 'ncsnpp' 37 | model.scale_by_sigma = True 38 | model.ema_rate = 0.999 39 | model.normalization = 'GroupNorm' 40 | model.nonlinearity = 'swish' 41 | model.nf = 128 42 | model.ch_mult = (1, 2, 2, 2) 43 | model.num_res_blocks = 4 44 | model.attn_resolutions = (16,) 45 | model.resamp_with_conv = True 46 | model.conditional = True 47 | model.fir = True 48 | model.fir_kernel = [1, 3, 3, 1] 49 | model.skip_rescale = True 50 | model.resblock_type = 'biggan' 51 | model.progressive = 'none' 52 | model.progressive_input = 'residual' 53 | model.progressive_combine = 'sum' 54 | model.attention_type = 'ddpm' 55 | model.init_scale = 0. 56 | model.fourier_scale = 16 57 | model.conv_size = 3 58 | 59 | return config 60 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/cifar10_ncsnpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = True 28 | training.n_iters = 950001 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'langevin' 35 | 36 | # model 37 | model = config.model 38 | model.name = 'ncsnpp' 39 | model.fourier_scale = 16 40 | model.scale_by_sigma = True 41 | model.ema_rate = 0.999 42 | model.normalization = 'GroupNorm' 43 | model.nonlinearity = 'swish' 44 | model.nf = 128 45 | model.ch_mult = (1, 2, 2, 2) 46 | model.num_res_blocks = 8 47 | model.attn_resolutions = (16,) 48 | model.resamp_with_conv = True 49 | model.conditional = True 50 | model.fir = True 51 | model.fir_kernel = [1, 3, 3, 1] 52 | model.skip_rescale = True 53 | model.resblock_type = 'biggan' 54 | model.progressive = 'none' 55 | model.progressive_input = 'residual' 56 | model.progressive_combine = 'sum' 57 | model.attention_type = 'ddpm' 58 | model.init_scale = 0.0 59 | model.conv_size = 3 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.loss = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/celeba_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/celeba_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1245 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.128 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 500 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/celeba_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv1 model with technique 5 only.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1. 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing NCSNv1 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 100 34 | sampling.snr = 0.316 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0. 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/cifar10_124.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.num_scales = 232 40 | model.ema_rate = 0. 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-3 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/cifar10_1245.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 1,2,4,5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000062 30 | n_steps_each = 5 31 | ckpt_id = 300000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.176 41 | # model 42 | model = config.model 43 | model.name = 'ncsn' 44 | model.scale_by_sigma = False 45 | model.num_scales = 232 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-3 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsn/cifar10_5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSN with technique 5 only.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.snr = 0.316 34 | sampling.n_steps_each = 100 35 | # model 36 | model = config.model 37 | model.name = 'ncsn' 38 | model.scale_by_sigma = False 39 | model.sigma_max = 1 40 | model.num_scales = 10 41 | model.ema_rate = 0.999 42 | model.normalization = 'InstanceNorm++' 43 | model.nonlinearity = 'elu' 44 | model.nf = 128 45 | model.interpolation = 'bilinear' 46 | # optim 47 | optim = config.optim 48 | optim.weight_decay = 0 49 | optim.optimizer = 'Adam' 50 | optim.lr = 1e-3 51 | optim.beta1 = 0.9 52 | optim.amsgrad = False 53 | optim.eps = 1e-8 54 | optim.warmup = 0 55 | optim.grad_clip = -1. 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsnv2/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on bedroom.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.batch_size = 128 27 | training.sde = 'vesde' 28 | training.continuouse = False 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'none' 33 | sampling.corrector = 'ald' 34 | sampling.n_steps_each = 3 35 | sampling.snr = 0.095 36 | # data 37 | data = config.data 38 | data.category = 'bedroom' 39 | data.image_size = 128 40 | # model 41 | model = config.model 42 | model.name = 'ncsnv2_128' 43 | model.scale_by_sigma = True 44 | model.sigma_max = 190 45 | model.num_scales = 1086 46 | model.ema_rate = 0.9999 47 | model.sigma_min = 0.01 48 | model.normalization = 'InstanceNorm++' 49 | model.nonlinearity = 'elu' 50 | model.nf = 128 51 | model.interpolation = 'bilinear' 52 | # optim 53 | optim = config.optim 54 | optim.weight_decay = 0 55 | optim.optimizer = 'Adam' 56 | optim.lr = 1e-4 57 | optim.beta1 = 0.9 58 | optim.amsgrad = False 59 | optim.eps = 1e-8 60 | optim.warmup = 0 61 | optim.grad_clip = -1 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsnv2/celeba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CelebA.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # shared configs for sample generation 29 | step_size = 0.0000033 30 | n_steps_each = 5 31 | ckpt_id = 210000 32 | final_only = True 33 | noise_removal = False 34 | # sampling 35 | sampling = config.sampling 36 | sampling.method = 'pc' 37 | sampling.predictor = 'none' 38 | sampling.corrector = 'ald' 39 | sampling.n_steps_each = 5 40 | sampling.snr = 0.128 41 | # model 42 | model = config.model 43 | model.name = 'ncsnv2_64' 44 | model.scale_by_sigma = True 45 | model.num_scales = 500 46 | model.ema_rate = 0.999 47 | model.normalization = 'InstanceNorm++' 48 | model.nonlinearity = 'elu' 49 | model.nf = 128 50 | model.interpolation = 'bilinear' 51 | # optim 52 | optim = config.optim 53 | optim.weight_decay = 0 54 | optim.optimizer = 'Adam' 55 | optim.lr = 1e-4 56 | optim.beta1 = 0.9 57 | optim.amsgrad = False 58 | optim.eps = 1e-8 59 | optim.warmup = 0 60 | optim.grad_clip = -1. 61 | 62 | return config 63 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/ve/ncsnv2/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for training NCSNv2 on CIFAR-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vesde' 27 | training.continuous = False 28 | # sampling 29 | sampling = config.sampling 30 | sampling.method = 'pc' 31 | sampling.predictor = 'none' 32 | sampling.corrector = 'ald' 33 | sampling.n_steps_each = 5 34 | sampling.snr = 0.176 35 | # model 36 | model = config.model 37 | model.name = 'ncsnv2_64' 38 | model.scale_by_sigma = True 39 | model.num_scales = 232 40 | model.ema_rate = 0.999 41 | model.normalization = 'InstanceNorm++' 42 | model.nonlinearity = 'elu' 43 | model.nf = 128 44 | model.interpolation = 'bilinear' 45 | # optim 46 | optim = config.optim 47 | optim.weight_decay = 0 48 | optim.optimizer = 'Adam' 49 | optim.lr = 1e-4 50 | optim.beta1 = 0.9 51 | optim.amsgrad = False 52 | optim.eps = 1e-8 53 | optim.warmup = 0 54 | optim.grad_clip = -1. 55 | 56 | return config 57 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/cifar10_ddpmpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'ancestral_sampling' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/cifar10_ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with DDPM.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = False 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'pc' 33 | sampling.predictor = 'reverse_diffusion' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = True 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'residual' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0.0 62 | model.embedding_type = 'positional' 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/cifar10_ncsnpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VP SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vpsde' 26 | training.continuous = True 27 | training.reduce_mean = True 28 | 29 | # sampling 30 | sampling = config.sampling 31 | sampling.method = 'pc' 32 | sampling.predictor = 'euler_maruyama' 33 | sampling.corrector = 'none' 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = True 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'residual' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.embedding_type = 'positional' 61 | model.init_scale = 0. 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/bedroom.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'bedroom' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/celebahq.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on bedrooms.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.dataset = 'CelebAHQ' 40 | data.centered = True 41 | data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' 42 | data.image_size = 256 43 | 44 | # model 45 | model = config.model 46 | model.name = 'ddpm' 47 | model.scale_by_sigma = False 48 | model.num_scales = 1000 49 | model.ema_rate = 0.9999 50 | model.normalization = 'GroupNorm' 51 | model.nonlinearity = 'swish' 52 | model.nf = 128 53 | model.ch_mult = (1, 1, 2, 2, 4, 4) 54 | model.num_res_blocks = 2 55 | model.attn_resolutions = (16,) 56 | model.resamp_with_conv = True 57 | model.conditional = True 58 | 59 | # optim 60 | optim = config.optim 61 | optim.lr = 2e-5 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/church.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on church_outdoor.""" 18 | 19 | from configs.default_lsun_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.category = 'church_outdoor' 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ddpm' 45 | model.scale_by_sigma = False 46 | model.num_scales = 1000 47 | model.ema_rate = 0.9999 48 | model.normalization = 'GroupNorm' 49 | model.nonlinearity = 'swish' 50 | model.nf = 128 51 | model.ch_mult = (1, 1, 2, 2, 4, 4) 52 | model.num_res_blocks = 2 53 | model.attn_resolutions = (16,) 54 | model.resamp_with_conv = True 55 | model.conditional = True 56 | 57 | # optim 58 | optim = config.optim 59 | optim.lr = 2e-5 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/cifar10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Config file for reproducing the results of DDPM on cifar-10.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/cifar10_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM with VP SDE.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = True 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'euler_maruyama' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = True 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/configs/vp/ddpm/cifar10_unconditional.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | 25 | # training 26 | training = config.training 27 | training.sde = 'vpsde' 28 | training.continuous = False 29 | training.reduce_mean = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'ancestral_sampling' 35 | sampling.corrector = 'none' 36 | 37 | # data 38 | data = config.data 39 | data.centered = True 40 | 41 | # model 42 | model = config.model 43 | model.name = 'ddpm' 44 | model.scale_by_sigma = False 45 | model.ema_rate = 0.9999 46 | model.normalization = 'GroupNorm' 47 | model.nonlinearity = 'swish' 48 | model.nf = 128 49 | model.ch_mult = (1, 2, 2, 2) 50 | model.num_res_blocks = 2 51 | model.attn_resolutions = (16,) 52 | model.resamp_with_conv = True 53 | model.conditional = False 54 | 55 | return config 56 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/debug.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import matplotlib.pyplot as plt 3 | import io 4 | import csv 5 | import numpy as np 6 | import pandas as pd 7 | import seaborn as sns 8 | import matplotlib 9 | import importlib 10 | import os 11 | import functools 12 | import itertools 13 | import torch 14 | 15 | import torch.nn as nn 16 | import numpy as np 17 | import tensorflow as tf 18 | import tensorflow_datasets as tfds 19 | import tensorflow_gan as tfgan 20 | import tqdm 21 | import io 22 | import inspect 23 | sns.set(font_scale=2) 24 | sns.set(style="whitegrid") 25 | 26 | import models 27 | from models import utils as mutils 28 | from models import ncsnv2 29 | from models import ncsnpp 30 | from models import ddpm as ddpm_model 31 | from models import layerspp 32 | from models import layers 33 | from models import normalization 34 | 35 | #from configs.ncsnpp import cifar10_continuous_ve as configs 36 | from configs.ddpm import cifar10_continuous_vp as configs 37 | config = configs.get_config() 38 | 39 | checkpoint = torch.load('exp/ddpm_continuous_vp.pth') 40 | 41 | #score_model = ncsnpp.NCSNpp(config) 42 | score_model = ddpm_model.DDPM(config) 43 | score_model.load_state_dict(checkpoint) 44 | score_model = score_model.eval() 45 | x = torch.ones(8, 3, 32, 32) 46 | y = torch.tensor([1] * 8) 47 | breakpoint() 48 | with torch.no_grad(): 49 | score = score_model(x, y) -------------------------------------------------------------------------------- /examples/score_sde_pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /examples/score_sde_pytorch/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /examples/score_sde_pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | ml-collections==0.1.0 2 | tensorflow-gan==2.1.0 3 | tensorflow_io==0.22.0 4 | tensorflow_datasets==4.4.0 5 | tensorflow==2.7.0 6 | tensorflow-addons==0.15.0 7 | tensorboard==2.8.0 8 | tensorflow-probability==0.15.0 9 | absl-py==0.10.0 10 | torch>=1.7.0 11 | torchvision 12 | ninja 13 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/sample.sh: -------------------------------------------------------------------------------- 1 | devices="0" 2 | 3 | steps="10" 4 | eps="1e-3" 5 | skip="logSNR" 6 | method="singlestep" 7 | order="3" 8 | dir="experiments/cifar10_ddpmpp_deep_continuous_steps" 9 | 10 | CUDA_VISIBLE_DEVICES=$devices python main.py --config "configs/vp/cifar10_ddpmpp_deep_continuous.py" --mode "eval" --workdir $dir --config.sampling.eps=$eps --config.sampling.method="dpm_solver" --config.sampling.steps=$steps --config.sampling.skip_type=$skip --config.sampling.dpm_solver_order=$order --config.sampling.dpm_solver_method=$method --config.eval.batch_size=1000 11 | -------------------------------------------------------------------------------- /examples/score_sde_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | import os 4 | import logging 5 | 6 | 7 | def restore_checkpoint(ckpt_dir, state, device): 8 | if not tf.io.gfile.exists(ckpt_dir): 9 | tf.io.gfile.makedirs(os.path.dirname(ckpt_dir)) 10 | logging.warning(f"No checkpoint found at {ckpt_dir}. " 11 | f"Returned the same state as input") 12 | return state 13 | else: 14 | loaded_state = torch.load(ckpt_dir, map_location=device) 15 | state['optimizer'].load_state_dict(loaded_state['optimizer']) 16 | state['model'].load_state_dict(loaded_state['model'], strict=False) 17 | state['ema'].load_state_dict(loaded_state['ema']) 18 | state['step'] = loaded_state['step'] 19 | return state 20 | 21 | 22 | def save_checkpoint(ckpt_dir, state): 23 | saved_state = { 24 | 'optimizer': state['optimizer'].state_dict(), 25 | 'model': state['model'].state_dict(), 26 | 'ema': state['ema'].state_dict(), 27 | 'step': state['step'] 28 | } 29 | torch.save(saved_state, ckpt_dir) -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/a-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/a-painting-of-a-fire.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/a-photograph-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/a-photograph-of-a-fire.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/a-shirt-with-a-fire-printed-on-it.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/a-shirt-with-the-inscription-'fire'.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/a-watercolor-painting-of-a-fire.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/birdhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/birdhouse.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/fire.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/inpainting.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/modelfigure.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/rdm-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/rdm-preview.jpg -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/reconstruction1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/reconstruction1.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/reconstruction2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/reconstruction2.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/results.gif -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/rick.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/rick.jpeg -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/mountains-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/mountains-1.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/mountains-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/mountains-2.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/mountains-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/mountains-3.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/upscaling-in.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/img2img/upscaling-out.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/txt2img/000002025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/txt2img/000002025.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/txt2img/000002035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/txt2img/000002035.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/txt2img/merged-0005.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/txt2img/merged-0006.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/stable-samples/txt2img/merged-0007.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/txt2img-convsample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/txt2img-convsample.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/txt2img-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/txt2img-preview.png -------------------------------------------------------------------------------- /examples/stable-diffusion/assets/v1-variants-scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/assets/v1-variants-scores.jpg -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /examples/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /examples/stable-diffusion/data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /examples/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/example_conditioning/superresolution/sample_0.jpg -------------------------------------------------------------------------------- /examples/stable-diffusion/data/example_conditioning/text_conditional/sample_0.txt: -------------------------------------------------------------------------------- 1 | A basket of cerries 2 | -------------------------------------------------------------------------------- /examples/stable-diffusion/data/fruit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/fruit.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/imagenet_train_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/imagenet_train_hr_indices.p -------------------------------------------------------------------------------- /examples/stable-diffusion/data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/bench2.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/bench2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/bench2_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /examples/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png -------------------------------------------------------------------------------- /examples/stable-diffusion/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - transformers==4.19.2 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/data/__init__.py -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler 2 | from .dpm_solver import ( 3 | model_wrapper, 4 | NoiseScheduleVP, 5 | DPM_Solver 6 | ) -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChengTHU/dpm-solver/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/examples/stable-diffusion/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /examples/stable-diffusion/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /examples/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /examples/stable-diffusion/scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /examples/stable-diffusion/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /examples/stable-diffusion/scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /examples/stable-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------