├── README.md ├── StegaStamp ├── models.py ├── train.py ├── train.sh └── utils_img.py ├── guided-diffusion ├── LICENSE ├── datasets │ ├── README.md │ └── lsun_bedroom.py ├── evaluations │ ├── README.md │ ├── evaluator.py │ └── requirements.txt ├── generate.sh ├── guided_diffusion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── dist_util.cpython-310.pyc │ │ ├── fp16_util.cpython-310.pyc │ │ ├── gaussian_diffusion.cpython-310.pyc │ │ ├── image_datasets.cpython-310.pyc │ │ ├── logger.cpython-310.pyc │ │ ├── losses.cpython-310.pyc │ │ ├── nn.cpython-310.pyc │ │ ├── resample.cpython-310.pyc │ │ ├── respace.cpython-310.pyc │ │ ├── script_util.cpython-310.pyc │ │ ├── stega_model.cpython-310.pyc │ │ ├── train_util.cpython-310.pyc │ │ └── unet.cpython-310.pyc │ ├── dist_util.py │ ├── fp16_util.py │ ├── gaussian_diffusion.py │ ├── image_datasets.py │ ├── logger.py │ ├── losses.py │ ├── nn.py │ ├── resample.py │ ├── respace.py │ ├── script_util.py │ ├── stega_model.py │ ├── train_util.py │ └── unet.py ├── scripts │ ├── classifier_sample.py │ ├── classifier_train.py │ ├── image_nll.py │ ├── image_sample.py │ ├── image_train.py │ ├── super_res_sample.py │ └── super_res_train.py ├── setup.py └── train.sh ├── pics └── framework.png ├── stable-diffusion ├── LICENSE ├── Stable_Diffusion_v1_Model_Card.md ├── configs │ ├── autoencoder │ │ ├── autoencoder_kl_16x16x16.yaml │ │ ├── autoencoder_kl_32x32x4.yaml │ │ ├── autoencoder_kl_64x64x3.yaml │ │ └── autoencoder_kl_8x8x64.yaml │ ├── latent-diffusion │ │ ├── celebahq-ldm-vq-4.yaml │ │ ├── cin-ldm-vq-f8.yaml │ │ ├── cin256-v2.yaml │ │ ├── ffhq-ldm-vq-4.yaml │ │ ├── lsun_bedrooms-ldm-vq-4.yaml │ │ ├── lsun_churches-ldm-kl-8.yaml │ │ └── txt2img-1p4B-eval.yaml │ ├── retrieval-augmented-diffusion │ │ └── 768x768.yaml │ └── stable-diffusion │ │ └── v1-inference.yaml ├── environment.yaml ├── ldm │ ├── __pycache__ │ │ └── util.cpython-310.pyc │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ │ ├── __pycache__ │ │ │ └── autoencoder.cpython-310.pyc │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── ddim.cpython-310.pyc │ │ │ ├── ddpm.cpython-310.pyc │ │ │ └── plms.cpython-310.pyc │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── dpm_solver.cpython-310.pyc │ │ │ │ └── sampler.cpython-310.pyc │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ │ └── plms.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── attention.cpython-310.pyc │ │ │ ├── ema.cpython-310.pyc │ │ │ └── x_transformer.cpython-310.pyc │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── model.cpython-310.pyc │ │ │ │ ├── openaimodel.cpython-310.pyc │ │ │ │ └── util.cpython-310.pyc │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── distributions.cpython-310.pyc │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ └── modules.cpython-310.pyc │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── main.py ├── models │ ├── first_stage_models │ │ ├── kl-f16 │ │ │ └── config.yaml │ │ ├── kl-f32 │ │ │ └── config.yaml │ │ ├── kl-f4 │ │ │ └── config.yaml │ │ ├── kl-f8 │ │ │ └── config.yaml │ │ ├── vq-f16 │ │ │ └── config.yaml │ │ ├── vq-f4-noattn │ │ │ └── config.yaml │ │ ├── vq-f4 │ │ │ └── config.yaml │ │ ├── vq-f8-n256 │ │ │ └── config.yaml │ │ └── vq-f8 │ │ │ └── config.yaml │ └── ldm │ │ ├── bsr_sr │ │ └── config.yaml │ │ ├── celeba256 │ │ └── config.yaml │ │ ├── cin256 │ │ └── config.yaml │ │ ├── ffhq256 │ │ └── config.yaml │ │ ├── inpainting_big │ │ └── config.yaml │ │ ├── layout2img-openimages256 │ │ └── config.yaml │ │ ├── lsun_beds256 │ │ └── config.yaml │ │ ├── lsun_churches256 │ │ └── config.yaml │ │ ├── semantic_synthesis256 │ │ └── config.yaml │ │ ├── semantic_synthesis512 │ │ └── config.yaml │ │ └── text2img256 │ │ └── config.yaml ├── notebook_helpers.py ├── outputs │ └── txt2img-samples │ │ └── samples │ │ ├── 00000.png │ │ ├── 00001.png │ │ ├── 00002.png │ │ ├── 00003.png │ │ ├── 00004.png │ │ └── 00005.png ├── scripts │ ├── download_first_stages.sh │ ├── download_models.sh │ ├── img2img.py │ ├── inpaint.py │ ├── knn2img.py │ ├── latent_imagenet_diffusion.ipynb │ ├── sample_diffusion.py │ ├── tests │ │ └── test_watermark.py │ ├── train_searcher.py │ └── txt2img.py └── setup.py └── trace.py /README.md: -------------------------------------------------------------------------------- 1 | ### A Watermark-Conditioned Diffusion Model for IP Protection (ECCV 2024) 2 | This code is the official implementation of [A Watermark-Conditioned Diffusion Model for IP Protection](https://arxiv.org/abs/2403.10893). 3 | 4 | ---- 5 |
6 | 7 | ### Abstract 8 | 9 | The ethical need to protect AI-generated content has been a significant concern in recent years. While existing watermarking strategies have demonstrated success in detecting synthetic content (detection), there has been limited exploration in identifying the users responsible for generating these outputs from a single model (owner identification). In this paper, we focus on both practical scenarios and propose a unified watermarking framework for content copyright protection within the context of diffusion models. Specifically, we consider two parties: the model provider, who grants public access to a diffusion model via an API, and the users, who can solely query the model API and generate images in a black-box manner. Our task is to embed hidden information into the generated contents, which facilitates further detection and owner identification. To tackle this challenge, we propose a Watermark-conditioned Diffusion model called WaDiff, which manipulates the watermark as a conditioned input and incorporates fingerprinting into the generation process. All the generative outputs from our WaDiff carry user-specific information, which can be recovered by an image extractor and further facilitate forensic identification. Extensive experiments are conducted on two popular diffusion models, and we demonstrate that our method is effective and robust in both the detection and owner identification tasks. Meanwhile, our watermarking framework only exerts a negligible impact on the original generation and is more stealthy and efficient in comparison to existing watermarking strategies. 10 | 11 | ### Setup 12 | To configure the environment, you can refer [WatermarkDM](https://github.com/yunqing-me/WatermarkDM) for training StegaStamp decoder, [guided-diffusion](https://github.com/openai/guided-diffusion) for fine-tuning ImageNet diffusion model and [stable-diffusion](https://github.com/CompVis/stable-diffusion) for fine-tuning the Stable Diffusion. 13 | 14 | ### Pipeline 15 | #### Step 1: Pre-train Watermark Decoder 16 | 17 | First, you need to pre-train the watermark encoder and decoder jointly. Go to the [StegaStamp](StegaStamp) folder and simply run: 18 | ```cmd 19 | cd StegaStamp 20 | sh train.sh 21 | ``` 22 | Note that directly running the script may not be successful as you need to specify the path of the training data ```--data_dir``` in your project. Besides, you can customize your experiments by adjusting hyperparameters such as the number of watermark bits ```--bit_length```, image resolution ```--image_resolution```, training epochs ```--num_epochs``` and GPU device ```--cuda```. 23 | 24 | #### Step 2: Fine-tune Diffusion Model 25 | Once you have finished the pre-training process, you can utilize the watermark decoder to guide the diffusion model's fine-tuning process. For the ImageNet Diffusion model, you can run the following commands: 26 | ```cmd 27 | cd ../guided-diffusion 28 | sh train.sh 29 | ``` 30 | But before running the script, you need to configure properly, i.e. the path of the pre-trained decoder checkpoint ```--wm_decoder_path``` (from Step 1) and the path of the training data ```--data_dir``` in your project (mostly the same in Step 1), the number of watermark bits ```--wm_length```, the balance parameter $\alpha$ ```--alpha```, and the time threshold $\tau$ ```--threshold```. Besides, you need to download the pre-trained diffusion model [checkpoint](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt) and put it into the ```models/``` folder. 31 | 32 | #### Step 3: Generate Watermarked Images 33 | After the fine-tuning step, you could use this watermark-conditioned diffusion model to generate watermarked images with the following commands: 34 | ```cmd 35 | sh generate.sh 36 | ``` 37 | All generated images are saved in a default folder named ```saved_images``` (specified by ```--output_path```) and will be organized into individual subfolders indexed by the specific ID of the watermark. You could also change ```--batch_size``` to adjust the number of generated images within individual subfolders. 38 | 39 | #### Step 4: Source Identification 40 | Finally, run the following command to perform tracing: 41 | ```cmd 42 | cd .. 43 | python trace.py 44 | ``` 45 | Note that the ```--image_path``` indicates where you save watermarked images, which should be consistent with the ```--output_path``` specified in Step 3. 46 | 47 | The code for stable diffusion is not finished yet. 48 | 49 | ### Citation 50 | ``` 51 | @article{min2024watermark, 52 | title={A watermark-conditioned diffusion model for ip protection}, 53 | author={Min, Rui and Li, Sen and Chen, Hongyang and Cheng, Minhao}, 54 | journal={arXiv preprint arXiv:2403.10893}, 55 | year={2024} 56 | } 57 | ``` 58 | 59 | #### Our codes are heavily built upon [WatermarkDM](https://github.com/yunqing-me/WatermarkDM), [guided-diffusion](https://github.com/openai/guided-diffusion) and [stable-diffusion](https://github.com/CompVis/stable-diffusion). 60 | -------------------------------------------------------------------------------- /StegaStamp/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn.functional import relu, sigmoid 5 | 6 | 7 | class StegaStampEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | resolution=32, 11 | IMAGE_CHANNELS=1, 12 | fingerprint_size=100, 13 | return_residual=False, 14 | ): 15 | super(StegaStampEncoder, self).__init__() 16 | self.fingerprint_size = fingerprint_size 17 | self.IMAGE_CHANNELS = IMAGE_CHANNELS 18 | self.return_residual = return_residual 19 | self.secret_dense = nn.Linear(self.fingerprint_size, 64 * 64 * IMAGE_CHANNELS) 20 | 21 | log_resolution = int(math.log(resolution, 2)) 22 | assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." 23 | 24 | self.fingerprint_upsample = nn.Upsample(scale_factor=(2**(log_resolution-6), 2**(log_resolution-6))) 25 | self.conv1 = nn.Conv2d(2 * IMAGE_CHANNELS, 32, 3, 1, 1) 26 | self.conv2 = nn.Conv2d(32, 32, 3, 2, 1) 27 | self.conv3 = nn.Conv2d(32, 64, 3, 2, 1) 28 | self.conv4 = nn.Conv2d(64, 128, 3, 2, 1) 29 | self.conv5 = nn.Conv2d(128, 256, 3, 2, 1) 30 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) 31 | self.up6 = nn.Conv2d(256, 128, 2, 1) 32 | self.upsample6 = nn.Upsample(scale_factor=(2, 2)) 33 | self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1) 34 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) 35 | self.up7 = nn.Conv2d(128, 64, 2, 1) 36 | self.upsample7 = nn.Upsample(scale_factor=(2, 2)) 37 | self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1) 38 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) 39 | self.up8 = nn.Conv2d(64, 32, 2, 1) 40 | self.upsample8 = nn.Upsample(scale_factor=(2, 2)) 41 | self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1) 42 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) 43 | self.up9 = nn.Conv2d(32, 32, 2, 1) 44 | self.upsample9 = nn.Upsample(scale_factor=(2, 2)) 45 | self.conv9 = nn.Conv2d(32 + 32 + 2 * IMAGE_CHANNELS, 32, 3, 1, 1) 46 | self.conv10 = nn.Conv2d(32, 32, 3, 1, 1) 47 | self.residual = nn.Conv2d(32, IMAGE_CHANNELS, 1) 48 | 49 | def forward(self, fingerprint, image): 50 | fingerprint = relu(self.secret_dense(fingerprint)) 51 | fingerprint = fingerprint.view((-1, self.IMAGE_CHANNELS, 64, 64)) 52 | fingerprint_enlarged = self.fingerprint_upsample(fingerprint) 53 | inputs = torch.cat([fingerprint_enlarged, image], dim=1) 54 | conv1 = relu(self.conv1(inputs)) 55 | conv2 = relu(self.conv2(conv1)) 56 | conv3 = relu(self.conv3(conv2)) 57 | conv4 = relu(self.conv4(conv3)) 58 | conv5 = relu(self.conv5(conv4)) 59 | up6 = relu(self.up6(self.pad6(self.upsample6(conv5)))) 60 | merge6 = torch.cat([conv4, up6], dim=1) 61 | conv6 = relu(self.conv6(merge6)) 62 | up7 = relu(self.up7(self.pad7(self.upsample7(conv6)))) 63 | merge7 = torch.cat([conv3, up7], dim=1) 64 | conv7 = relu(self.conv7(merge7)) 65 | up8 = relu(self.up8(self.pad8(self.upsample8(conv7)))) 66 | merge8 = torch.cat([conv2, up8], dim=1) 67 | conv8 = relu(self.conv8(merge8)) 68 | up9 = relu(self.up9(self.pad9(self.upsample9(conv8)))) 69 | merge9 = torch.cat([conv1, up9, inputs], dim=1) 70 | conv9 = relu(self.conv9(merge9)) 71 | conv10 = relu(self.conv10(conv9)) 72 | residual = self.residual(conv10) 73 | if not self.return_residual: 74 | residual = sigmoid(residual) 75 | return residual 76 | 77 | 78 | class StegaStampDecoder(nn.Module): 79 | def __init__(self, resolution=32, IMAGE_CHANNELS=1, fingerprint_size=1): 80 | super(StegaStampDecoder, self).__init__() 81 | self.resolution = resolution 82 | self.IMAGE_CHANNELS = IMAGE_CHANNELS 83 | self.decoder = nn.Sequential( 84 | nn.Conv2d(IMAGE_CHANNELS, 32, (3, 3), 2, 1), # 16 85 | nn.ReLU(), 86 | nn.Conv2d(32, 32, 3, 1, 1), 87 | nn.ReLU(), 88 | nn.Conv2d(32, 64, 3, 2, 1), # 8 89 | nn.ReLU(), 90 | nn.Conv2d(64, 64, 3, 1, 1), 91 | nn.ReLU(), 92 | nn.Conv2d(64, 64, 3, 2, 1), # 4 93 | nn.ReLU(), 94 | nn.Conv2d(64, 128, 3, 2, 1), # 2 95 | nn.ReLU(), 96 | nn.Conv2d(128, 128, (3, 3), 2, 1), 97 | nn.ReLU(), 98 | ) 99 | self.dense = nn.Sequential( 100 | nn.Linear(resolution * resolution * 128 // 32 // 32, 512), 101 | nn.ReLU(), 102 | nn.Linear(512, fingerprint_size), 103 | ) 104 | 105 | def forward(self, image): 106 | x = self.decoder(image) 107 | x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32) 108 | return self.dense(x) 109 | -------------------------------------------------------------------------------- /StegaStamp/train.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_dir ./data/coco/train2014 \ 2 | --bit_length 48 --image_resolution 512 --num_epochs 100 --cuda 0 3 | -------------------------------------------------------------------------------- /StegaStamp/utils_img.py: -------------------------------------------------------------------------------- 1 | # This script is modified from "https://github.com/facebookresearch/stable_signature" 2 | 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd.variable import Variable 10 | from torchvision import transforms 11 | from torchvision.transforms import functional 12 | from augly.image import functional as aug_functional 13 | 14 | import kornia.augmentation as K 15 | 16 | from PIL import Image 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 21 | default_transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ]) 25 | image_mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 26 | image_std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 27 | 28 | normalize_rgb = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 29 | unnormalize_rgb = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) 30 | normalize_yuv = transforms.Normalize(mean=[0.5, 0, 0], std=[0.5, 1, 1]) 31 | unnormalize_yuv = transforms.Normalize(mean=[-0.5/0.5, 0, 0], std=[1/0.5, 1/1, 1/1]) 32 | 33 | 34 | def normalize_img(x): 35 | """ Normalize image to approx. [-1,1] """ 36 | return (x - image_mean.to(x.device)) / image_std.to(x.device) 37 | 38 | def unnormalize_img(x): 39 | """ Unnormalize image to [0,1] """ 40 | return (x * image_std.to(x.device)) + image_mean.to(x.device) 41 | 42 | def round_pixel(x): 43 | """ 44 | Round pixel values to nearest integer. 45 | Args: 46 | x: Image tensor with values approx. between [-1,1] 47 | Returns: 48 | y: Rounded image tensor with values approx. between [-1,1] 49 | """ 50 | x_pixel = 255 * unnormalize_img(x) 51 | y = torch.round(x_pixel).clamp(0, 255) 52 | y = normalize_img(y/255.0) 53 | return y 54 | 55 | def clamp_pixel(x): 56 | """ 57 | Clamp pixel values to 0 255. 58 | Args: 59 | x: Image tensor with values approx. between [-1,1] 60 | Returns: 61 | y: Rounded image tensor with values approx. between [-1,1] 62 | """ 63 | x_pixel = 255 * unnormalize_img(x) 64 | y = x_pixel.clamp(0, 255) 65 | y = normalize_img(y/255.0) 66 | return y 67 | 68 | def project_linf(x, y, radius): 69 | """ 70 | Clamp x so that Linf(x,y)<=radius 71 | Args: 72 | x: Image tensor with values approx. between [-1,1] 73 | y: Image tensor with values approx. between [-1,1], ex: original image 74 | radius: Radius of Linf ball for the images in pixel space [0, 255] 75 | """ 76 | delta = x - y 77 | delta = 255 * (delta * image_std.to(x.device)) 78 | delta = torch.clamp(delta, -radius, radius) 79 | delta = (delta / 255.0) / image_std.to(x.device) 80 | return y + delta 81 | 82 | def psnr(x, y): 83 | """ 84 | Return PSNR 85 | Args: 86 | x: Image tensor with values approx. between [-1,1] 87 | y: Image tensor with values approx. between [-1,1], ex: original image 88 | """ 89 | delta = x - y 90 | delta = 255 * (delta * image_std.to(x.device)) 91 | delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW 92 | psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B 93 | return psnr 94 | 95 | def center_crop(x, scale): 96 | """ Perform center crop such that the target area of the crop is at a given scale 97 | Args: 98 | x: PIL image 99 | scale: target area scale 100 | """ 101 | scale = np.sqrt(scale) 102 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 103 | 104 | # left = int(x.size[0]/2-new_edges_size[0]/2) 105 | # upper = int(x.size[1]/2-new_edges_size[1]/2) 106 | # right = left + new_edges_size[0] 107 | # lower = upper + new_edges_size[1] 108 | 109 | # return x.crop((left, upper, right, lower)) 110 | x = functional.center_crop(x, new_edges_size) 111 | return x 112 | 113 | def resize(x, scale): 114 | """ Perform center crop such that the target area of the crop is at a given scale 115 | Args: 116 | x: PIL image 117 | scale: target area scale 118 | """ 119 | scale = np.sqrt(scale) 120 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 121 | return functional.resize(x, new_edges_size) 122 | 123 | def rotate(x, angle): 124 | """ Rotate image by angle 125 | Args: 126 | x: image (PIl or tensor) 127 | angle: angle in degrees 128 | """ 129 | return functional.rotate(x, angle) 130 | 131 | def adjust_brightness(x, brightness_factor): 132 | """ Adjust brightness of an image 133 | Args: 134 | x: PIL image 135 | brightness_factor: brightness factor 136 | """ 137 | return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor)) 138 | 139 | def adjust_contrast(x, contrast_factor): 140 | """ Adjust constrast of an image 141 | Args: 142 | x: PIL image 143 | contrast_factor: contrast factor 144 | """ 145 | return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor)) 146 | 147 | def jpeg_compress(x, quality_factor): 148 | """ Apply jpeg compression to image 149 | Args: 150 | x: Tensor image 151 | quality_factor: quality factor 152 | """ 153 | to_pil = transforms.ToPILImage() 154 | to_tensor = transforms.ToTensor() 155 | img_aug = torch.zeros_like(x, device=x.device) 156 | x = unnormalize_img(x) 157 | for ii,img in enumerate(x): 158 | pil_img = to_pil(img) 159 | img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor)) 160 | return normalize_img(img_aug) 161 | 162 | def gaussian_blur(x, sigma=1): 163 | """ Add gaussian blur to image 164 | Args: 165 | x: Tensor image 166 | sigma: sigma of gaussian kernel 167 | """ 168 | x = unnormalize_img(x) 169 | x = functional.gaussian_blur(x, sigma=sigma, kernel_size=21) 170 | x = normalize_img(x) 171 | return x 172 | -------------------------------------------------------------------------------- /guided-diffusion/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /guided-diffusion/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase. 4 | 5 | ## Class-conditional ImageNet 6 | 7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 8 | 9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 10 | 11 | ``` 12 | for file in *.tar; do tar xf "$file"; rm "$file"; done 13 | ``` 14 | 15 | This will extract and remove each tar file in turn. 16 | 17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 18 | 19 | ## LSUN bedroom 20 | 21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 22 | 23 | ``` 24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 25 | ``` 26 | 27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 28 | -------------------------------------------------------------------------------- /guided-diffusion/datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /guided-diffusion/evaluations/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Download batches 6 | 7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 8 | 9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 10 | 11 | Here are links to download all of the sample and reference batches: 12 | 13 | * LSUN 14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 25 | 26 | * ImageNet 27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 50 | 51 | # Run evaluations 52 | 53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 54 | 55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 56 | 57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 58 | 59 | ``` 60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 61 | ... 62 | computing reference batch activations... 63 | computing/reading reference batch statistics... 64 | computing sample batch activations... 65 | computing/reading sample batch statistics... 66 | Computing evaluations... 67 | Inception Score: 215.8370361328125 68 | FID: 3.9425574129223264 69 | sFID: 6.140433703346162 70 | Precision: 0.8265 71 | Recall: 0.5309 72 | ``` 73 | -------------------------------------------------------------------------------- /guided-diffusion/evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /guided-diffusion/generate.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 2 | SAMPLE_FLAGS="--batch_size 4 --num_samples 8 --timestep_respacing 100 --use_ddim True" 3 | python scripts/image_sample.py $MODEL_FLAGS --output_path saved_images/ --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/dist_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/dist_util.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/image_datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/image_datasets.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/nn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/nn.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/resample.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/resample.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/respace.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/respace.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/script_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/script_util.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/stega_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/stega_model.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/train_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/train_util.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided-diffusion/guided_diffusion/stega_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn.functional import relu, sigmoid 5 | 6 | 7 | class StegaStampEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | resolution=32, 11 | IMAGE_CHANNELS=1, 12 | fingerprint_size=100, 13 | return_residual=False, 14 | ): 15 | super(StegaStampEncoder, self).__init__() 16 | self.fingerprint_size = fingerprint_size 17 | self.IMAGE_CHANNELS = IMAGE_CHANNELS 18 | self.return_residual = return_residual 19 | self.secret_dense = nn.Linear(self.fingerprint_size, 64 * 64 * IMAGE_CHANNELS) 20 | 21 | log_resolution = int(math.log(resolution, 2)) 22 | assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." 23 | 24 | self.fingerprint_upsample = nn.Upsample(scale_factor=(2**(log_resolution-6), 2**(log_resolution-6))) 25 | self.conv1 = nn.Conv2d(2 * IMAGE_CHANNELS, 32, 3, 1, 1) 26 | self.conv2 = nn.Conv2d(32, 32, 3, 2, 1) 27 | self.conv3 = nn.Conv2d(32, 64, 3, 2, 1) 28 | self.conv4 = nn.Conv2d(64, 128, 3, 2, 1) 29 | self.conv5 = nn.Conv2d(128, 256, 3, 2, 1) 30 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) 31 | self.up6 = nn.Conv2d(256, 128, 2, 1) 32 | self.upsample6 = nn.Upsample(scale_factor=(2, 2)) 33 | self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1) 34 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) 35 | self.up7 = nn.Conv2d(128, 64, 2, 1) 36 | self.upsample7 = nn.Upsample(scale_factor=(2, 2)) 37 | self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1) 38 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) 39 | self.up8 = nn.Conv2d(64, 32, 2, 1) 40 | self.upsample8 = nn.Upsample(scale_factor=(2, 2)) 41 | self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1) 42 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) 43 | self.up9 = nn.Conv2d(32, 32, 2, 1) 44 | self.upsample9 = nn.Upsample(scale_factor=(2, 2)) 45 | self.conv9 = nn.Conv2d(32 + 32 + 2 * IMAGE_CHANNELS, 32, 3, 1, 1) 46 | self.conv10 = nn.Conv2d(32, 32, 3, 1, 1) 47 | self.residual = nn.Conv2d(32, IMAGE_CHANNELS, 1) 48 | 49 | def forward(self, fingerprint, image): 50 | fingerprint = relu(self.secret_dense(fingerprint)) 51 | fingerprint = fingerprint.view((-1, self.IMAGE_CHANNELS, 64, 64)) 52 | fingerprint_enlarged = self.fingerprint_upsample(fingerprint) 53 | inputs = torch.cat([fingerprint_enlarged, image], dim=1) 54 | conv1 = relu(self.conv1(inputs)) 55 | conv2 = relu(self.conv2(conv1)) 56 | conv3 = relu(self.conv3(conv2)) 57 | conv4 = relu(self.conv4(conv3)) 58 | conv5 = relu(self.conv5(conv4)) 59 | up6 = relu(self.up6(self.pad6(self.upsample6(conv5)))) 60 | merge6 = torch.cat([conv4, up6], dim=1) 61 | conv6 = relu(self.conv6(merge6)) 62 | up7 = relu(self.up7(self.pad7(self.upsample7(conv6)))) 63 | merge7 = torch.cat([conv3, up7], dim=1) 64 | conv7 = relu(self.conv7(merge7)) 65 | up8 = relu(self.up8(self.pad8(self.upsample8(conv7)))) 66 | merge8 = torch.cat([conv2, up8], dim=1) 67 | conv8 = relu(self.conv8(merge8)) 68 | up9 = relu(self.up9(self.pad9(self.upsample9(conv8)))) 69 | merge9 = torch.cat([conv1, up9, inputs], dim=1) 70 | conv9 = relu(self.conv9(merge9)) 71 | conv10 = relu(self.conv10(conv9)) 72 | residual = self.residual(conv10) 73 | if not self.return_residual: 74 | residual = sigmoid(residual) 75 | return residual 76 | 77 | 78 | class StegaStampDecoder(nn.Module): 79 | def __init__(self, resolution=32, IMAGE_CHANNELS=1, fingerprint_size=1): 80 | super(StegaStampDecoder, self).__init__() 81 | self.resolution = resolution 82 | self.IMAGE_CHANNELS = IMAGE_CHANNELS 83 | self.fingerprint_size=fingerprint_size 84 | self.decoder = nn.Sequential( 85 | nn.Conv2d(IMAGE_CHANNELS, 32, (3, 3), 2, 1), # 16 86 | nn.ReLU(), 87 | nn.Conv2d(32, 32, 3, 1, 1), 88 | nn.ReLU(), 89 | nn.Conv2d(32, 64, 3, 2, 1), # 8 90 | nn.ReLU(), 91 | nn.Conv2d(64, 64, 3, 1, 1), 92 | nn.ReLU(), 93 | nn.Conv2d(64, 64, 3, 2, 1), # 4 94 | nn.ReLU(), 95 | nn.Conv2d(64, 128, 3, 2, 1), # 2 96 | nn.ReLU(), 97 | nn.Conv2d(128, 128, (3, 3), 2, 1), 98 | nn.ReLU(), 99 | ) 100 | self.dense = nn.Sequential( 101 | nn.Linear(resolution * resolution * 128 // 32 // 32, 512), 102 | nn.ReLU(), 103 | nn.Linear(512, fingerprint_size), 104 | ) 105 | 106 | def forward(self, image): 107 | x = self.decoder(image) 108 | x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32) 109 | return self.dense(x) 110 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/classifier_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Like image_sample.py, but use a noisy image classifier to guide the sampling 3 | process towards more realistic images. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | NUM_CLASSES, 17 | model_and_diffusion_defaults, 18 | classifier_defaults, 19 | create_model_and_diffusion, 20 | create_classifier, 21 | add_dict_to_argparser, 22 | args_to_dict, 23 | ) 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | logger.log("creating model and diffusion...") 33 | model, diffusion = create_model_and_diffusion( 34 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 35 | ) 36 | model.load_state_dict( 37 | dist_util.load_state_dict(args.model_path, map_location="cpu") 38 | ) 39 | model.to(dist_util.dev()) 40 | if args.use_fp16: 41 | model.convert_to_fp16() 42 | model.eval() 43 | 44 | logger.log("loading classifier...") 45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys())) 46 | classifier.load_state_dict( 47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu") 48 | ) 49 | classifier.to(dist_util.dev()) 50 | if args.classifier_use_fp16: 51 | classifier.convert_to_fp16() 52 | classifier.eval() 53 | 54 | def cond_fn(x, t, y=None): 55 | assert y is not None 56 | with th.enable_grad(): 57 | x_in = x.detach().requires_grad_(True) 58 | logits = classifier(x_in, t) 59 | log_probs = F.log_softmax(logits, dim=-1) 60 | selected = log_probs[range(len(logits)), y.view(-1)] 61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 62 | 63 | def model_fn(x, t, y=None): 64 | assert y is not None 65 | return model(x, t, y if args.class_cond else None) 66 | 67 | logger.log("sampling...") 68 | all_images = [] 69 | all_labels = [] 70 | while len(all_images) * args.batch_size < args.num_samples: 71 | model_kwargs = {} 72 | classes = th.randint( 73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 74 | ) 75 | model_kwargs["y"] = classes 76 | sample_fn = ( 77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 78 | ) 79 | sample = sample_fn( 80 | model_fn, 81 | (args.batch_size, 3, args.image_size, args.image_size), 82 | clip_denoised=args.clip_denoised, 83 | model_kwargs=model_kwargs, 84 | cond_fn=cond_fn, 85 | device=dist_util.dev(), 86 | ) 87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 88 | sample = sample.permute(0, 2, 3, 1) 89 | sample = sample.contiguous() 90 | 91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] 95 | dist.all_gather(gathered_labels, classes) 96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 97 | logger.log(f"created {len(all_images) * args.batch_size} samples") 98 | 99 | arr = np.concatenate(all_images, axis=0) 100 | arr = arr[: args.num_samples] 101 | label_arr = np.concatenate(all_labels, axis=0) 102 | label_arr = label_arr[: args.num_samples] 103 | if dist.get_rank() == 0: 104 | shape_str = "x".join([str(x) for x in arr.shape]) 105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 106 | logger.log(f"saving to {out_path}") 107 | np.savez(out_path, arr, label_arr) 108 | 109 | dist.barrier() 110 | logger.log("sampling complete") 111 | 112 | 113 | def create_argparser(): 114 | defaults = dict( 115 | clip_denoised=True, 116 | num_samples=10000, 117 | batch_size=16, 118 | use_ddim=False, 119 | model_path="", 120 | classifier_path="", 121 | classifier_scale=1.0, 122 | ) 123 | defaults.update(model_and_diffusion_defaults()) 124 | defaults.update(classifier_defaults()) 125 | parser = argparse.ArgumentParser() 126 | add_dict_to_argparser(parser, defaults) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/image_nll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate the bits/dimension for an image model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | from guided_diffusion import dist_util, logger 12 | from guided_diffusion.image_datasets import load_data 13 | from guided_diffusion.script_util import ( 14 | model_and_diffusion_defaults, 15 | create_model_and_diffusion, 16 | add_dict_to_argparser, 17 | args_to_dict, 18 | ) 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.load_state_dict( 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") 33 | ) 34 | model.to(dist_util.dev()) 35 | model.eval() 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | deterministic=True, 44 | ) 45 | 46 | logger.log("evaluating...") 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) 48 | 49 | 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): 51 | all_bpd = [] 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} 53 | num_complete = 0 54 | while num_complete < num_samples: 55 | batch, model_kwargs = next(data) 56 | batch = batch.to(dist_util.dev()) 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 58 | minibatch_metrics = diffusion.calc_bpd_loop( 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs 60 | ) 61 | 62 | for key, term_list in all_metrics.items(): 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() 64 | dist.all_reduce(terms) 65 | term_list.append(terms.detach().cpu().numpy()) 66 | 67 | total_bpd = minibatch_metrics["total_bpd"] 68 | total_bpd = total_bpd.mean() / dist.get_world_size() 69 | dist.all_reduce(total_bpd) 70 | all_bpd.append(total_bpd.item()) 71 | num_complete += dist.get_world_size() * batch.shape[0] 72 | 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") 74 | 75 | if dist.get_rank() == 0: 76 | for name, terms in all_metrics.items(): 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") 78 | logger.log(f"saving {name} terms to {out_path}") 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) 80 | 81 | dist.barrier() 82 | logger.log("evaluation complete") 83 | 84 | 85 | def create_argparser(): 86 | defaults = dict( 87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import sys 8 | import os 9 | dir_path = os.path.dirname(os.path.realpath(__file__)) 10 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir)) 11 | sys.path.insert(0, parent_dir_path) 12 | 13 | import numpy as np 14 | import torch as th 15 | import torch.distributed as dist 16 | from PIL import Image 17 | 18 | from guided_diffusion import dist_util, logger 19 | from guided_diffusion.script_util import ( 20 | NUM_CLASSES, 21 | model_and_diffusion_defaults, 22 | create_model_and_diffusion, 23 | add_dict_to_argparser, 24 | args_to_dict, 25 | ) 26 | 27 | 28 | 29 | 30 | def save_images(images, output_path): 31 | for i in range(images.shape[0]): 32 | image_array = np.uint8(images[i]) 33 | image = Image.fromarray(image_array) 34 | image.save(os.path.join(output_path, f'image_{i}.png')) 35 | 36 | def main(): 37 | args = create_argparser().parse_args() 38 | 39 | dist_util.setup_dist() 40 | logger.configure() 41 | 42 | logger.log("creating model and diffusion...") 43 | 44 | if not os.path.exists(f'watermark_pool/{args.wm_length}_1e4.npy'): 45 | os.makedirs('watermark_pool', exist_ok=True) 46 | np.save(f'watermark_pool/{args.wm_length}_1e4.npy', np.random.randint(0, 2, size=(int(1e4), args.wm_length))) 47 | np.save(f'watermark_pool/{args.wm_length}_1e5.npy', np.random.randint(0, 2, size=(int(1e5), args.wm_length))) 48 | np.save(f'watermark_pool/{args.wm_length}_1e6.npy', np.random.randint(0, 2, size=(int(1e6), args.wm_length))) 49 | 50 | keys = np.load(f'watermark_pool/{args.wm_length}_1e4.npy')[:1000] 51 | 52 | model, diffusion = create_model_and_diffusion( 53 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=args.wm_length 54 | ) 55 | model.load_state_dict( 56 | dist_util.load_state_dict(args.model_path, map_location="cpu") 57 | ) 58 | model.to(dist_util.dev()) 59 | if args.use_fp16: 60 | model.convert_to_fp16() 61 | model.eval() 62 | 63 | logger.log("sampling...") 64 | all_images = [] 65 | 66 | # while len(all_images) * args.batch_size < args.num_samples: 67 | for idx, key in enumerate(keys): 68 | if args.wm_length < 0: 69 | key = None 70 | else: 71 | key = th.from_numpy(key).to(dist_util.dev()).float() 72 | key = key.repeat(args.batch_size, 1) 73 | 74 | model_kwargs = {} 75 | if args.class_cond: 76 | classes = th.randint( 77 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 78 | ) 79 | model_kwargs["y"] = classes 80 | sample_fn = ( 81 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 82 | ) 83 | sample = sample_fn( 84 | model, 85 | (args.batch_size, 3, args.image_size, args.image_size), 86 | clip_denoised=args.clip_denoised, 87 | model_kwargs=model_kwargs, 88 | key=key, 89 | ) 90 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 91 | sample = sample.permute(0, 2, 3, 1) 92 | sample = sample.contiguous() 93 | 94 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 95 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 96 | # all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 97 | all_images = [sample.cpu().numpy() for sample in gathered_samples] 98 | if args.class_cond: 99 | gathered_labels = [ 100 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 101 | ] 102 | dist.all_gather(gathered_labels, classes) 103 | 104 | logger.log(f"created {len(all_images) * args.batch_size} samples") 105 | 106 | 107 | arr = np.concatenate(all_images, axis=0) 108 | arr = arr[: args.num_samples] 109 | os.makedirs(os.path.join(args.output_path, f'{idx}/'), exist_ok=True) 110 | save_images(arr, os.path.join(args.output_path, f'{idx}/')) 111 | 112 | logger.log(f"saving to {os.path.join(args.output_path, f'{idx}/')}") 113 | 114 | 115 | dist.barrier() 116 | logger.log("sampling complete") 117 | 118 | 119 | def create_argparser(): 120 | defaults = dict( 121 | clip_denoised=True, 122 | num_samples=10, 123 | batch_size=16, 124 | use_ddim=False, 125 | model_path="", 126 | output_path='saved_images/', 127 | wm_length=48, 128 | ) 129 | defaults.update(model_and_diffusion_defaults()) 130 | parser = argparse.ArgumentParser() 131 | add_dict_to_argparser(parser, defaults) 132 | return parser 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | import sys 7 | import os 8 | dir_path = os.path.dirname(os.path.realpath(__file__)) 9 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir)) 10 | sys.path.insert(0, parent_dir_path) 11 | 12 | from guided_diffusion import dist_util, logger 13 | from guided_diffusion.image_datasets import load_data 14 | from guided_diffusion.resample import create_named_schedule_sampler 15 | from guided_diffusion.script_util import ( 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | from guided_diffusion.train_util import TrainLoop 22 | from guided_diffusion.stega_model import StegaStampDecoder 23 | import torch as th 24 | 25 | def main(): 26 | args = create_argparser().parse_args() 27 | 28 | dist_util.setup_dist() 29 | logger.configure() 30 | 31 | logger.log("creating model and diffusion...") 32 | # Change model architecture 33 | 34 | model, diffusion = create_model_and_diffusion( 35 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=args.wm_length 36 | ) 37 | 38 | # Create the original model architecture 39 | ori_model, _ = create_model_and_diffusion( 40 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=0 41 | ) 42 | for param in ori_model.parameters(): 43 | param.requires_grad = False 44 | 45 | ori_model.eval() 46 | ori_model.to(dist_util.dev()) 47 | 48 | model.to(dist_util.dev()) 49 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 50 | 51 | # Define watermark decoder 52 | if args.wm_length > 0 and isinstance(args.wm_length, int): 53 | wm_decoder = StegaStampDecoder( 54 | args.image_size, 55 | 3, 56 | args.wm_length, 57 | ) 58 | # wm_decoder.load_state_dict(th.load(args.wm_decoder_path, map_location='cpu')).eval() 59 | wm_decoder.to(dist_util.dev()) 60 | else: 61 | wm_decoder = None 62 | 63 | logger.log("creating data loader...") 64 | data = load_data( 65 | data_dir=args.data_dir, 66 | batch_size=args.batch_size, 67 | image_size=args.image_size, 68 | class_cond=args.class_cond, 69 | ) 70 | 71 | logger.log("training...") 72 | TrainLoop( 73 | model=model, 74 | diffusion=diffusion, 75 | data=data, 76 | batch_size=args.batch_size, 77 | microbatch=args.microbatch, 78 | lr=args.lr, 79 | ema_rate=args.ema_rate, 80 | log_interval=args.log_interval, 81 | save_interval=args.save_interval, 82 | resume_checkpoint=args.resume_checkpoint, 83 | use_fp16=args.use_fp16, 84 | fp16_scale_growth=args.fp16_scale_growth, 85 | schedule_sampler=schedule_sampler, 86 | weight_decay=args.weight_decay, 87 | lr_anneal_steps=args.lr_anneal_steps, 88 | ori_model=ori_model, 89 | wm_length=args.wm_length, 90 | alpha=args.alpha, 91 | threshold=args.threshold, 92 | wm_decoder=wm_decoder 93 | ).run_loop() 94 | 95 | 96 | def create_argparser(): 97 | defaults = dict( 98 | data_dir="", 99 | schedule_sampler="uniform", 100 | lr=1e-4, 101 | weight_decay=0.0, 102 | lr_anneal_steps=0, 103 | batch_size=1, 104 | microbatch=-1, # -1 disables microbatches 105 | ema_rate="0.9999", # comma-separated list of EMA values 106 | log_interval=10, 107 | save_interval=10000, 108 | resume_checkpoint="", 109 | use_fp16=False, 110 | fp16_scale_growth=1e-3, 111 | wm_length=48, 112 | alpha=0.4, 113 | threshold=400, 114 | wm_decoder_path='./', 115 | ) 116 | defaults.update(model_and_diffusion_defaults()) 117 | parser = argparse.ArgumentParser() 118 | add_dict_to_argparser(parser, defaults) 119 | return parser 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/super_res_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of samples from a super resolution model, given a batch 3 | of samples from a regular model from image_sample.py. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import blobfile as bf 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | sr_model_and_diffusion_defaults, 17 | sr_create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model...") 30 | model, diffusion = sr_create_model_and_diffusion( 31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("loading data...") 42 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) 43 | 44 | logger.log("creating samples...") 45 | all_images = [] 46 | while len(all_images) * args.batch_size < args.num_samples: 47 | model_kwargs = next(data) 48 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 49 | sample = diffusion.p_sample_loop( 50 | model, 51 | (args.batch_size, 3, args.large_size, args.large_size), 52 | clip_denoised=args.clip_denoised, 53 | model_kwargs=model_kwargs, 54 | ) 55 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 56 | sample = sample.permute(0, 2, 3, 1) 57 | sample = sample.contiguous() 58 | 59 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 60 | dist.all_gather(all_samples, sample) # gather not supported with NCCL 61 | for sample in all_samples: 62 | all_images.append(sample.cpu().numpy()) 63 | logger.log(f"created {len(all_images) * args.batch_size} samples") 64 | 65 | arr = np.concatenate(all_images, axis=0) 66 | arr = arr[: args.num_samples] 67 | if dist.get_rank() == 0: 68 | shape_str = "x".join([str(x) for x in arr.shape]) 69 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 70 | logger.log(f"saving to {out_path}") 71 | np.savez(out_path, arr) 72 | 73 | dist.barrier() 74 | logger.log("sampling complete") 75 | 76 | 77 | def load_data_for_worker(base_samples, batch_size, class_cond): 78 | with bf.BlobFile(base_samples, "rb") as f: 79 | obj = np.load(f) 80 | image_arr = obj["arr_0"] 81 | if class_cond: 82 | label_arr = obj["arr_1"] 83 | rank = dist.get_rank() 84 | num_ranks = dist.get_world_size() 85 | buffer = [] 86 | label_buffer = [] 87 | while True: 88 | for i in range(rank, len(image_arr), num_ranks): 89 | buffer.append(image_arr[i]) 90 | if class_cond: 91 | label_buffer.append(label_arr[i]) 92 | if len(buffer) == batch_size: 93 | batch = th.from_numpy(np.stack(buffer)).float() 94 | batch = batch / 127.5 - 1.0 95 | batch = batch.permute(0, 3, 1, 2) 96 | res = dict(low_res=batch) 97 | if class_cond: 98 | res["y"] = th.from_numpy(np.stack(label_buffer)) 99 | yield res 100 | buffer, label_buffer = [], [] 101 | 102 | 103 | def create_argparser(): 104 | defaults = dict( 105 | clip_denoised=True, 106 | num_samples=10000, 107 | batch_size=16, 108 | use_ddim=False, 109 | base_samples="", 110 | model_path="", 111 | ) 112 | defaults.update(sr_model_and_diffusion_defaults()) 113 | parser = argparse.ArgumentParser() 114 | add_dict_to_argparser(parser, defaults) 115 | return parser 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /guided-diffusion/scripts/super_res_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch.nn.functional as F 8 | 9 | from guided_diffusion import dist_util, logger 10 | from guided_diffusion.image_datasets import load_data 11 | from guided_diffusion.resample import create_named_schedule_sampler 12 | from guided_diffusion.script_util import ( 13 | sr_model_and_diffusion_defaults, 14 | sr_create_model_and_diffusion, 15 | args_to_dict, 16 | add_dict_to_argparser, 17 | ) 18 | from guided_diffusion.train_util import TrainLoop 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model...") 28 | model, diffusion = sr_create_model_and_diffusion( 29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_superres_data( 36 | args.data_dir, 37 | args.batch_size, 38 | large_size=args.large_size, 39 | small_size=args.small_size, 40 | class_cond=args.class_cond, 41 | ) 42 | 43 | logger.log("training...") 44 | TrainLoop( 45 | model=model, 46 | diffusion=diffusion, 47 | data=data, 48 | batch_size=args.batch_size, 49 | microbatch=args.microbatch, 50 | lr=args.lr, 51 | ema_rate=args.ema_rate, 52 | log_interval=args.log_interval, 53 | save_interval=args.save_interval, 54 | resume_checkpoint=args.resume_checkpoint, 55 | use_fp16=args.use_fp16, 56 | fp16_scale_growth=args.fp16_scale_growth, 57 | schedule_sampler=schedule_sampler, 58 | weight_decay=args.weight_decay, 59 | lr_anneal_steps=args.lr_anneal_steps, 60 | ).run_loop() 61 | 62 | 63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): 64 | data = load_data( 65 | data_dir=data_dir, 66 | batch_size=batch_size, 67 | image_size=large_size, 68 | class_cond=class_cond, 69 | ) 70 | for large_batch, model_kwargs in data: 71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") 72 | yield large_batch, model_kwargs 73 | 74 | 75 | def create_argparser(): 76 | defaults = dict( 77 | data_dir="", 78 | schedule_sampler="uniform", 79 | lr=1e-4, 80 | weight_decay=0.0, 81 | lr_anneal_steps=0, 82 | batch_size=1, 83 | microbatch=-1, 84 | ema_rate="0.9999", 85 | log_interval=10, 86 | save_interval=10000, 87 | resume_checkpoint="", 88 | use_fp16=False, 89 | fp16_scale_growth=1e-3, 90 | ) 91 | defaults.update(sr_model_and_diffusion_defaults()) 92 | parser = argparse.ArgumentParser() 93 | add_dict_to_argparser(parser, defaults) 94 | return parser 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /guided-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="guided-diffusion", 5 | py_modules=["guided_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /guided-diffusion/train.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --image_size 256 --num_channels 256 --learn_sigma True --num_head_channels 64 --num_res_blocks 2 --resblock_updown True" 2 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear" 3 | TRAIN_FLAGS="--lr 1e-4 --batch_size 4" 4 | NUM_GPUS=1 5 | mpiexec -n $NUM_GPUS python scripts/image_train.py --alpha 0.4 --threshold 400 --wm_decoder_path ./ --data_dir ../../../../data/imagenet/val --resume_checkpoint models/256x256_diffusion_uncond.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 6 | 7 | -------------------------------------------------------------------------------- /pics/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/pics/framework.png -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /stable-diffusion/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /stable-diffusion/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - transformers==4.19.2 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/data/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /stable-diffusion/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/__pycache__/autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/__pycache__/autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 60 | 61 | device = self.model.betas.device 62 | if x_T is None: 63 | img = torch.randn(size, device=device) 64 | else: 65 | img = x_T 66 | 67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 68 | 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="classifier-free", 74 | condition=conditioning, 75 | unconditional_condition=unconditional_conditioning, 76 | guidance_scale=unconditional_guidance_scale, 77 | ) 78 | 79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 81 | 82 | return x.to(device), None 83 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/__pycache__/x_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/x_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /stable-diffusion/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /stable-diffusion/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable-diffusion/models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /stable-diffusion/models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00000.png -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00001.png -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00002.png -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00003.png -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00004.png -------------------------------------------------------------------------------- /stable-diffusion/outputs/txt2img-samples/samples/00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00005.png -------------------------------------------------------------------------------- /stable-diffusion/scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /stable-diffusion/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /stable-diffusion/scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /stable-diffusion/scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /stable-diffusion/scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /stable-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import argparse 4 | from StegaStamp import models 5 | from torchvision import transforms 6 | from PIL import Image 7 | import numpy as np 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--image_path", type=str, default='./guided-diffusion/saved_images/', help="Directory of watermark examples." 13 | ) 14 | parser.add_argument( 15 | "--bit_length", type=int, default=48, help="Length of watermark bits." 16 | ) 17 | parser.add_argument( 18 | "--model_type", type=str, default='imagenet', choices=['imagenet', 'stable'], help="ImageNet Diffusion or Stable Diffusion." 19 | ) 20 | parser.add_argument( 21 | "--checkpoint", type=str, default='./', help="Checkpoint of the watermark decoder." 22 | ) 23 | parser.add_argument( 24 | "--device", type=int, default=0, help="GPU index." 25 | ) 26 | parser.add_argument( 27 | "--detection_thre", type=float, default=0.8, help="Bit threshold for detection." 28 | ) 29 | 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | 35 | def trace(image_path, decoder, detection_thre, device): 36 | 37 | count_pool_1e4, count_pool_1e5, count_pool_1e6, count_detection = 0, 0, 0, 0 38 | decoder.to(args.device) 39 | 40 | # Load pre-defined watermarks and watermarked images 41 | user_pool_1e4 = np.load(f'watermark_pool/{bit_length}_1e4.npy') 42 | user_pool_1e5 = np.load(f'watermark_pool/{bit_length}_1e5.npy') 43 | user_pool_1e6 = np.load(f'watermark_pool/{bit_length}_1e6.npy') 44 | 45 | image_path_list = glob.glob(image_path + '*/*.png') 46 | 47 | 48 | for path in image_path_list: 49 | img = transforms.ToTensor()(Image.open(path)).to(device) 50 | user_index = int(path.split('/')[-2]) 51 | 52 | fingerprints_predicted = (decoder(img) > 0).float().cpu().numpy() 53 | 54 | if 1 - np.abs(fingerprints_predicted - user_pool_1e4[user_index]).sum(0) / len(fingerprints_predicted) > detection_thre 55 | count_detection += 1 56 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e4).sum(0)) == user_index: 57 | count_pool_1e4 += 1 58 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e5).sum(0)) == user_index: 59 | count_pool_1e5 += 1 60 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e6).sum(0)) == user_index: 61 | count_pool_1e6 += 1 62 | 63 | return {'trace1e4': count_pool_1e4/1e4, 'trace1e5': count_pool_1e5/1e5, 'trace1e6': count_pool_1e6/1e6, 'detection_acc': count_detection/len(image_path_list)} 64 | 65 | 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | # Create watermark decoder 71 | decoder = models.StegaStampDecoder( 72 | 256 if args.model_type == 'imagenet' else 512, 73 | 3, 74 | args.bit_length, 75 | ) 76 | decoder.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) 77 | 78 | # Perform tracing 79 | result = trace(args.image_path, decoder, args.detection_thre, args.device) 80 | print(result['trace1e4'], result['trace1e5'], result['trace1e6'], result['detection_acc']) 81 | --------------------------------------------------------------------------------