├── .gitignore ├── LICENSE ├── README.md ├── codebases ├── edm │ ├── LICENSE.txt │ ├── README.md │ ├── compute_fid.py │ ├── dnnlib │ │ ├── __init__.py │ │ └── util.py │ ├── evaluation.py │ ├── sample.py │ ├── sample.sh │ ├── samplers │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ ├── dpm_solver_v3.py │ │ ├── heun.py │ │ ├── uni_pc.py │ │ └── utils.py │ └── torch_utils │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── misc.py │ │ ├── persistence.py │ │ └── training_stats.py ├── guided-diffusion │ ├── README.md │ ├── configs │ │ └── imagenet256_guided.yml │ ├── evaluate │ │ ├── fid_score.py │ │ └── inception.py │ ├── models │ │ └── guided_diffusion │ │ │ ├── __init__.py │ │ │ ├── fp16_util.py │ │ │ ├── logger.py │ │ │ ├── nn.py │ │ │ └── unet.py │ ├── runners │ │ ├── __init__.py │ │ └── diffusion.py │ ├── sample.py │ ├── sample.sh │ └── samplers │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ ├── dpm_solver_v3.py │ │ └── uni_pc.py ├── score_sde │ ├── LICENSE │ ├── README.md │ ├── compute_fid.py │ ├── configs │ │ ├── default_cifar10_configs.py │ │ └── vp │ │ │ └── cifar10_ddpmpp_deep_continuous.py │ ├── evaluation.py │ ├── models │ │ ├── __init__.py │ │ ├── ddpm.py │ │ ├── ema.py │ │ ├── layers.py │ │ ├── layerspp.py │ │ ├── ncsnpp.py │ │ ├── ncsnv2.py │ │ ├── normalization.py │ │ ├── up_or_down_sampling.py │ │ └── utils.py │ ├── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── sample.py │ ├── sample.sh │ ├── samplers │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ ├── dpm_solver_v3.py │ │ ├── uni_pc.py │ │ └── utils.py │ ├── sampling.py │ ├── sde_lib.py │ └── utils.py └── stable-diffusion │ ├── README.md │ ├── configs │ ├── autoencoder │ │ ├── autoencoder_kl_16x16x16.yaml │ │ ├── autoencoder_kl_32x32x4.yaml │ │ ├── autoencoder_kl_64x64x3.yaml │ │ └── autoencoder_kl_8x8x64.yaml │ ├── latent-diffusion │ │ ├── celebahq-ldm-vq-4.yaml │ │ ├── cin-ldm-vq-f8.yaml │ │ ├── cin256-v2.yaml │ │ ├── ffhq-ldm-vq-4.yaml │ │ ├── lsun_bedrooms-ldm-vq-4.yaml │ │ ├── lsun_churches-ldm-kl-8.yaml │ │ └── txt2img-1p4B-eval.yaml │ ├── retrieval-augmented-diffusion │ │ └── 768x768.yaml │ └── stable-diffusion │ │ └── v1-inference.yaml │ ├── ldm │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ ├── dpm_solver_v3 │ │ │ ├── __init__.py │ │ │ ├── dpm_solver_v3.py │ │ │ └── sampler.py │ │ │ ├── plms.py │ │ │ └── uni_pc │ │ │ ├── __init__.py │ │ │ ├── sampler.py │ │ │ └── uni_pc.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py │ ├── sample.sh │ ├── src │ ├── clip │ │ ├── .github │ │ │ └── workflows │ │ │ │ └── test.yml │ │ ├── .gitignore │ │ ├── CLIP.png │ │ ├── LICENSE │ │ ├── MANIFEST.in │ │ ├── README.md │ │ ├── clip │ │ │ ├── __init__.py │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── clip.py │ │ │ ├── model.py │ │ │ └── simple_tokenizer.py │ │ ├── data │ │ │ ├── country211.md │ │ │ ├── prompts.md │ │ │ ├── rendered-sst2.md │ │ │ └── yfcc100m.md │ │ ├── hubconf.py │ │ ├── model-card.md │ │ ├── notebooks │ │ │ ├── Interacting_with_CLIP.ipynb │ │ │ └── Prompt_Engineering_for_ImageNet.ipynb │ │ ├── requirements.txt │ │ ├── setup.py │ │ └── tests │ │ │ └── test_consistency.py │ └── taming-transformers │ │ ├── License.txt │ │ ├── README.md │ │ ├── assets │ │ ├── birddrawnbyachild.png │ │ ├── coco_scene_images_training.svg │ │ ├── drin.jpg │ │ ├── faceshq.jpg │ │ ├── first_stage_mushrooms.png │ │ ├── first_stage_squirrels.png │ │ ├── imagenet.png │ │ ├── lake_in_the_mountains.png │ │ ├── mountain.jpeg │ │ ├── scene_images_samples.svg │ │ ├── stormy.jpeg │ │ ├── sunset_and_ocean.jpg │ │ └── teaser.png │ │ ├── configs │ │ ├── coco_cond_stage.yaml │ │ ├── coco_scene_images_transformer.yaml │ │ ├── custom_vqgan.yaml │ │ ├── drin_transformer.yaml │ │ ├── faceshq_transformer.yaml │ │ ├── faceshq_vqgan.yaml │ │ ├── imagenet_vqgan.yaml │ │ ├── imagenetdepth_vqgan.yaml │ │ ├── open_images_scene_images_transformer.yaml │ │ └── sflckr_cond_stage.yaml │ │ ├── environment.yaml │ │ ├── main.py │ │ ├── scripts │ │ ├── extract_depth.py │ │ ├── extract_segmentation.py │ │ ├── extract_submodel.py │ │ ├── make_samples.py │ │ ├── make_scene_samples.py │ │ ├── reconstruction_usage.ipynb │ │ ├── sample_conditional.py │ │ ├── sample_fast.py │ │ └── taming-transformers.ipynb │ │ ├── setup.py │ │ ├── taming │ │ ├── data │ │ │ ├── ade20k.py │ │ │ ├── annotated_objects_coco.py │ │ │ ├── annotated_objects_dataset.py │ │ │ ├── annotated_objects_open_images.py │ │ │ ├── base.py │ │ │ ├── coco.py │ │ │ ├── conditional_builder │ │ │ │ ├── objects_bbox.py │ │ │ │ ├── objects_center_points.py │ │ │ │ └── utils.py │ │ │ ├── custom.py │ │ │ ├── faceshq.py │ │ │ ├── helper_types.py │ │ │ ├── image_transforms.py │ │ │ ├── imagenet.py │ │ │ ├── open_images_helper.py │ │ │ ├── sflckr.py │ │ │ └── utils.py │ │ ├── lr_scheduler.py │ │ ├── models │ │ │ ├── cond_transformer.py │ │ │ ├── dummy_cond_stage.py │ │ │ └── vqgan.py │ │ ├── modules │ │ │ ├── diffusionmodules │ │ │ │ └── model.py │ │ │ ├── discriminator │ │ │ │ └── model.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── lpips.py │ │ │ │ ├── segmentation.py │ │ │ │ └── vqperceptual.py │ │ │ ├── misc │ │ │ │ └── coord.py │ │ │ ├── transformer │ │ │ │ ├── mingpt.py │ │ │ │ └── permuter.py │ │ │ ├── util.py │ │ │ └── vqvae │ │ │ │ └── quantize.py │ │ └── util.py │ │ └── taming_transformers.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ ├── requires.txt │ │ └── top_level.txt │ └── txt2img.py ├── compute_EMS_scoresde.py └── dpm_solver_v3.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | 23 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 TSAIL group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /codebases/edm/README.md: -------------------------------------------------------------------------------- 1 | # DPM-Solver-v3 (EDM) 2 | ## Preparation 3 | 4 | To generate samples: 5 | 6 | - Download the pretrained models 7 | 8 | ```shell 9 | mkdir -p pretrained 10 | wget -O pretrained/edm-cifar10-32x32-uncond-vp.pkl https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl 11 | ``` 12 | 13 | - Download the folder `edm-cifar10-32x32-uncond-vp` from https://drive.google.com/drive/folders/1sWq-htX9c3Xdajmo1BG-QvkbaeVtJqaq and put it under the folder `statistics/`. 14 | 15 | - Install the packages 16 | 17 | ```shell 18 | pip install absl-py 19 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 20 | ``` 21 | 22 | To compute FIDs: 23 | 24 | - Download `cifar10_stats.npz` from https://drive.google.com/drive/folders/1bofxWSwcoVGRqsUnAGUbco1z5lwP0Rb6 and put it under the folder `assets/stats/` 25 | 26 | - Install the packages 27 | 28 | ```shell 29 | pip install tqdm tensorflow==2.11.0 tensorflow_probability==0.19.0 tensorflow-gan 30 | ``` 31 | 32 | ## Generate Samples 33 | 34 | Run `bash sample.sh`, and the samples of different samplers under different numbers of steps will be generated under the folder `samples/edm-cifar10-32x32-uncond-vp/`. You can modify the script as you wish. 35 | 36 | ## Compute FIDs 37 | 38 | Run `python compute_fid.py`, and the FIDs of the generated samples will be computed and stored to `output.txt`. You can modify the script as you wish. 39 | 40 | -------------------------------------------------------------------------------- /codebases/edm/compute_fid.py: -------------------------------------------------------------------------------- 1 | import tensorflow_gan as tfgan 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | 6 | from evaluation import * 7 | import gc 8 | from tqdm import tqdm 9 | 10 | 11 | inception_model = get_inception_model(inceptionv3=False) 12 | BATCH_SIZE = 1000 13 | 14 | 15 | def load_cifar10_stats(): 16 | """Load the pre-computed dataset statistics.""" 17 | filename = "assets/stats/cifar10_stats.npz" 18 | 19 | with tf.io.gfile.GFile(filename, "rb") as fin: 20 | stats = np.load(fin) 21 | return stats 22 | 23 | 24 | def compute_fid(path): 25 | images = [] 26 | for file in os.listdir(path): 27 | if file.endswith(".npz"): 28 | with tf.io.gfile.GFile(os.path.join(path, file), "rb") as fin: 29 | sample = np.load(fin) 30 | images.append(sample["samples"]) 31 | samples = np.concatenate(images, axis=0) 32 | all_pools = [] 33 | N = samples.shape[0] 34 | assert N >= 50000, "At least 50k samples are required to compute FID." 35 | for i in tqdm(range(N // BATCH_SIZE)): 36 | gc.collect() 37 | latents = run_inception_distributed( 38 | samples[i * BATCH_SIZE : (i + 1) * BATCH_SIZE, ...], inception_model, inceptionv3=False 39 | ) 40 | gc.collect() 41 | all_pools.append(latents["pool_3"]) 42 | all_pools = np.concatenate(all_pools, axis=0)[:50000, ...] 43 | data_stats = load_cifar10_stats() 44 | data_pools = data_stats["pool_3"] 45 | 46 | fid = tfgan.eval.frechet_classifier_distance_from_activations(data_pools, all_pools) 47 | return fid 48 | 49 | 50 | for name in ["dpm_solver++", "heun", "uni_pc_bh1", "uni_pc_bh2", "dpm_solver_v3"]: 51 | fids = [] 52 | for step in [5, 6, 8, 10, 12, 15, 20, 25]: 53 | path = f"samples/edm-cifar10-32x32-uncond-vp/{name}_{step}" 54 | fid = compute_fid(path) 55 | fids.append(float(fid)) 56 | print(name, fids, file=open("output.txt", "a")) 57 | -------------------------------------------------------------------------------- /codebases/edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /codebases/edm/evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for computing FID/Inception scores.""" 17 | 18 | import numpy as np 19 | import six 20 | import tensorflow as tf 21 | import tensorflow_gan as tfgan 22 | import tensorflow_hub as tfhub 23 | import torch 24 | 25 | INCEPTION_TFHUB = "https://tfhub.dev/tensorflow/tfgan/eval/inception/1" 26 | INCEPTION_OUTPUT = "logits" 27 | INCEPTION_FINAL_POOL = "pool_3" 28 | _DEFAULT_DTYPES = {INCEPTION_OUTPUT: tf.float32, INCEPTION_FINAL_POOL: tf.float32} 29 | INCEPTION_DEFAULT_IMAGE_SIZE = 299 30 | 31 | 32 | def get_inception_model(inceptionv3=False): 33 | if inceptionv3: 34 | return tfhub.load("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4") 35 | else: 36 | return tfhub.load(INCEPTION_TFHUB) 37 | 38 | 39 | def load_dataset_stats(config): 40 | """Load the pre-computed dataset statistics.""" 41 | if config.data.dataset == "CIFAR10": 42 | filename = "assets/stats/cifar10_stats.npz" 43 | elif config.data.dataset == "CELEBA": 44 | filename = "assets/stats/celeba_stats.npz" 45 | elif config.data.dataset == "LSUN": 46 | filename = f"assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz" 47 | else: 48 | raise ValueError(f"Dataset {config.data.dataset} stats not found.") 49 | 50 | with tf.io.gfile.GFile(filename, "rb") as fin: 51 | stats = np.load(fin) 52 | return stats 53 | 54 | 55 | def classifier_fn_from_tfhub(output_fields, inception_model, return_tensor=False): 56 | """Returns a function that can be as a classifier function. 57 | 58 | Copied from tfgan but avoid loading the model each time calling _classifier_fn 59 | 60 | Args: 61 | output_fields: A string, list, or `None`. If present, assume the module 62 | outputs a dictionary, and select this field. 63 | inception_model: A model loaded from TFHub. 64 | return_tensor: If `True`, return a single tensor instead of a dictionary. 65 | 66 | Returns: 67 | A one-argument function that takes an image Tensor and returns outputs. 68 | """ 69 | if isinstance(output_fields, six.string_types): 70 | output_fields = [output_fields] 71 | 72 | def _classifier_fn(images): 73 | output = inception_model(images) 74 | if output_fields is not None: 75 | output = {x: output[x] for x in output_fields} 76 | if return_tensor: 77 | assert len(output) == 1 78 | output = list(output.values())[0] 79 | return tf.nest.map_structure(tf.compat.v1.layers.flatten, output) 80 | 81 | return _classifier_fn 82 | 83 | 84 | @tf.function 85 | def run_inception_jit(inputs, inception_model, num_batches=1, inceptionv3=False): 86 | """Running the inception network. Assuming input is within [0, 255].""" 87 | if not inceptionv3: 88 | inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5 89 | else: 90 | inputs = tf.cast(inputs, tf.float32) / 255.0 91 | 92 | return tfgan.eval.run_classifier_fn( 93 | inputs, 94 | num_batches=num_batches, 95 | classifier_fn=classifier_fn_from_tfhub(None, inception_model), 96 | dtypes=_DEFAULT_DTYPES, 97 | ) 98 | 99 | 100 | @tf.function 101 | def run_inception_distributed(input_tensor, inception_model, num_batches=1, inceptionv3=False): 102 | """Distribute the inception network computation to all available TPUs. 103 | 104 | Args: 105 | input_tensor: The input images. Assumed to be within [0, 255]. 106 | inception_model: The inception network model obtained from `tfhub`. 107 | num_batches: The number of batches used for dividing the input. 108 | inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1. 109 | 110 | Returns: 111 | A dictionary with key `pool_3` and `logits`, representing the pool_3 and 112 | logits of the inception network respectively. 113 | """ 114 | num_gpus = torch.cuda.device_count() 115 | # num_gpus = jax.local_device_count() 116 | input_tensors = tf.split(input_tensor, num_gpus, axis=0) 117 | pool3 = [] 118 | logits = [] if not inceptionv3 else None 119 | device_format = "/GPU:{}" 120 | for i, tensor in enumerate(input_tensors): 121 | with tf.device(device_format.format(i)): 122 | tensor_on_device = tf.identity(tensor) 123 | res = run_inception_jit(tensor_on_device, inception_model, num_batches=num_batches, inceptionv3=inceptionv3) 124 | 125 | if not inceptionv3: 126 | pool3.append(res["pool_3"]) 127 | logits.append(res["logits"]) # pytype: disable=attribute-error 128 | else: 129 | pool3.append(res) 130 | 131 | with tf.device("/CPU"): 132 | return {"pool_3": tf.concat(pool3, axis=0), "logits": tf.concat(logits, axis=0) if not inceptionv3 else None} 133 | -------------------------------------------------------------------------------- /codebases/edm/sample.sh: -------------------------------------------------------------------------------- 1 | CKPT_PATH="pretrained/edm-cifar10-32x32-uncond-vp.pkl" 2 | 3 | for steps in 5 6 8 10 12 15 20 25; do 4 | 5 | if [ $steps -lt 10 ]; then 6 | STATS_DIR="statistics/edm-cifar10-32x32-uncond-vp/0.002_80.0_1200_1024" 7 | else 8 | STATS_DIR="statistics/edm-cifar10-32x32-uncond-vp/0.002_80.0_120_4096" 9 | fi 10 | 11 | python sample.py --sample_folder="heun_"$steps --ckp_path=$CKPT_PATH --method=heun --steps=$steps --skip_type=edm 12 | 13 | python sample.py --sample_folder="dpm_solver++_"$steps --ckp_path=$CKPT_PATH --method=dpm_solver++ --steps=$steps --skip_type=logSNR 14 | 15 | python sample.py --sample_folder="uni_pc_bh1_"$steps --unipc_variant=bh1 --ckp_path=$CKPT_PATH --method=uni_pc --steps=$steps --skip_type=logSNR 16 | 17 | python sample.py --sample_folder="uni_pc_bh2_"$steps --unipc_variant=bh2 --ckp_path=$CKPT_PATH --method=uni_pc --steps=$steps --skip_type=logSNR 18 | 19 | python sample.py --sample_folder="dpm_solver_v3_"$steps --statistics_dir=$STATS_DIR --ckp_path=$CKPT_PATH --method=dpm_solver_v3 --steps=$steps --skip_type=logSNR 20 | 21 | done -------------------------------------------------------------------------------- /codebases/edm/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/edm/samplers/__init__.py -------------------------------------------------------------------------------- /codebases/edm/samplers/heun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from .utils import expand_dims 5 | import numpy as np 6 | 7 | 8 | class Heun: 9 | def __init__(self, noise_schedule): 10 | self.noise_schedule = noise_schedule 11 | 12 | def model_fn(self, x, t): 13 | """ 14 | Return the noise prediction model. 15 | """ 16 | return self.model(x, t) 17 | 18 | def get_time_steps(self, skip_type, t_T, t_0, N, device): 19 | """Compute the intermediate time steps for sampling. 20 | 21 | Args: 22 | skip_type: A `str`. The type for the spacing of the time steps. We support three types: 23 | - 'logSNR': uniform logSNR for the time steps. 24 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) 25 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) 26 | t_T: A `float`. The starting time of the sampling (default is T). 27 | t_0: A `float`. The ending time of the sampling (default is epsilon). 28 | N: A `int`. The total number of the spacing of the time steps. 29 | device: A torch device. 30 | Returns: 31 | A pytorch tensor of the time steps, with the shape (N + 1,). 32 | """ 33 | if skip_type == "logSNR": 34 | lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) 35 | lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) 36 | logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) 37 | return self.noise_schedule.inverse_lambda(logSNR_steps) 38 | elif skip_type == "time_uniform": 39 | return torch.linspace(t_T, t_0, N + 1).to(device) 40 | elif skip_type == "time_quadratic": 41 | t_order = 2 42 | t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) 43 | return t 44 | elif skip_type == "edm": 45 | rho = 7.0 # 7.0 is the value used in the paper 46 | 47 | sigma_min: float = t_0 48 | sigma_max: float = t_T 49 | ramp = np.linspace(0, 1, N + 1) 50 | min_inv_rho = sigma_min ** (1 / rho) 51 | max_inv_rho = sigma_max ** (1 / rho) 52 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 53 | lambdas = torch.Tensor(-np.log(sigmas)).to(device) 54 | t = self.noise_schedule.inverse_lambda(lambdas) 55 | return t 56 | else: 57 | raise ValueError( 58 | "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) 59 | ) 60 | 61 | def sample( 62 | self, 63 | model_fn, 64 | x, 65 | steps=20, 66 | t_start=None, 67 | t_end=None, 68 | skip_type="time_uniform", 69 | ): 70 | self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) 71 | t_0 = t_end 72 | t_T = t_start 73 | assert ( 74 | t_0 > 0 and t_T > 0 75 | ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" 76 | device = x.device 77 | denoise_to_zero = (steps % 2) == 1 78 | steps //= 2 79 | with torch.no_grad(): 80 | timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) 81 | assert timesteps.shape[0] - 1 == steps 82 | x_next = x 83 | for step in range(steps): 84 | t_cur, t_next = timesteps[step], timesteps[step + 1] 85 | x_cur = x_next 86 | 87 | # Euler step. 88 | d_cur = self.model_fn(x_cur, t_cur) 89 | x_next = x_cur + (t_next - t_cur) * d_cur 90 | 91 | # Apply 2nd order correction. 92 | d_prime = self.model_fn(x_next, t_next) 93 | x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime) 94 | # print((t_cur, t_next)) 95 | if denoise_to_zero: 96 | t_cur = timesteps[-1] 97 | x_cur = x_next 98 | 99 | # Euler step. 100 | d_cur = self.model_fn(x_cur, t_cur) 101 | x_next = x_cur + (0 - t_cur) * d_cur 102 | return x_next 103 | -------------------------------------------------------------------------------- /codebases/edm/samplers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class NoiseScheduleEDM: 6 | def marginal_log_mean_coeff(self, t): 7 | """ 8 | Compute log(alpha_t) of a given continuous-time label t in [0, T]. 9 | """ 10 | return torch.zeros_like(t).to(torch.float64) 11 | 12 | def marginal_alpha(self, t): 13 | """ 14 | Compute alpha_t of a given continuous-time label t in [0, T]. 15 | """ 16 | return torch.ones_like(t).to(torch.float64) 17 | 18 | def marginal_std(self, t): 19 | """ 20 | Compute sigma_t of a given continuous-time label t in [0, T]. 21 | """ 22 | return t.to(torch.float64) 23 | 24 | def marginal_lambda(self, t): 25 | """ 26 | Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. 27 | """ 28 | 29 | return -torch.log(t).to(torch.float64) 30 | 31 | def inverse_lambda(self, lamb): 32 | """ 33 | Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. 34 | """ 35 | return torch.exp(-lamb).to(torch.float64) 36 | 37 | 38 | def model_wrapper(model, noise_schedule, class_labels=None): 39 | def noise_pred_fn(x, t_continuous, cond=None): 40 | t_input = t_continuous 41 | output = model(x, t_input, cond) 42 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) 43 | return (x - alpha_t[:, None, None, None] * output) / sigma_t[:, None, None, None] 44 | 45 | def model_fn(x, t_continuous): 46 | return noise_pred_fn(x, t_continuous, class_labels).to(torch.float64) 47 | 48 | return model_fn 49 | 50 | 51 | def expand_dims(v, dims): 52 | """ 53 | Expand the tensor `v` to the dim `dims`. 54 | 55 | Args: 56 | `v`: a PyTorch tensor with shape [N]. 57 | `dim`: a `int`. 58 | Returns: 59 | a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. 60 | """ 61 | return v[(...,) + (None,) * (dims - 1)] 62 | -------------------------------------------------------------------------------- /codebases/edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /codebases/edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /codebases/guided-diffusion/README.md: -------------------------------------------------------------------------------- 1 | # DPM-Solver-v3 (Guided-Diffusion) 2 | 3 | ## Preparation 4 | 5 | - Download the pretrained models 6 | 7 | ```shell 8 | mkdir -p ddpm_ckpt/imagenet256 9 | wget -O ddpm_ckpt/imagenet256/256x256_diffusion.pt https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt 10 | wget -O ddpm_ckpt/imagenet256/256x256_classifier.pt https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_classifier.pt 11 | ``` 12 | 13 | - Download the folder `imagenet256_guided` from https://drive.google.com/drive/folders/1sWq-htX9c3Xdajmo1BG-QvkbaeVtJqaq and put it under the folder `statistics/`. 14 | 15 | - Download the stats file for computing FID 16 | 17 | ```shell 18 | mkdir -p fid_stats 19 | wget -O fid_stats/VIRTUAL_imagenet256_labeled.npz https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz 20 | ``` 21 | 22 | - Install the packages 23 | 24 | ```shell 25 | pip install PyYAML tqdm scipy pytorch_fid 26 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 27 | ``` 28 | 29 | ## Generate Samples and Compute FIDs 30 | 31 | Run `bash sample.sh`, and the samples of different samplers under different numbers of steps will be generated under the folder `samples/256x256_diffusion/`. After the samples are generated, their FIDs will be computed and stored in `output.txt`. You can modify the script as you wish. 32 | -------------------------------------------------------------------------------- /codebases/guided-diffusion/configs/imagenet256_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET256" 3 | image_size: 256 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "guided_diffusion" 15 | is_upsampling: false 16 | image_size: 256 17 | in_channels: 3 18 | model_channels: 256 19 | out_channels: 6 20 | num_res_blocks: 2 21 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 22 | dropout: 0.0 23 | channel_mult: [1, 1, 2, 2, 4, 4] 24 | conv_resample: true 25 | dims: 2 26 | num_classes: 1000 27 | use_checkpoint: false 28 | use_fp16: true 29 | num_heads: 4 30 | num_head_channels: 64 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | resblock_updown: true 34 | use_new_attention_order: false 35 | var_type: fixedlarge 36 | ema: false 37 | ckpt_dir: "ddpm_ckpt/imagenet256/256x256_diffusion.pt" 38 | 39 | classifier: 40 | ckpt_dir: "ddpm_ckpt/imagenet256/256x256_classifier.pt" 41 | image_size: 256 42 | in_channels: 3 43 | model_channels: 128 44 | out_channels: 1000 45 | num_res_blocks: 2 46 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 47 | channel_mult: [1, 1, 2, 2, 4, 4] 48 | use_fp16: true 49 | num_head_channels: 64 50 | use_scale_shift_norm: true 51 | resblock_updown: true 52 | pool: "attention" 53 | 54 | diffusion: 55 | beta_schedule: linear 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | num_diffusion_timesteps: 1000 59 | 60 | sampling: 61 | total_N: 1000 62 | batch_size: 25 63 | last_only: True 64 | fid_stats_dir: "fid_stats/VIRTUAL_imagenet256_labeled.npz" 65 | fid_total_samples: 10000 66 | fid_batch_size: 200 67 | cond_class: true 68 | classifier_scale: 2.5 69 | -------------------------------------------------------------------------------- /codebases/guided-diffusion/models/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/guided-diffusion/models/guided_diffusion/__init__.py -------------------------------------------------------------------------------- /codebases/guided-diffusion/models/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads -------------------------------------------------------------------------------- /codebases/guided-diffusion/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/guided-diffusion/runners/__init__.py -------------------------------------------------------------------------------- /codebases/guided-diffusion/sample.sh: -------------------------------------------------------------------------------- 1 | CONFIG="imagenet256_guided.yml" 2 | scale="2.0" 3 | STATS_DIR="statistics/imagenet256_guided/500_1024" 4 | 5 | for steps in 5 6 8 10 12 15 20; do 6 | for sampleMethod in 'dpmsolver++' 'unipc' 'dpmsolver_v3'; do 7 | 8 | python sample.py --config=$CONFIG --exp=$sampleMethod"_"$steps"_scale"$scale --statistics_dir=$STATS_DIR --timesteps=$steps --sample_type=$sampleMethod --scale=$scale --lower_order_final 9 | 10 | done 11 | done -------------------------------------------------------------------------------- /codebases/guided-diffusion/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/guided-diffusion/samplers/__init__.py -------------------------------------------------------------------------------- /codebases/score_sde/README.md: -------------------------------------------------------------------------------- 1 | # DPM-Solver-v3 (ScoreSDE) 2 | 3 | ## Preparation 4 | 5 | To generate samples: 6 | 7 | - Download `checkpoint_8.pth` from https://drive.google.com/drive/folders/1F74y6G_AGqPw8DG5uhdO_Kf9DCX1jKfL and put it under the folder `checkpoints/cifar10_ddpmpp_deep_continuous/`. 8 | 9 | - Download the folder `cifar10_ddpmpp_deep_continuous` from https://drive.google.com/drive/folders/1sWq-htX9c3Xdajmo1BG-QvkbaeVtJqaq and put it under the folder `statistics/`. 10 | 11 | - Install the packages 12 | 13 | ```shell 14 | pip install absl-py ml-collections scipy 15 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 16 | ``` 17 | 18 | To compute FIDs: 19 | 20 | - Download `cifar10_stats.npz` from https://drive.google.com/drive/folders/1bofxWSwcoVGRqsUnAGUbco1z5lwP0Rb6 and put it under the folder `assets/stats/` 21 | 22 | - Install the packages 23 | 24 | ```shell 25 | pip install tqdm tensorflow==2.11.0 tensorflow_probability==0.19.0 tensorflow-gan 26 | ``` 27 | 28 | ## Generate Samples 29 | 30 | Run `bash sample.sh`, and the samples of different samplers under different numbers of steps will be generated under the folder `samples/checkpoint_8/`. You can modify the script as you wish. 31 | 32 | ## Compute FIDs 33 | 34 | Run `python compute_fid.py`, and the FIDs of the generated samples will be computed and stored to `output.txt`. You can modify the script as you wish. 35 | -------------------------------------------------------------------------------- /codebases/score_sde/compute_fid.py: -------------------------------------------------------------------------------- 1 | import tensorflow_gan as tfgan 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | 6 | from evaluation import * 7 | import gc 8 | from tqdm import tqdm 9 | 10 | 11 | inception_model = get_inception_model(inceptionv3=False) 12 | BATCH_SIZE = 1000 13 | 14 | 15 | def load_cifar10_stats(): 16 | """Load the pre-computed dataset statistics.""" 17 | filename = "assets/stats/cifar10_stats.npz" 18 | 19 | with tf.io.gfile.GFile(filename, "rb") as fin: 20 | stats = np.load(fin) 21 | return stats 22 | 23 | 24 | def compute_fid(path): 25 | images = [] 26 | for file in os.listdir(path): 27 | if file.endswith(".npz"): 28 | with tf.io.gfile.GFile(os.path.join(path, file), "rb") as fin: 29 | sample = np.load(fin) 30 | images.append(sample["samples"]) 31 | samples = np.concatenate(images, axis=0) 32 | all_pools = [] 33 | N = samples.shape[0] 34 | assert N >= 50000, "At least 50k samples are required to compute FID." 35 | for i in tqdm(range(N // BATCH_SIZE)): 36 | gc.collect() 37 | latents = run_inception_distributed( 38 | samples[i * BATCH_SIZE : (i + 1) * BATCH_SIZE, ...], inception_model, inceptionv3=False 39 | ) 40 | gc.collect() 41 | all_pools.append(latents["pool_3"]) 42 | all_pools = np.concatenate(all_pools, axis=0)[:50000, ...] 43 | data_stats = load_cifar10_stats() 44 | data_pools = data_stats["pool_3"] 45 | 46 | fid = tfgan.eval.frechet_classifier_distance_from_activations(data_pools, all_pools) 47 | return fid 48 | 49 | for name in ["DPM-Solver++", "UniPC_bh1", "UniPC_bh2", "DPM-Solver-v3"]: 50 | fids = [] 51 | for step in [5, 6, 8, 10, 12, 15, 20, 25]: 52 | path = f"samples/checkpoint_8/{name}_{step}" 53 | fid = compute_fid(path) 54 | fids.append(float(fid)) 55 | print(name, fids, file=open("output.txt", "a")) 56 | -------------------------------------------------------------------------------- /codebases/score_sde/configs/default_cifar10_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 128 10 | training.n_iters = 1300001 11 | training.snapshot_freq = 50000 12 | training.log_freq = 50 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 10000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = False 19 | training.continuous = True 20 | training.reduce_mean = False 21 | 22 | # sampling 23 | config.sampling = sampling = ml_collections.ConfigDict() 24 | sampling.n_steps_each = 1 25 | sampling.noise_removal = True 26 | sampling.probability_flow = False 27 | sampling.snr = 0.16 28 | 29 | # evaluation 30 | config.eval = evaluate = ml_collections.ConfigDict() 31 | evaluate.begin_ckpt = 8 32 | evaluate.end_ckpt = 8 33 | evaluate.batch_size = 256 34 | evaluate.enable_sampling = True 35 | evaluate.num_samples = 50000 36 | evaluate.enable_loss = False 37 | evaluate.enable_bpd = False 38 | evaluate.bpd_dataset = "test" 39 | 40 | # data 41 | config.data = data = ml_collections.ConfigDict() 42 | data.dataset = "CIFAR10" 43 | data.image_size = 32 44 | data.random_flip = True 45 | data.centered = False 46 | data.uniform_dequantization = False 47 | data.num_channels = 3 48 | 49 | # model 50 | config.model = model = ml_collections.ConfigDict() 51 | model.sigma_min = 0.01 52 | model.sigma_max = 50 53 | model.num_scales = 1000 54 | model.beta_min = 0.1 55 | model.beta_max = 20.0 56 | model.dropout = 0.1 57 | model.embedding_type = "fourier" 58 | 59 | # optimization 60 | config.optim = optim = ml_collections.ConfigDict() 61 | optim.weight_decay = 0 62 | optim.optimizer = "Adam" 63 | optim.lr = 2e-4 64 | optim.beta1 = 0.9 65 | optim.eps = 1e-8 66 | optim.warmup = 5000 67 | optim.grad_clip = 1.0 68 | 69 | config.seed = 42 70 | config.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 71 | 72 | return config 73 | -------------------------------------------------------------------------------- /codebases/score_sde/configs/vp/cifar10_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = "vpsde" 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | 31 | # sampling 32 | sampling = config.sampling 33 | 34 | # sampling.method = 'pc' 35 | # sampling.predictor = 'euler_maruyama' 36 | # sampling.corrector = 'none' 37 | 38 | # sampling.method = 'ode' 39 | # sampling.eps = 1e-4 40 | # sampling.noise_removal = False 41 | # sampling.rk45_rtol = 1e-5 42 | # sampling.rk45_atol = 1e-5 43 | 44 | # sampling.method = "dpm_solver" 45 | sampling.dpm_solver_method = "multistep" 46 | sampling.dpm_solver_algorithm_type = "dpmsolver++" 47 | sampling.rtol = 0.05 48 | 49 | # sampling.method = "uni_pc" 50 | sampling.uni_pc_method = "multistep" 51 | sampling.uni_pc_algorithm_type = "data_prediction" 52 | sampling.variant = "bh1" 53 | 54 | # dpm_solver and uni_pc 55 | sampling.thresholding = False 56 | sampling.noise_removal = False 57 | 58 | sampling.method = "dpm_solver_v3" 59 | sampling.eps = 1e-3 60 | sampling.order = 3 61 | sampling.steps = 10 62 | sampling.skip_type = "logSNR" 63 | sampling.predictor_pseudo = False 64 | sampling.use_corrector = True 65 | sampling.corrector_pseudo = False 66 | sampling.lower_order_final = True 67 | sampling.degenerated = False 68 | 69 | # data 70 | data = config.data 71 | data.centered = True 72 | 73 | # model 74 | model = config.model 75 | model.name = "ncsnpp" 76 | model.scale_by_sigma = False 77 | model.ema_rate = 0.9999 78 | model.normalization = "GroupNorm" 79 | model.nonlinearity = "swish" 80 | model.nf = 128 81 | model.ch_mult = (1, 2, 2, 2) 82 | model.num_res_blocks = 8 83 | model.attn_resolutions = (16,) 84 | model.resamp_with_conv = True 85 | model.conditional = True 86 | model.fir = False 87 | model.fir_kernel = [1, 3, 3, 1] 88 | model.skip_rescale = True 89 | model.resblock_type = "biggan" 90 | model.progressive = "none" 91 | model.progressive_input = "none" 92 | model.progressive_combine = "sum" 93 | model.attention_type = "ddpm" 94 | model.init_scale = 0.0 95 | model.embedding_type = "positional" 96 | model.fourier_scale = 16 97 | model.conv_size = 3 98 | 99 | return config 100 | -------------------------------------------------------------------------------- /codebases/score_sde/evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for computing FID/Inception scores.""" 17 | 18 | import numpy as np 19 | import six 20 | import tensorflow as tf 21 | import tensorflow_gan as tfgan 22 | import tensorflow_hub as tfhub 23 | import torch 24 | 25 | INCEPTION_TFHUB = "https://tfhub.dev/tensorflow/tfgan/eval/inception/1" 26 | INCEPTION_OUTPUT = "logits" 27 | INCEPTION_FINAL_POOL = "pool_3" 28 | _DEFAULT_DTYPES = {INCEPTION_OUTPUT: tf.float32, INCEPTION_FINAL_POOL: tf.float32} 29 | INCEPTION_DEFAULT_IMAGE_SIZE = 299 30 | 31 | 32 | def get_inception_model(inceptionv3=False): 33 | if inceptionv3: 34 | return tfhub.load("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4") 35 | else: 36 | return tfhub.load(INCEPTION_TFHUB) 37 | 38 | 39 | def load_dataset_stats(config): 40 | """Load the pre-computed dataset statistics.""" 41 | if config.data.dataset == "CIFAR10": 42 | filename = "assets/stats/cifar10_stats.npz" 43 | elif config.data.dataset == "CELEBA": 44 | filename = "assets/stats/celeba_stats.npz" 45 | elif config.data.dataset == "LSUN": 46 | filename = f"assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz" 47 | else: 48 | raise ValueError(f"Dataset {config.data.dataset} stats not found.") 49 | 50 | with tf.io.gfile.GFile(filename, "rb") as fin: 51 | stats = np.load(fin) 52 | return stats 53 | 54 | 55 | def classifier_fn_from_tfhub(output_fields, inception_model, return_tensor=False): 56 | """Returns a function that can be as a classifier function. 57 | 58 | Copied from tfgan but avoid loading the model each time calling _classifier_fn 59 | 60 | Args: 61 | output_fields: A string, list, or `None`. If present, assume the module 62 | outputs a dictionary, and select this field. 63 | inception_model: A model loaded from TFHub. 64 | return_tensor: If `True`, return a single tensor instead of a dictionary. 65 | 66 | Returns: 67 | A one-argument function that takes an image Tensor and returns outputs. 68 | """ 69 | if isinstance(output_fields, six.string_types): 70 | output_fields = [output_fields] 71 | 72 | def _classifier_fn(images): 73 | output = inception_model(images) 74 | if output_fields is not None: 75 | output = {x: output[x] for x in output_fields} 76 | if return_tensor: 77 | assert len(output) == 1 78 | output = list(output.values())[0] 79 | return tf.nest.map_structure(tf.compat.v1.layers.flatten, output) 80 | 81 | return _classifier_fn 82 | 83 | 84 | @tf.function 85 | def run_inception_jit(inputs, inception_model, num_batches=1, inceptionv3=False): 86 | """Running the inception network. Assuming input is within [0, 255].""" 87 | if not inceptionv3: 88 | inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5 89 | else: 90 | inputs = tf.cast(inputs, tf.float32) / 255.0 91 | 92 | return tfgan.eval.run_classifier_fn( 93 | inputs, 94 | num_batches=num_batches, 95 | classifier_fn=classifier_fn_from_tfhub(None, inception_model), 96 | dtypes=_DEFAULT_DTYPES, 97 | ) 98 | 99 | 100 | @tf.function 101 | def run_inception_distributed(input_tensor, inception_model, num_batches=1, inceptionv3=False): 102 | """Distribute the inception network computation to all available TPUs. 103 | 104 | Args: 105 | input_tensor: The input images. Assumed to be within [0, 255]. 106 | inception_model: The inception network model obtained from `tfhub`. 107 | num_batches: The number of batches used for dividing the input. 108 | inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1. 109 | 110 | Returns: 111 | A dictionary with key `pool_3` and `logits`, representing the pool_3 and 112 | logits of the inception network respectively. 113 | """ 114 | num_gpus = torch.cuda.device_count() 115 | # num_gpus = jax.local_device_count() 116 | input_tensors = tf.split(input_tensor, num_gpus, axis=0) 117 | pool3 = [] 118 | logits = [] if not inceptionv3 else None 119 | device_format = "/GPU:{}" 120 | for i, tensor in enumerate(input_tensors): 121 | with tf.device(device_format.format(i)): 122 | tensor_on_device = tf.identity(tensor) 123 | res = run_inception_jit(tensor_on_device, inception_model, num_batches=num_batches, inceptionv3=inceptionv3) 124 | 125 | if not inceptionv3: 126 | pool3.append(res["pool_3"]) 127 | logits.append(res["logits"]) # pytype: disable=attribute-error 128 | else: 129 | pool3.append(res) 130 | 131 | with tf.device("/CPU"): 132 | return {"pool_3": tf.concat(pool3, axis=0), "logits": tf.concat(logits, axis=0) if not inceptionv3 else None} 133 | -------------------------------------------------------------------------------- /codebases/score_sde/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /codebases/score_sde/models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /codebases/score_sde/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /codebases/score_sde/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /codebases/score_sde/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /codebases/score_sde/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /codebases/score_sde/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /codebases/score_sde/sample.sh: -------------------------------------------------------------------------------- 1 | CKPT_PATH="checkpoints/cifar10_ddpmpp_deep_continuous/checkpoint_8.pth" 2 | CONFIG="configs/vp/cifar10_ddpmpp_deep_continuous.py" 3 | for steps in 5 6 8 10 12 15 20 25; do 4 | 5 | if [ $steps -le 10 ]; then 6 | EPS="1e-3" 7 | STATS_DIR="statistics/cifar10_ddpmpp_deep_continuous/0.001_1200_4096" 8 | if [ $steps -le 8 ]; then 9 | if [ $steps -le 5 ]; then 10 | p_pseudo="True" 11 | lower_order_final="False" 12 | use_corrector="True" 13 | else 14 | p_pseudo="False" 15 | lower_order_final="True" 16 | use_corrector="True" 17 | fi 18 | else 19 | p_pseudo="False" 20 | lower_order_final="True" 21 | use_corrector="False" 22 | fi 23 | else 24 | STATS_DIR="statistics/cifar10_ddpmpp_deep_continuous/0.0001_1200_4096" 25 | EPS="1e-4" 26 | p_pseudo="False" 27 | lower_order_final="True" 28 | use_corrector="True" 29 | fi 30 | 31 | python sample.py --config=$CONFIG --ckp_path=$CKPT_PATH --sample_folder="DPM-Solver++_"$steps --config.sampling.method=dpm_solver --config.sampling.steps=$steps --config.sampling.eps=$EPS 32 | python sample.py --config=$CONFIG --ckp_path=$CKPT_PATH --sample_folder="UniPC_bh1_"$steps --config.sampling.method=uni_pc --config.sampling.steps=$steps --config.sampling.variant=bh1 --config.sampling.eps=$EPS 33 | python sample.py --config=$CONFIG --ckp_path=$CKPT_PATH --sample_folder="UniPC_bh2_"$steps --config.sampling.method=uni_pc --config.sampling.steps=$steps --config.sampling.variant=bh2 --config.sampling.eps=$EPS 34 | 35 | python sample.py --config=$CONFIG --ckp_path=$CKPT_PATH --statistics_dir=$STATS_DIR --sample_folder="DPM-Solver-v3_"$steps --config.sampling.method=dpm_solver_v3 --config.sampling.eps=$EPS --config.sampling.steps=$steps --config.sampling.predictor_pseudo=$p_pseudo --config.sampling.use_corrector=$use_corrector --config.sampling.lower_order_final=$lower_order_final --config.sampling.corrector_pseudo=False 36 | done -------------------------------------------------------------------------------- /codebases/score_sde/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/score_sde/samplers/__init__.py -------------------------------------------------------------------------------- /codebases/score_sde/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | 5 | 6 | def restore_checkpoint(ckpt_dir, state, device): 7 | if not os.path.exists(ckpt_dir): 8 | os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) 9 | logging.warning(f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input") 10 | return state 11 | else: 12 | loaded_state = torch.load(ckpt_dir, map_location=device) 13 | # state['optimizer'].load_state_dict(loaded_state['optimizer']) 14 | state["model"].load_state_dict(loaded_state["model"], strict=False) 15 | state["ema"].load_state_dict(loaded_state["ema"]) 16 | state["step"] = loaded_state["step"] 17 | return state 18 | 19 | 20 | def save_checkpoint(ckpt_dir, state): 21 | saved_state = { 22 | "optimizer": state["optimizer"].state_dict(), 23 | "model": state["model"].state_dict(), 24 | "ema": state["ema"].state_dict(), 25 | "step": state["step"], 26 | } 27 | torch.save(saved_state, ckpt_dir) 28 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/README.md: -------------------------------------------------------------------------------- 1 | # DPM-Solver-v3 (Latent-Diffusion, Stable-Diffusion) 2 | 3 | ## Preparation 4 | 5 | Install the packages 6 | 7 | ```shell 8 | pip install opencv-python omegaconf tqdm einops pytorch-lightning==1.6.5 transformers kornia 9 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 10 | pip install -e ./src/clip/ 11 | pip install -e ./src/taming-transformers/ 12 | ``` 13 | 14 | 15 | For Latent-Diffusion on LSUN-Bedroom: 16 | 17 | - Download the pretrained models 18 | 19 | ```shell 20 | mkdir -p models/first_stage_models/vq-f4 21 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 22 | cd models/first_stage_models/vq-f4 23 | unzip -o model.zip 24 | cd ../../.. 25 | 26 | mkdir -p models/ldm/lsun_beds256 27 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 28 | cd models/ldm/lsun_beds256 29 | unzip -o lsun_beds-256.zip 30 | cd ../../.. 31 | ``` 32 | 33 | - Download the folder `lsun_beds256` from https://drive.google.com/drive/folders/1sWq-htX9c3Xdajmo1BG-QvkbaeVtJqaq and put it under the folder `statistics/`. 34 | 35 | For Stable-Diffusion-v1.4: 36 | 37 | - Download https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt from [CompVis/stable-diffusion-v-1-4-original · Hugging Face](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) and put it under the folder `models/ldm/stable-diffusion-v1/`. 38 | 39 | - Download the folder `sd-v1-4` from https://drive.google.com/drive/folders/1sWq-htX9c3Xdajmo1BG-QvkbaeVtJqaq and put it under the folder `statistics/`. 40 | 41 | 42 | ## Generate Samples 43 | 44 | For Latent-Diffusion on LSUN-Bedroom: 45 | 46 | - Run `bash sample.sh lsun_beds256 ` 47 | 48 | - For example: 49 | 50 | ```shell 51 | bash sample.sh lsun_beds256 5 52 | ``` 53 | 54 | For Stable-Diffusion-v1.4: 55 | 56 | - Run `bash sample.sh sd-v1-4 ` 57 | 58 | - For example: 59 | 60 | ```shell 61 | bash sample.sh sd-v1-4 5 7.5 "A beautiful castle beside a waterfall in the woods, by Josef Thoma, matte painting, trending on artstation HQ" 62 | ``` 63 | 64 | The samples of different samplers will be generated under the folder `outputs/`. You can modify the script as you wish. 65 | 66 | ## Compute FID and MSE 67 | 68 | TODO -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /codebases/stable-diffusion/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: False 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/data/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample( 23 | self, 24 | S, 25 | batch_size, 26 | shape, 27 | conditioning=None, 28 | callback=None, 29 | normals_sequence=None, 30 | img_callback=None, 31 | quantize_x0=False, 32 | eta=0.0, 33 | mask=None, 34 | x0=None, 35 | temperature=1.0, 36 | noise_dropout=0.0, 37 | score_corrector=None, 38 | corrector_kwargs=None, 39 | verbose=True, 40 | x_T=None, 41 | log_every_t=100, 42 | unconditional_guidance_scale=1.0, 43 | unconditional_conditioning=None, 44 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 45 | **kwargs, 46 | ): 47 | if conditioning is not None: 48 | if isinstance(conditioning, dict): 49 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 50 | if cbs != batch_size: 51 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 52 | else: 53 | if conditioning.shape[0] != batch_size: 54 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 55 | 56 | # sampling 57 | C, H, W = shape 58 | size = (batch_size, C, H, W) 59 | 60 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 61 | 62 | device = self.model.betas.device 63 | if x_T is None: 64 | img = torch.randn(size, device=device) 65 | else: 66 | img = x_T 67 | 68 | ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) 69 | 70 | if conditioning is None: 71 | model_fn = model_wrapper( 72 | lambda x, t, c: self.model.apply_model(x, t, c), 73 | ns, 74 | model_type="noise", 75 | guidance_type="uncond", 76 | ) 77 | ORDER = 3 78 | else: 79 | model_fn = model_wrapper( 80 | lambda x, t, c: self.model.apply_model(x, t, c), 81 | ns, 82 | model_type="noise", 83 | guidance_type="classifier-free", 84 | condition=conditioning, 85 | unconditional_condition=unconditional_conditioning, 86 | guidance_scale=unconditional_guidance_scale, 87 | ) 88 | ORDER = 2 89 | 90 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 91 | x = dpm_solver.sample( 92 | img, steps=S, skip_type="time_uniform", method="multistep", order=ORDER, lower_order_final=True 93 | ) 94 | return x.to(device), None 95 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/dpm_solver_v3/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverv3Sampler -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/dpm_solver_v3/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3 6 | 7 | 8 | class DPMSolverv3Sampler: 9 | def __init__(self, ckp_path, stats_dir, model, steps, guidance_scale, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.alphas_cumprod = to_torch(model.alphas_cumprod) 14 | self.device = self.model.betas.device 15 | self.guidance_scale = guidance_scale 16 | 17 | self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) 18 | 19 | assert stats_dir is not None, f"No statistics file found in {stats_base}." 20 | print("Use statistics", stats_dir) 21 | self.dpm_solver_v3 = DPM_Solver_v3( 22 | statistics_dir=stats_dir, 23 | noise_schedule=self.ns, 24 | steps=steps, 25 | t_start=None, 26 | t_end=None, 27 | skip_type="time_uniform", 28 | degenerated=False, 29 | device=self.device, 30 | ) 31 | self.steps = steps 32 | 33 | @torch.no_grad() 34 | def sample( 35 | self, 36 | batch_size, 37 | shape, 38 | conditioning=None, 39 | x_T=None, 40 | unconditional_conditioning=None, 41 | use_corrector=False, 42 | half=False, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs, 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | if x_T is None: 60 | img = torch.randn(size, device=self.device) 61 | else: 62 | img = x_T 63 | 64 | if conditioning is None: 65 | model_fn = model_wrapper( 66 | lambda x, t, c: self.model.apply_model(x, t, c), 67 | self.ns, 68 | model_type="noise", 69 | guidance_type="uncond", 70 | ) 71 | ORDER = 3 72 | else: 73 | model_fn = model_wrapper( 74 | lambda x, t, c: self.model.apply_model(x, t, c), 75 | self.ns, 76 | model_type="noise", 77 | guidance_type="classifier-free", 78 | condition=conditioning, 79 | unconditional_condition=unconditional_conditioning, 80 | guidance_scale=self.guidance_scale, 81 | ) 82 | ORDER = 2 83 | 84 | x = self.dpm_solver_v3.sample( 85 | img, 86 | model_fn, 87 | order=ORDER, 88 | p_pseudo=False, 89 | c_pseudo=True, 90 | lower_order_final=True, 91 | use_corrector=use_corrector, 92 | half=half, 93 | ) 94 | 95 | return x.to(self.device), None 96 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/uni_pc/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import UniPCSampler -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/models/diffusion/uni_pc/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC 6 | 7 | 8 | class UniPCSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample( 23 | self, 24 | S, 25 | batch_size, 26 | shape, 27 | conditioning=None, 28 | callback=None, 29 | normals_sequence=None, 30 | img_callback=None, 31 | quantize_x0=False, 32 | eta=0.0, 33 | mask=None, 34 | x0=None, 35 | temperature=1.0, 36 | noise_dropout=0.0, 37 | score_corrector=None, 38 | corrector_kwargs=None, 39 | verbose=True, 40 | x_T=None, 41 | log_every_t=100, 42 | unconditional_guidance_scale=1.0, 43 | unconditional_conditioning=None, 44 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 45 | **kwargs, 46 | ): 47 | if conditioning is not None: 48 | if isinstance(conditioning, dict): 49 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 50 | if cbs != batch_size: 51 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 52 | else: 53 | if conditioning.shape[0] != batch_size: 54 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 55 | 56 | # sampling 57 | C, H, W = shape 58 | size = (batch_size, C, H, W) 59 | 60 | device = self.model.betas.device 61 | if x_T is None: 62 | img = torch.randn(size, device=device) 63 | else: 64 | img = x_T 65 | 66 | ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) 67 | 68 | if conditioning is None: 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="uncond", 74 | ) 75 | ORDER = 3 76 | else: 77 | model_fn = model_wrapper( 78 | lambda x, t, c: self.model.apply_model(x, t, c), 79 | ns, 80 | model_type="noise", 81 | guidance_type="classifier-free", 82 | condition=conditioning, 83 | unconditional_condition=unconditional_conditioning, 84 | guidance_scale=unconditional_guidance_scale, 85 | ) 86 | ORDER = 2 87 | 88 | uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant="bh2") 89 | x = uni_pc.sample( 90 | img, steps=S, skip_type="time_uniform", method="multistep", order=ORDER, lower_order_final=True 91 | ) 92 | 93 | return x.to(device), None 94 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /codebases/stable-diffusion/sample.sh: -------------------------------------------------------------------------------- 1 | case $1 in 2 | lsun_beds256) 3 | config="configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml" 4 | ckpt="models/ldm/lsun_beds256/model.ckpt" 5 | H=256 6 | W=256 7 | C=3 8 | f=4 9 | scale=0.0 10 | prompt="" 11 | STATS_DIR="statistics/lsun_beds256/120_1024" 12 | ;; 13 | sd-v1-4) 14 | config="configs/stable-diffusion/v1-inference.yaml" 15 | ckpt="models/ldm/stable-diffusion-v1/sd-v1-4.ckpt" 16 | H=512 17 | W=512 18 | C=4 19 | f=8 20 | scale=$3 21 | prompt=$4 22 | STATS_DIR="statistics/sd-v1-4/"$scale"_250_1024" 23 | ;; 24 | esac 25 | 26 | steps=$2 27 | 28 | for sampleMethod in 'dpm_solver++' 'uni_pc' 'dpm_solver_v3'; do 29 | python txt2img.py --prompt "$prompt" --steps $steps --statistics_dir $STATS_DIR --outdir "outputs/"$1"/"$sampleMethod"_steps"$steps"_scale"$scale --method $sampleMethod --scale $scale --config $config --ckpt $ckpt --H $H --W $W --C $C --f $f 30 | done -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | CLIP-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.8] 15 | pytorch-version: [1.7.1, 1.9.1, 1.10.1] 16 | include: 17 | - python-version: 3.8 18 | pytorch-version: 1.7.1 19 | torchvision-version: 0.8.2 20 | - python-version: 3.8 21 | pytorch-version: 1.9.1 22 | torchvision-version: 0.10.1 23 | - python-version: 3.8 24 | pytorch-version: 1.10.1 25 | torchvision-version: 0.11.2 26 | steps: 27 | - uses: conda-incubator/setup-miniconda@v2 28 | - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision=${{ matrix.torchvision-version }} cpuonly -c pytorch 29 | - uses: actions/checkout@v2 30 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 31 | - run: pip install pytest 32 | - run: pip install . 33 | - run: pytest 34 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/clip/CLIP.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/clip/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/data/country211.md: -------------------------------------------------------------------------------- 1 | # The Country211 Dataset 2 | 3 | In the paper, we used an image classification dataset called Country211, to evaluate the model's capability on geolocation. To do so, we filtered the YFCC100m dataset that have GPS coordinate corresponding to a [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes) and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country. 4 | 5 | The following command will download an 11GB archive countaining the images and extract into a subdirectory `country211`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/country211.tgz 9 | tar zxvf country211.tgz 10 | ``` 11 | 12 | These images are a subset of the YFCC100m dataset. Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/data/rendered-sst2.md: -------------------------------------------------------------------------------- 1 | # The Rendered SST2 Dataset 2 | 3 | In the paper, we used an image classification dataset called Rendered SST2, to evaluate the model's capability on optical character recognition. To do so, we rendered the sentences in the [Standford Sentiment Treebank v2](https://nlp.stanford.edu/sentiment/treebank.html) dataset and used those as the input to the CLIP image encoder. 4 | 5 | The following command will download a 131MB archive countaining the images and extract into a subdirectory `rendered-sst2`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz 9 | tar zxvf rendered-sst2.tgz 10 | ``` 11 | 12 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/data/yfcc100m.md: -------------------------------------------------------------------------------- 1 | # The YFCC100M Subset 2 | 3 | In the paper, we performed a dataset ablation using a subset of the YFCC100M dataset and showed that the performance remained largely similar. 4 | 5 | The subset contains 14,829,396 images, about 15% of the full dataset, which have been filtered to only keep those with natural languag titles and/or descriptions in English. 6 | 7 | We provide the list of (line number, photo identifier, photo hash) of each image contained in this subset. These correspond to the first three columns in the dataset's metadata TSV file. 8 | 9 | ```bash 10 | wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2 11 | bunzip2 yfcc100m_subset_data.tsv.bz2 12 | ``` 13 | 14 | Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/hubconf.py: -------------------------------------------------------------------------------- 1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models 2 | import re 3 | import string 4 | 5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] 6 | 7 | # For compatibility (cannot include special characters in function name) 8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} 9 | 10 | def _create_hub_entrypoint(model): 11 | def entrypoint(**kwargs): 12 | return _load(model, **kwargs) 13 | 14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model 15 | 16 | Parameters 17 | ---------- 18 | device : Union[str, torch.device] 19 | The device to put the loaded model 20 | 21 | jit : bool 22 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 23 | 24 | download_root: str 25 | path to download the model files; by default, it uses "~/.cache/clip" 26 | 27 | Returns 28 | ------- 29 | model : torch.nn.Module 30 | The {model} CLIP model 31 | 32 | preprocess : Callable[[PIL.Image], torch.Tensor] 33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 34 | """ 35 | return entrypoint 36 | 37 | def tokenize(): 38 | return _tokenize 39 | 40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} 41 | 42 | globals().update(_entrypoints) -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/clip/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize('model_name', clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device, jit=True) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/birddrawnbyachild.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/birddrawnbyachild.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/drin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/drin.jpg -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/faceshq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/faceshq.jpg -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/first_stage_mushrooms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/first_stage_mushrooms.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/first_stage_squirrels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/first_stage_squirrels.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/imagenet.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/lake_in_the_mountains.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/lake_in_the_mountains.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/mountain.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/mountain.jpeg -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/stormy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/stormy.jpeg -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/sunset_and_ocean.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/sunset_and_ocean.jpg -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DPM-Solver-v3/b615d707e9df512d309c4e59ce228283a608eb8b/codebases/stable-diffusion/src/taming-transformers/assets/teaser.png -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/coco_cond_stage.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQSegmentationModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: "segmentation" 8 | n_labels: 183 9 | ddconfig: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 183 14 | out_ch: 183 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | num_res_blocks: 2 23 | attn_resolutions: 24 | - 16 25 | dropout: 0.0 26 | 27 | lossconfig: 28 | target: taming.modules.losses.segmentation.BCELossWithQuant 29 | params: 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 12 36 | train: 37 | target: taming.data.coco.CocoImagesAndCaptionsTrain 38 | params: 39 | size: 296 40 | crop_size: 256 41 | onehot_segmentation: true 42 | use_stuffthing: true 43 | validation: 44 | target: taming.data.coco.CocoImagesAndCaptionsValidation 45 | params: 46 | size: 256 47 | crop_size: 256 48 | onehot_segmentation: true 49 | use_stuffthing: true 50 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/coco_scene_images_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: objects_bbox 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 8192 10 | block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim) 11 | n_layer: 40 12 | n_head: 16 13 | n_embd: 1408 14 | embd_pdrop: 0.1 15 | resid_pdrop: 0.1 16 | attn_pdrop: 0.1 17 | first_stage_config: 18 | target: taming.models.vqgan.VQModel 19 | params: 20 | ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/ 21 | embed_dim: 256 22 | n_embed: 8192 23 | ddconfig: 24 | double_z: false 25 | z_channels: 256 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 1 33 | - 2 34 | - 2 35 | - 4 36 | num_res_blocks: 2 37 | attn_resolutions: 38 | - 16 39 | dropout: 0.0 40 | lossconfig: 41 | target: taming.modules.losses.DummyLoss 42 | cond_stage_config: 43 | target: taming.models.dummy_cond_stage.DummyCondStage 44 | params: 45 | conditional_key: objects_bbox 46 | 47 | data: 48 | target: main.DataModuleFromConfig 49 | params: 50 | batch_size: 6 51 | train: 52 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 53 | params: 54 | data_path: data/coco_annotations_100 # substitute with path to full dataset 55 | split: train 56 | keys: [image, objects_bbox, file_name, annotations] 57 | no_tokens: 8192 58 | target_image_size: 256 59 | min_object_area: 0.00001 60 | min_objects_per_image: 2 61 | max_objects_per_image: 30 62 | crop_method: random-1d 63 | random_flip: true 64 | use_group_parameter: true 65 | encode_crop: true 66 | validation: 67 | target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco 68 | params: 69 | data_path: data/coco_annotations_100 # substitute with path to full dataset 70 | split: validation 71 | keys: [image, objects_bbox, file_name, annotations] 72 | no_tokens: 8192 73 | target_image_size: 256 74 | min_object_area: 0.00001 75 | min_objects_per_image: 2 76 | max_objects_per_image: 30 77 | crop_method: center 78 | random_flip: false 79 | use_group_parameter: true 80 | encode_crop: true 81 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/custom_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 10000 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 5 32 | num_workers: 8 33 | train: 34 | target: taming.data.custom.CustomTrain 35 | params: 36 | training_images_list_file: some/training.txt 37 | size: 256 38 | validation: 39 | target: taming.data.custom.CustomTest 40 | params: 41 | test_images_list_file: some/test.txt 42 | size: 256 43 | 44 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/drin_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: depth 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 1024 10 | block_size: 512 11 | n_layer: 24 12 | n_head: 16 13 | n_embd: 1024 14 | first_stage_config: 15 | target: taming.models.vqgan.VQModel 16 | params: 17 | ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt 18 | embed_dim: 256 19 | n_embed: 1024 20 | ddconfig: 21 | double_z: false 22 | z_channels: 256 23 | resolution: 256 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 1 30 | - 2 31 | - 2 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: 35 | - 16 36 | dropout: 0.0 37 | lossconfig: 38 | target: taming.modules.losses.DummyLoss 39 | cond_stage_config: 40 | target: taming.models.vqgan.VQModel 41 | params: 42 | ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt 43 | embed_dim: 256 44 | n_embed: 1024 45 | ddconfig: 46 | double_z: false 47 | z_channels: 256 48 | resolution: 256 49 | in_channels: 1 50 | out_ch: 1 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 1 55 | - 2 56 | - 2 57 | - 4 58 | num_res_blocks: 2 59 | attn_resolutions: 60 | - 16 61 | dropout: 0.0 62 | lossconfig: 63 | target: taming.modules.losses.DummyLoss 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 2 69 | num_workers: 8 70 | train: 71 | target: taming.data.imagenet.RINTrainWithDepth 72 | params: 73 | size: 256 74 | validation: 75 | target: taming.data.imagenet.RINValidationWithDepth 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/faceshq_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: coord 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 1024 10 | block_size: 512 11 | n_layer: 24 12 | n_head: 16 13 | n_embd: 1024 14 | first_stage_config: 15 | target: taming.models.vqgan.VQModel 16 | params: 17 | ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt 18 | embed_dim: 256 19 | n_embed: 1024 20 | ddconfig: 21 | double_z: false 22 | z_channels: 256 23 | resolution: 256 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 1 30 | - 2 31 | - 2 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: 35 | - 16 36 | dropout: 0.0 37 | lossconfig: 38 | target: taming.modules.losses.DummyLoss 39 | cond_stage_config: 40 | target: taming.modules.misc.coord.CoordStage 41 | params: 42 | n_embed: 1024 43 | down_factor: 16 44 | 45 | data: 46 | target: main.DataModuleFromConfig 47 | params: 48 | batch_size: 2 49 | num_workers: 8 50 | train: 51 | target: taming.data.faceshq.FacesHQTrain 52 | params: 53 | size: 256 54 | crop_size: 256 55 | coord: True 56 | validation: 57 | target: taming.data.faceshq.FacesHQValidation 58 | params: 59 | size: 256 60 | crop_size: 256 61 | coord: True 62 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/faceshq_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 30001 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 3 32 | num_workers: 8 33 | train: 34 | target: taming.data.faceshq.FacesHQTrain 35 | params: 36 | size: 256 37 | crop_size: 256 38 | validation: 39 | target: taming.data.faceshq.FacesHQValidation 40 | params: 41 | size: 256 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/imagenet_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: False 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 15 | num_res_blocks: 2 16 | attn_resolutions: [16] 17 | dropout: 0.0 18 | 19 | lossconfig: 20 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 21 | params: 22 | disc_conditional: False 23 | disc_in_channels: 3 24 | disc_start: 250001 25 | disc_weight: 0.8 26 | codebook_weight: 1.0 27 | 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 12 32 | num_workers: 24 33 | train: 34 | target: taming.data.imagenet.ImageNetTrain 35 | params: 36 | config: 37 | size: 256 38 | validation: 39 | target: taming.data.imagenet.ImageNetValidation 40 | params: 41 | config: 42 | size: 256 43 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/imagenetdepth_vqgan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: depth 8 | ddconfig: 9 | double_z: False 10 | z_channels: 256 11 | resolution: 256 12 | in_channels: 1 13 | out_ch: 1 14 | ch: 128 15 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 16 | num_res_blocks: 2 17 | attn_resolutions: [16] 18 | dropout: 0.0 19 | 20 | lossconfig: 21 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 22 | params: 23 | disc_conditional: False 24 | disc_in_channels: 1 25 | disc_start: 50001 26 | disc_weight: 0.75 27 | codebook_weight: 1.0 28 | 29 | data: 30 | target: main.DataModuleFromConfig 31 | params: 32 | batch_size: 3 33 | num_workers: 8 34 | train: 35 | target: taming.data.imagenet.ImageNetTrainWithDepth 36 | params: 37 | size: 256 38 | validation: 39 | target: taming.data.imagenet.ImageNetValidationWithDepth 40 | params: 41 | size: 256 42 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/open_images_scene_images_transformer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.cond_transformer.Net2NetTransformer 4 | params: 5 | cond_stage_key: objects_bbox 6 | transformer_config: 7 | target: taming.modules.transformer.mingpt.GPT 8 | params: 9 | vocab_size: 8192 10 | block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim) 11 | n_layer: 36 12 | n_head: 16 13 | n_embd: 1536 14 | embd_pdrop: 0.1 15 | resid_pdrop: 0.1 16 | attn_pdrop: 0.1 17 | first_stage_config: 18 | target: taming.models.vqgan.VQModel 19 | params: 20 | ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/ 21 | embed_dim: 256 22 | n_embed: 8192 23 | ddconfig: 24 | double_z: false 25 | z_channels: 256 26 | resolution: 256 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: 31 | - 1 32 | - 1 33 | - 2 34 | - 2 35 | - 4 36 | num_res_blocks: 2 37 | attn_resolutions: 38 | - 16 39 | dropout: 0.0 40 | lossconfig: 41 | target: taming.modules.losses.DummyLoss 42 | cond_stage_config: 43 | target: taming.models.dummy_cond_stage.DummyCondStage 44 | params: 45 | conditional_key: objects_bbox 46 | 47 | data: 48 | target: main.DataModuleFromConfig 49 | params: 50 | batch_size: 6 51 | train: 52 | target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages 53 | params: 54 | data_path: data/open_images_annotations_100 # substitute with path to full dataset 55 | split: train 56 | keys: [image, objects_bbox, file_name, annotations] 57 | no_tokens: 8192 58 | target_image_size: 256 59 | category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility 60 | category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco 61 | min_object_area: 0.0001 62 | min_objects_per_image: 2 63 | max_objects_per_image: 30 64 | crop_method: random-2d 65 | random_flip: true 66 | use_group_parameter: true 67 | use_additional_parameters: true 68 | encode_crop: true 69 | validation: 70 | target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages 71 | params: 72 | data_path: data/open_images_annotations_100 # substitute with path to full dataset 73 | split: validation 74 | keys: [image, objects_bbox, file_name, annotations] 75 | no_tokens: 8192 76 | target_image_size: 256 77 | category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility 78 | category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco 79 | min_object_area: 0.0001 80 | min_objects_per_image: 2 81 | max_objects_per_image: 30 82 | crop_method: center 83 | random_flip: false 84 | use_group_parameter: true 85 | use_additional_parameters: true 86 | encode_crop: true 87 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/configs/sflckr_cond_stage.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQSegmentationModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | image_key: "segmentation" 8 | n_labels: 182 9 | ddconfig: 10 | double_z: false 11 | z_channels: 256 12 | resolution: 256 13 | in_channels: 182 14 | out_ch: 182 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 1 19 | - 2 20 | - 2 21 | - 4 22 | num_res_blocks: 2 23 | attn_resolutions: 24 | - 16 25 | dropout: 0.0 26 | 27 | lossconfig: 28 | target: taming.modules.losses.segmentation.BCELossWithQuant 29 | params: 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: cutlit.DataModuleFromConfig 34 | params: 35 | batch_size: 12 36 | train: 37 | target: taming.data.sflckr.Examples # adjust 38 | params: 39 | size: 256 40 | validation: 41 | target: taming.data.sflckr.Examples # adjust 42 | params: 43 | size: 256 44 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/environment.yaml: -------------------------------------------------------------------------------- 1 | name: taming 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.0.8 19 | - omegaconf==2.0.0 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - more-itertools>=8.0.0 24 | - transformers==4.3.1 25 | - -e . 26 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/scripts/extract_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import trange 5 | from PIL import Image 6 | 7 | 8 | def get_state(gpu): 9 | import torch 10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") 11 | if gpu: 12 | midas.cuda() 13 | midas.eval() 14 | 15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 16 | transform = midas_transforms.default_transform 17 | 18 | state = {"model": midas, 19 | "transform": transform} 20 | return state 21 | 22 | 23 | def depth_to_rgba(x): 24 | assert x.dtype == np.float32 25 | assert len(x.shape) == 2 26 | y = x.copy() 27 | y.dtype = np.uint8 28 | y = y.reshape(x.shape+(4,)) 29 | return np.ascontiguousarray(y) 30 | 31 | 32 | def rgba_to_depth(x): 33 | assert x.dtype == np.uint8 34 | assert len(x.shape) == 3 and x.shape[2] == 4 35 | y = x.copy() 36 | y.dtype = np.float32 37 | y = y.reshape(x.shape[:2]) 38 | return np.ascontiguousarray(y) 39 | 40 | 41 | def run(x, state): 42 | model = state["model"] 43 | transform = state["transform"] 44 | hw = x.shape[:2] 45 | with torch.no_grad(): 46 | prediction = model(transform((x + 1.0) * 127.5).cuda()) 47 | prediction = torch.nn.functional.interpolate( 48 | prediction.unsqueeze(1), 49 | size=hw, 50 | mode="bicubic", 51 | align_corners=False, 52 | ).squeeze() 53 | output = prediction.cpu().numpy() 54 | return output 55 | 56 | 57 | def get_filename(relpath, level=-2): 58 | # save class folder structure and filename: 59 | fn = relpath.split(os.sep)[level:] 60 | folder = fn[-2] 61 | file = fn[-1].split('.')[0] 62 | return folder, file 63 | 64 | 65 | def save_depth(dataset, path, debug=False): 66 | os.makedirs(path) 67 | N = len(dset) 68 | if debug: 69 | N = 10 70 | state = get_state(gpu=True) 71 | for idx in trange(N, desc="Data"): 72 | ex = dataset[idx] 73 | image, relpath = ex["image"], ex["relpath"] 74 | folder, filename = get_filename(relpath) 75 | # prepare 76 | folderabspath = os.path.join(path, folder) 77 | os.makedirs(folderabspath, exist_ok=True) 78 | savepath = os.path.join(folderabspath, filename) 79 | # run model 80 | xout = run(image, state) 81 | I = depth_to_rgba(xout) 82 | Image.fromarray(I).save("{}.png".format(savepath)) 83 | 84 | 85 | if __name__ == "__main__": 86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation 87 | out = "data/imagenet_depth" 88 | if not os.path.exists(out): 89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) + 90 | "(be prepared that the output size will be larger than ImageNet itself).") 91 | exit(1) 92 | 93 | # go 94 | dset = ImageNetValidation() 95 | abspath = os.path.join(out, "val") 96 | if os.path.exists(abspath): 97 | print("{} exists - not doing anything.".format(abspath)) 98 | else: 99 | print("preparing {}".format(abspath)) 100 | save_depth(dset, abspath) 101 | print("done with validation split") 102 | 103 | dset = ImageNetTrain() 104 | abspath = os.path.join(out, "train") 105 | if os.path.exists(abspath): 106 | print("{} exists - not doing anything.".format(abspath)) 107 | else: 108 | print("preparing {}".format(abspath)) 109 | save_depth(dset, abspath) 110 | print("done with train split") 111 | 112 | print("done done.") 113 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/scripts/extract_segmentation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import numpy as np 3 | import scipy 4 | import torch 5 | import torch.nn as nn 6 | from scipy import ndimage 7 | from tqdm import tqdm, trange 8 | from PIL import Image 9 | import torch.hub 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from 14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth 15 | # and put the path here 16 | CKPT_PATH = "TODO" 17 | 18 | rescale = lambda x: (x + 1.) / 2. 19 | 20 | def rescale_bgr(x): 21 | x = (x+1)*127.5 22 | x = torch.flip(x, dims=[0]) 23 | return x 24 | 25 | 26 | class COCOStuffSegmenter(nn.Module): 27 | def __init__(self, config): 28 | super().__init__() 29 | self.config = config 30 | self.n_labels = 182 31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) 32 | ckpt_path = CKPT_PATH 33 | model.load_state_dict(torch.load(ckpt_path)) 34 | self.model = model 35 | 36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 37 | self.image_transform = torchvision.transforms.Compose([ 38 | torchvision.transforms.Lambda(lambda image: torch.stack( 39 | [normalize(rescale_bgr(x)) for x in image])) 40 | ]) 41 | 42 | def forward(self, x, upsample=None): 43 | x = self._pre_process(x) 44 | x = self.model(x) 45 | if upsample is not None: 46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample) 47 | return x 48 | 49 | def _pre_process(self, x): 50 | x = self.image_transform(x) 51 | return x 52 | 53 | @property 54 | def mean(self): 55 | # bgr 56 | return [104.008, 116.669, 122.675] 57 | 58 | @property 59 | def std(self): 60 | return [1.0, 1.0, 1.0] 61 | 62 | @property 63 | def input_size(self): 64 | return [3, 224, 224] 65 | 66 | 67 | def run_model(img, model): 68 | model = model.eval() 69 | with torch.no_grad(): 70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3])) 71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True) 72 | return segmentation.detach().cpu() 73 | 74 | 75 | def get_input(batch, k): 76 | x = batch[k] 77 | if len(x.shape) == 3: 78 | x = x[..., None] 79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 80 | return x.float() 81 | 82 | 83 | def save_segmentation(segmentation, path): 84 | # --> class label to uint8, save as png 85 | os.makedirs(os.path.dirname(path), exist_ok=True) 86 | assert len(segmentation.shape)==4 87 | assert segmentation.shape[0]==1 88 | for seg in segmentation: 89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) 90 | seg = Image.fromarray(seg) 91 | seg.save(path) 92 | 93 | 94 | def iterate_dataset(dataloader, destpath, model): 95 | os.makedirs(destpath, exist_ok=True) 96 | num_processed = 0 97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"): 98 | try: 99 | img = get_input(batch, "image") 100 | img = img.cuda() 101 | seg = run_model(img, model) 102 | 103 | path = batch["relative_file_path_"][0] 104 | path = os.path.splitext(path)[0] 105 | 106 | path = os.path.join(destpath, path + ".png") 107 | save_segmentation(seg, path) 108 | num_processed += 1 109 | except Exception as e: 110 | print(e) 111 | print("but anyhow..") 112 | 113 | print("Processed {} files. Bye.".format(num_processed)) 114 | 115 | 116 | from taming.data.sflckr import Examples 117 | from torch.utils.data import DataLoader 118 | 119 | if __name__ == "__main__": 120 | dest = sys.argv[1] 121 | batchsize = 1 122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) 123 | 124 | model = COCOStuffSegmenter({}).cuda() 125 | print("Instantiated model.") 126 | 127 | dataset = Examples() 128 | dloader = DataLoader(dataset, batch_size=batchsize) 129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model) 130 | print("done.") 131 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/scripts/extract_submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | inpath = sys.argv[1] 6 | outpath = sys.argv[2] 7 | submodel = "cond_stage_model" 8 | if len(sys.argv) > 3: 9 | submodel = sys.argv[3] 10 | 11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) 12 | 13 | sd = torch.load(inpath, map_location="cpu") 14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) 15 | for k,v in sd["state_dict"].items() 16 | if k.startswith("cond_stage_model"))} 17 | torch.save(new_sd, outpath) 18 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='taming-transformers', 5 | version='0.0.1', 6 | description='Taming Transformers for High-Resolution Image Synthesis', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import _get_image_size as get_image_size 9 | 10 | from taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming_transformers.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: taming-transformers 3 | Version: 0.0.1 4 | Summary: Taming Transformers for High-Resolution Image Synthesis 5 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming_transformers.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | taming_transformers.egg-info/PKG-INFO 4 | taming_transformers.egg-info/SOURCES.txt 5 | taming_transformers.egg-info/dependency_links.txt 6 | taming_transformers.egg-info/requires.txt 7 | taming_transformers.egg-info/top_level.txt -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming_transformers.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming_transformers.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm 4 | -------------------------------------------------------------------------------- /codebases/stable-diffusion/src/taming-transformers/taming_transformers.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------