├── README.md
├── StegaStamp
├── models.py
├── train.py
├── train.sh
└── utils_img.py
├── guided-diffusion
├── LICENSE
├── datasets
│ ├── README.md
│ └── lsun_bedroom.py
├── evaluations
│ ├── README.md
│ ├── evaluator.py
│ └── requirements.txt
├── generate.sh
├── guided_diffusion
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── dist_util.cpython-310.pyc
│ │ ├── fp16_util.cpython-310.pyc
│ │ ├── gaussian_diffusion.cpython-310.pyc
│ │ ├── image_datasets.cpython-310.pyc
│ │ ├── logger.cpython-310.pyc
│ │ ├── losses.cpython-310.pyc
│ │ ├── nn.cpython-310.pyc
│ │ ├── resample.cpython-310.pyc
│ │ ├── respace.cpython-310.pyc
│ │ ├── script_util.cpython-310.pyc
│ │ ├── stega_model.cpython-310.pyc
│ │ ├── train_util.cpython-310.pyc
│ │ └── unet.cpython-310.pyc
│ ├── dist_util.py
│ ├── fp16_util.py
│ ├── gaussian_diffusion.py
│ ├── image_datasets.py
│ ├── logger.py
│ ├── losses.py
│ ├── nn.py
│ ├── resample.py
│ ├── respace.py
│ ├── script_util.py
│ ├── stega_model.py
│ ├── train_util.py
│ └── unet.py
├── scripts
│ ├── classifier_sample.py
│ ├── classifier_train.py
│ ├── image_nll.py
│ ├── image_sample.py
│ ├── image_train.py
│ ├── super_res_sample.py
│ └── super_res_train.py
├── setup.py
└── train.sh
├── pics
└── framework.png
├── stable-diffusion
├── LICENSE
├── Stable_Diffusion_v1_Model_Card.md
├── configs
│ ├── autoencoder
│ │ ├── autoencoder_kl_16x16x16.yaml
│ │ ├── autoencoder_kl_32x32x4.yaml
│ │ ├── autoencoder_kl_64x64x3.yaml
│ │ └── autoencoder_kl_8x8x64.yaml
│ ├── latent-diffusion
│ │ ├── celebahq-ldm-vq-4.yaml
│ │ ├── cin-ldm-vq-f8.yaml
│ │ ├── cin256-v2.yaml
│ │ ├── ffhq-ldm-vq-4.yaml
│ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ └── txt2img-1p4B-eval.yaml
│ ├── retrieval-augmented-diffusion
│ │ └── 768x768.yaml
│ └── stable-diffusion
│ │ └── v1-inference.yaml
├── environment.yaml
├── ldm
│ ├── __pycache__
│ │ └── util.cpython-310.pyc
│ ├── data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── imagenet.py
│ │ └── lsun.py
│ ├── lr_scheduler.py
│ ├── models
│ │ ├── __pycache__
│ │ │ └── autoencoder.cpython-310.pyc
│ │ ├── autoencoder.py
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── ddim.cpython-310.pyc
│ │ │ ├── ddpm.cpython-310.pyc
│ │ │ └── plms.cpython-310.pyc
│ │ │ ├── classifier.py
│ │ │ ├── ddim.py
│ │ │ ├── ddpm.py
│ │ │ ├── dpm_solver
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── dpm_solver.cpython-310.pyc
│ │ │ │ └── sampler.cpython-310.pyc
│ │ │ ├── dpm_solver.py
│ │ │ └── sampler.py
│ │ │ └── plms.py
│ ├── modules
│ │ ├── __pycache__
│ │ │ ├── attention.cpython-310.pyc
│ │ │ ├── ema.cpython-310.pyc
│ │ │ └── x_transformer.cpython-310.pyc
│ │ ├── attention.py
│ │ ├── diffusionmodules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ ├── model.cpython-310.pyc
│ │ │ │ ├── openaimodel.cpython-310.pyc
│ │ │ │ └── util.cpython-310.pyc
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── distributions.cpython-310.pyc
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-310.pyc
│ │ │ │ └── modules.cpython-310.pyc
│ │ │ └── modules.py
│ │ ├── image_degradation
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ ├── utils
│ │ │ │ └── test.png
│ │ │ └── utils_image.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── main.py
├── models
│ ├── first_stage_models
│ │ ├── kl-f16
│ │ │ └── config.yaml
│ │ ├── kl-f32
│ │ │ └── config.yaml
│ │ ├── kl-f4
│ │ │ └── config.yaml
│ │ ├── kl-f8
│ │ │ └── config.yaml
│ │ ├── vq-f16
│ │ │ └── config.yaml
│ │ ├── vq-f4-noattn
│ │ │ └── config.yaml
│ │ ├── vq-f4
│ │ │ └── config.yaml
│ │ ├── vq-f8-n256
│ │ │ └── config.yaml
│ │ └── vq-f8
│ │ │ └── config.yaml
│ └── ldm
│ │ ├── bsr_sr
│ │ └── config.yaml
│ │ ├── celeba256
│ │ └── config.yaml
│ │ ├── cin256
│ │ └── config.yaml
│ │ ├── ffhq256
│ │ └── config.yaml
│ │ ├── inpainting_big
│ │ └── config.yaml
│ │ ├── layout2img-openimages256
│ │ └── config.yaml
│ │ ├── lsun_beds256
│ │ └── config.yaml
│ │ ├── lsun_churches256
│ │ └── config.yaml
│ │ ├── semantic_synthesis256
│ │ └── config.yaml
│ │ ├── semantic_synthesis512
│ │ └── config.yaml
│ │ └── text2img256
│ │ └── config.yaml
├── notebook_helpers.py
├── outputs
│ └── txt2img-samples
│ │ └── samples
│ │ ├── 00000.png
│ │ ├── 00001.png
│ │ ├── 00002.png
│ │ ├── 00003.png
│ │ ├── 00004.png
│ │ └── 00005.png
├── scripts
│ ├── download_first_stages.sh
│ ├── download_models.sh
│ ├── img2img.py
│ ├── inpaint.py
│ ├── knn2img.py
│ ├── latent_imagenet_diffusion.ipynb
│ ├── sample_diffusion.py
│ ├── tests
│ │ └── test_watermark.py
│ ├── train_searcher.py
│ └── txt2img.py
└── setup.py
└── trace.py
/README.md:
--------------------------------------------------------------------------------
1 | ### A Watermark-Conditioned Diffusion Model for IP Protection (ECCV 2024)
2 | This code is the official implementation of [A Watermark-Conditioned Diffusion Model for IP Protection](https://arxiv.org/abs/2403.10893).
3 |
4 | ----
5 |

6 |
7 | ### Abstract
8 |
9 | The ethical need to protect AI-generated content has been a significant concern in recent years. While existing watermarking strategies have demonstrated success in detecting synthetic content (detection), there has been limited exploration in identifying the users responsible for generating these outputs from a single model (owner identification). In this paper, we focus on both practical scenarios and propose a unified watermarking framework for content copyright protection within the context of diffusion models. Specifically, we consider two parties: the model provider, who grants public access to a diffusion model via an API, and the users, who can solely query the model API and generate images in a black-box manner. Our task is to embed hidden information into the generated contents, which facilitates further detection and owner identification. To tackle this challenge, we propose a Watermark-conditioned Diffusion model called WaDiff, which manipulates the watermark as a conditioned input and incorporates fingerprinting into the generation process. All the generative outputs from our WaDiff carry user-specific information, which can be recovered by an image extractor and further facilitate forensic identification. Extensive experiments are conducted on two popular diffusion models, and we demonstrate that our method is effective and robust in both the detection and owner identification tasks. Meanwhile, our watermarking framework only exerts a negligible impact on the original generation and is more stealthy and efficient in comparison to existing watermarking strategies.
10 |
11 | ### Setup
12 | To configure the environment, you can refer [WatermarkDM](https://github.com/yunqing-me/WatermarkDM) for training StegaStamp decoder, [guided-diffusion](https://github.com/openai/guided-diffusion) for fine-tuning ImageNet diffusion model and [stable-diffusion](https://github.com/CompVis/stable-diffusion) for fine-tuning the Stable Diffusion.
13 |
14 | ### Pipeline
15 | #### Step 1: Pre-train Watermark Decoder
16 |
17 | First, you need to pre-train the watermark encoder and decoder jointly. Go to the [StegaStamp](StegaStamp) folder and simply run:
18 | ```cmd
19 | cd StegaStamp
20 | sh train.sh
21 | ```
22 | Note that directly running the script may not be successful as you need to specify the path of the training data ```--data_dir``` in your project. Besides, you can customize your experiments by adjusting hyperparameters such as the number of watermark bits ```--bit_length```, image resolution ```--image_resolution```, training epochs ```--num_epochs``` and GPU device ```--cuda```.
23 |
24 | #### Step 2: Fine-tune Diffusion Model
25 | Once you have finished the pre-training process, you can utilize the watermark decoder to guide the diffusion model's fine-tuning process. For the ImageNet Diffusion model, you can run the following commands:
26 | ```cmd
27 | cd ../guided-diffusion
28 | sh train.sh
29 | ```
30 | But before running the script, you need to configure properly, i.e. the path of the pre-trained decoder checkpoint ```--wm_decoder_path``` (from Step 1) and the path of the training data ```--data_dir``` in your project (mostly the same in Step 1), the number of watermark bits ```--wm_length```, the balance parameter $\alpha$ ```--alpha```, and the time threshold $\tau$ ```--threshold```. Besides, you need to download the pre-trained diffusion model [checkpoint](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt) and put it into the ```models/``` folder.
31 |
32 | #### Step 3: Generate Watermarked Images
33 | After the fine-tuning step, you could use this watermark-conditioned diffusion model to generate watermarked images with the following commands:
34 | ```cmd
35 | sh generate.sh
36 | ```
37 | All generated images are saved in a default folder named ```saved_images``` (specified by ```--output_path```) and will be organized into individual subfolders indexed by the specific ID of the watermark. You could also change ```--batch_size``` to adjust the number of generated images within individual subfolders.
38 |
39 | #### Step 4: Source Identification
40 | Finally, run the following command to perform tracing:
41 | ```cmd
42 | cd ..
43 | python trace.py
44 | ```
45 | Note that the ```--image_path``` indicates where you save watermarked images, which should be consistent with the ```--output_path``` specified in Step 3.
46 |
47 | The code for stable diffusion is not finished yet.
48 |
49 | ### Citation
50 | ```
51 | @article{min2024watermark,
52 | title={A watermark-conditioned diffusion model for ip protection},
53 | author={Min, Rui and Li, Sen and Chen, Hongyang and Cheng, Minhao},
54 | journal={arXiv preprint arXiv:2403.10893},
55 | year={2024}
56 | }
57 | ```
58 |
59 | #### Our codes are heavily built upon [WatermarkDM](https://github.com/yunqing-me/WatermarkDM), [guided-diffusion](https://github.com/openai/guided-diffusion) and [stable-diffusion](https://github.com/CompVis/stable-diffusion).
60 |
--------------------------------------------------------------------------------
/StegaStamp/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn.functional import relu, sigmoid
5 |
6 |
7 | class StegaStampEncoder(nn.Module):
8 | def __init__(
9 | self,
10 | resolution=32,
11 | IMAGE_CHANNELS=1,
12 | fingerprint_size=100,
13 | return_residual=False,
14 | ):
15 | super(StegaStampEncoder, self).__init__()
16 | self.fingerprint_size = fingerprint_size
17 | self.IMAGE_CHANNELS = IMAGE_CHANNELS
18 | self.return_residual = return_residual
19 | self.secret_dense = nn.Linear(self.fingerprint_size, 64 * 64 * IMAGE_CHANNELS)
20 |
21 | log_resolution = int(math.log(resolution, 2))
22 | assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
23 |
24 | self.fingerprint_upsample = nn.Upsample(scale_factor=(2**(log_resolution-6), 2**(log_resolution-6)))
25 | self.conv1 = nn.Conv2d(2 * IMAGE_CHANNELS, 32, 3, 1, 1)
26 | self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
27 | self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
28 | self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
29 | self.conv5 = nn.Conv2d(128, 256, 3, 2, 1)
30 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
31 | self.up6 = nn.Conv2d(256, 128, 2, 1)
32 | self.upsample6 = nn.Upsample(scale_factor=(2, 2))
33 | self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1)
34 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
35 | self.up7 = nn.Conv2d(128, 64, 2, 1)
36 | self.upsample7 = nn.Upsample(scale_factor=(2, 2))
37 | self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1)
38 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
39 | self.up8 = nn.Conv2d(64, 32, 2, 1)
40 | self.upsample8 = nn.Upsample(scale_factor=(2, 2))
41 | self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1)
42 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
43 | self.up9 = nn.Conv2d(32, 32, 2, 1)
44 | self.upsample9 = nn.Upsample(scale_factor=(2, 2))
45 | self.conv9 = nn.Conv2d(32 + 32 + 2 * IMAGE_CHANNELS, 32, 3, 1, 1)
46 | self.conv10 = nn.Conv2d(32, 32, 3, 1, 1)
47 | self.residual = nn.Conv2d(32, IMAGE_CHANNELS, 1)
48 |
49 | def forward(self, fingerprint, image):
50 | fingerprint = relu(self.secret_dense(fingerprint))
51 | fingerprint = fingerprint.view((-1, self.IMAGE_CHANNELS, 64, 64))
52 | fingerprint_enlarged = self.fingerprint_upsample(fingerprint)
53 | inputs = torch.cat([fingerprint_enlarged, image], dim=1)
54 | conv1 = relu(self.conv1(inputs))
55 | conv2 = relu(self.conv2(conv1))
56 | conv3 = relu(self.conv3(conv2))
57 | conv4 = relu(self.conv4(conv3))
58 | conv5 = relu(self.conv5(conv4))
59 | up6 = relu(self.up6(self.pad6(self.upsample6(conv5))))
60 | merge6 = torch.cat([conv4, up6], dim=1)
61 | conv6 = relu(self.conv6(merge6))
62 | up7 = relu(self.up7(self.pad7(self.upsample7(conv6))))
63 | merge7 = torch.cat([conv3, up7], dim=1)
64 | conv7 = relu(self.conv7(merge7))
65 | up8 = relu(self.up8(self.pad8(self.upsample8(conv7))))
66 | merge8 = torch.cat([conv2, up8], dim=1)
67 | conv8 = relu(self.conv8(merge8))
68 | up9 = relu(self.up9(self.pad9(self.upsample9(conv8))))
69 | merge9 = torch.cat([conv1, up9, inputs], dim=1)
70 | conv9 = relu(self.conv9(merge9))
71 | conv10 = relu(self.conv10(conv9))
72 | residual = self.residual(conv10)
73 | if not self.return_residual:
74 | residual = sigmoid(residual)
75 | return residual
76 |
77 |
78 | class StegaStampDecoder(nn.Module):
79 | def __init__(self, resolution=32, IMAGE_CHANNELS=1, fingerprint_size=1):
80 | super(StegaStampDecoder, self).__init__()
81 | self.resolution = resolution
82 | self.IMAGE_CHANNELS = IMAGE_CHANNELS
83 | self.decoder = nn.Sequential(
84 | nn.Conv2d(IMAGE_CHANNELS, 32, (3, 3), 2, 1), # 16
85 | nn.ReLU(),
86 | nn.Conv2d(32, 32, 3, 1, 1),
87 | nn.ReLU(),
88 | nn.Conv2d(32, 64, 3, 2, 1), # 8
89 | nn.ReLU(),
90 | nn.Conv2d(64, 64, 3, 1, 1),
91 | nn.ReLU(),
92 | nn.Conv2d(64, 64, 3, 2, 1), # 4
93 | nn.ReLU(),
94 | nn.Conv2d(64, 128, 3, 2, 1), # 2
95 | nn.ReLU(),
96 | nn.Conv2d(128, 128, (3, 3), 2, 1),
97 | nn.ReLU(),
98 | )
99 | self.dense = nn.Sequential(
100 | nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
101 | nn.ReLU(),
102 | nn.Linear(512, fingerprint_size),
103 | )
104 |
105 | def forward(self, image):
106 | x = self.decoder(image)
107 | x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
108 | return self.dense(x)
109 |
--------------------------------------------------------------------------------
/StegaStamp/train.sh:
--------------------------------------------------------------------------------
1 | python train.py --data_dir ./data/coco/train2014 \
2 | --bit_length 48 --image_resolution 512 --num_epochs 100 --cuda 0
3 |
--------------------------------------------------------------------------------
/StegaStamp/utils_img.py:
--------------------------------------------------------------------------------
1 | # This script is modified from "https://github.com/facebookresearch/stable_signature"
2 |
3 | import os
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.autograd.variable import Variable
10 | from torchvision import transforms
11 | from torchvision.transforms import functional
12 | from augly.image import functional as aug_functional
13 |
14 | import kornia.augmentation as K
15 |
16 | from PIL import Image
17 |
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21 | default_transform = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24 | ])
25 | image_mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
26 | image_std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
27 |
28 | normalize_rgb = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29 | unnormalize_rgb = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])
30 | normalize_yuv = transforms.Normalize(mean=[0.5, 0, 0], std=[0.5, 1, 1])
31 | unnormalize_yuv = transforms.Normalize(mean=[-0.5/0.5, 0, 0], std=[1/0.5, 1/1, 1/1])
32 |
33 |
34 | def normalize_img(x):
35 | """ Normalize image to approx. [-1,1] """
36 | return (x - image_mean.to(x.device)) / image_std.to(x.device)
37 |
38 | def unnormalize_img(x):
39 | """ Unnormalize image to [0,1] """
40 | return (x * image_std.to(x.device)) + image_mean.to(x.device)
41 |
42 | def round_pixel(x):
43 | """
44 | Round pixel values to nearest integer.
45 | Args:
46 | x: Image tensor with values approx. between [-1,1]
47 | Returns:
48 | y: Rounded image tensor with values approx. between [-1,1]
49 | """
50 | x_pixel = 255 * unnormalize_img(x)
51 | y = torch.round(x_pixel).clamp(0, 255)
52 | y = normalize_img(y/255.0)
53 | return y
54 |
55 | def clamp_pixel(x):
56 | """
57 | Clamp pixel values to 0 255.
58 | Args:
59 | x: Image tensor with values approx. between [-1,1]
60 | Returns:
61 | y: Rounded image tensor with values approx. between [-1,1]
62 | """
63 | x_pixel = 255 * unnormalize_img(x)
64 | y = x_pixel.clamp(0, 255)
65 | y = normalize_img(y/255.0)
66 | return y
67 |
68 | def project_linf(x, y, radius):
69 | """
70 | Clamp x so that Linf(x,y)<=radius
71 | Args:
72 | x: Image tensor with values approx. between [-1,1]
73 | y: Image tensor with values approx. between [-1,1], ex: original image
74 | radius: Radius of Linf ball for the images in pixel space [0, 255]
75 | """
76 | delta = x - y
77 | delta = 255 * (delta * image_std.to(x.device))
78 | delta = torch.clamp(delta, -radius, radius)
79 | delta = (delta / 255.0) / image_std.to(x.device)
80 | return y + delta
81 |
82 | def psnr(x, y):
83 | """
84 | Return PSNR
85 | Args:
86 | x: Image tensor with values approx. between [-1,1]
87 | y: Image tensor with values approx. between [-1,1], ex: original image
88 | """
89 | delta = x - y
90 | delta = 255 * (delta * image_std.to(x.device))
91 | delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW
92 | psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B
93 | return psnr
94 |
95 | def center_crop(x, scale):
96 | """ Perform center crop such that the target area of the crop is at a given scale
97 | Args:
98 | x: PIL image
99 | scale: target area scale
100 | """
101 | scale = np.sqrt(scale)
102 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
103 |
104 | # left = int(x.size[0]/2-new_edges_size[0]/2)
105 | # upper = int(x.size[1]/2-new_edges_size[1]/2)
106 | # right = left + new_edges_size[0]
107 | # lower = upper + new_edges_size[1]
108 |
109 | # return x.crop((left, upper, right, lower))
110 | x = functional.center_crop(x, new_edges_size)
111 | return x
112 |
113 | def resize(x, scale):
114 | """ Perform center crop such that the target area of the crop is at a given scale
115 | Args:
116 | x: PIL image
117 | scale: target area scale
118 | """
119 | scale = np.sqrt(scale)
120 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1]
121 | return functional.resize(x, new_edges_size)
122 |
123 | def rotate(x, angle):
124 | """ Rotate image by angle
125 | Args:
126 | x: image (PIl or tensor)
127 | angle: angle in degrees
128 | """
129 | return functional.rotate(x, angle)
130 |
131 | def adjust_brightness(x, brightness_factor):
132 | """ Adjust brightness of an image
133 | Args:
134 | x: PIL image
135 | brightness_factor: brightness factor
136 | """
137 | return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor))
138 |
139 | def adjust_contrast(x, contrast_factor):
140 | """ Adjust constrast of an image
141 | Args:
142 | x: PIL image
143 | contrast_factor: contrast factor
144 | """
145 | return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor))
146 |
147 | def jpeg_compress(x, quality_factor):
148 | """ Apply jpeg compression to image
149 | Args:
150 | x: Tensor image
151 | quality_factor: quality factor
152 | """
153 | to_pil = transforms.ToPILImage()
154 | to_tensor = transforms.ToTensor()
155 | img_aug = torch.zeros_like(x, device=x.device)
156 | x = unnormalize_img(x)
157 | for ii,img in enumerate(x):
158 | pil_img = to_pil(img)
159 | img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor))
160 | return normalize_img(img_aug)
161 |
162 | def gaussian_blur(x, sigma=1):
163 | """ Add gaussian blur to image
164 | Args:
165 | x: Tensor image
166 | sigma: sigma of gaussian kernel
167 | """
168 | x = unnormalize_img(x)
169 | x = functional.gaussian_blur(x, sigma=sigma, kernel_size=21)
170 | x = normalize_img(x)
171 | return x
172 |
--------------------------------------------------------------------------------
/guided-diffusion/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/guided-diffusion/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Downloading datasets
2 |
3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase.
4 |
5 | ## Class-conditional ImageNet
6 |
7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class.
8 |
9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script:
10 |
11 | ```
12 | for file in *.tar; do tar xf "$file"; rm "$file"; done
13 | ```
14 |
15 | This will extract and remove each tar file in turn.
16 |
17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels.
18 |
19 | ## LSUN bedroom
20 |
21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so:
22 |
23 | ```
24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir
25 | ```
26 |
27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument.
28 |
--------------------------------------------------------------------------------
/guided-diffusion/datasets/lsun_bedroom.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert an LSUN lmdb database into a directory of images.
3 | """
4 |
5 | import argparse
6 | import io
7 | import os
8 |
9 | from PIL import Image
10 | import lmdb
11 | import numpy as np
12 |
13 |
14 | def read_images(lmdb_path, image_size):
15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True)
16 | with env.begin(write=False) as transaction:
17 | cursor = transaction.cursor()
18 | for _, webp_data in cursor:
19 | img = Image.open(io.BytesIO(webp_data))
20 | width, height = img.size
21 | scale = image_size / min(width, height)
22 | img = img.resize(
23 | (int(round(scale * width)), int(round(scale * height))),
24 | resample=Image.BOX,
25 | )
26 | arr = np.array(img)
27 | h, w, _ = arr.shape
28 | h_off = (h - image_size) // 2
29 | w_off = (w - image_size) // 2
30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size]
31 | yield arr
32 |
33 |
34 | def dump_images(out_dir, images, prefix):
35 | if not os.path.exists(out_dir):
36 | os.mkdir(out_dir)
37 | for i, img in enumerate(images):
38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png"))
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--image-size", help="new image size", type=int, default=256)
44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom")
45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database")
46 | parser.add_argument("out_dir", help="path to output directory")
47 | args = parser.parse_args()
48 |
49 | images = read_images(args.lmdb_path, args.image_size)
50 | dump_images(args.out_dir, images, args.prefix)
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/guided-diffusion/evaluations/README.md:
--------------------------------------------------------------------------------
1 | # Evaluations
2 |
3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4 |
5 | # Download batches
6 |
7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
8 |
9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
10 |
11 | Here are links to download all of the sample and reference batches:
12 |
13 | * LSUN
14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
25 |
26 | * ImageNet
27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
50 |
51 | # Run evaluations
52 |
53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
54 |
55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
56 |
57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
58 |
59 | ```
60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
61 | ...
62 | computing reference batch activations...
63 | computing/reading reference batch statistics...
64 | computing sample batch activations...
65 | computing/reading sample batch statistics...
66 | Computing evaluations...
67 | Inception Score: 215.8370361328125
68 | FID: 3.9425574129223264
69 | sFID: 6.140433703346162
70 | Precision: 0.8265
71 | Recall: 0.5309
72 | ```
73 |
--------------------------------------------------------------------------------
/guided-diffusion/evaluations/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu>=2.0
2 | scipy
3 | requests
4 | tqdm
--------------------------------------------------------------------------------
/guided-diffusion/generate.sh:
--------------------------------------------------------------------------------
1 | MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
2 | SAMPLE_FLAGS="--batch_size 4 --num_samples 8 --timestep_respacing 100 --use_ddim True"
3 | python scripts/image_sample.py $MODEL_FLAGS --output_path saved_images/ --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/dist_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/dist_util.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/fp16_util.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/image_datasets.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/image_datasets.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/logger.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/logger.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/losses.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/losses.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/nn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/nn.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/resample.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/resample.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/respace.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/respace.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/script_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/script_util.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/stega_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/stega_model.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/train_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/train_util.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/__pycache__/unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/guided-diffusion/guided_diffusion/__pycache__/unet.cpython-310.pyc
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | from mpi4py import MPI
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | def load_data(
12 | *,
13 | data_dir,
14 | batch_size,
15 | image_size,
16 | class_cond=False,
17 | deterministic=False,
18 | random_crop=False,
19 | random_flip=True,
20 | ):
21 | """
22 | For a dataset, create a generator over (images, kwargs) pairs.
23 |
24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
25 | more keys, each of which map to a batched Tensor of their own.
26 | The kwargs dict can be used for class labels, in which case the key is "y"
27 | and the values are integer tensors of class labels.
28 |
29 | :param data_dir: a dataset directory.
30 | :param batch_size: the batch size of each returned pair.
31 | :param image_size: the size to which images are resized.
32 | :param class_cond: if True, include a "y" key in returned dicts for class
33 | label. If classes are not available and this is true, an
34 | exception will be raised.
35 | :param deterministic: if True, yield results in a deterministic order.
36 | :param random_crop: if True, randomly crop the images for augmentation.
37 | :param random_flip: if True, randomly flip the images for augmentation.
38 | """
39 | if not data_dir:
40 | raise ValueError("unspecified data directory")
41 | all_files = _list_image_files_recursively(data_dir)
42 | classes = None
43 | if class_cond:
44 | # Assume classes are the first part of the filename,
45 | # before an underscore.
46 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
48 | classes = [sorted_classes[x] for x in class_names]
49 | dataset = ImageDataset(
50 | image_size,
51 | all_files,
52 | classes=classes,
53 | shard=MPI.COMM_WORLD.Get_rank(),
54 | num_shards=MPI.COMM_WORLD.Get_size(),
55 | random_crop=random_crop,
56 | random_flip=random_flip,
57 | )
58 | if deterministic:
59 | loader = DataLoader(
60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
61 | )
62 | else:
63 | loader = DataLoader(
64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
65 | )
66 | while True:
67 | yield from loader
68 |
69 |
70 | def _list_image_files_recursively(data_dir):
71 | results = []
72 | for entry in sorted(bf.listdir(data_dir)):
73 | full_path = bf.join(data_dir, entry)
74 | ext = entry.split(".")[-1]
75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
76 | results.append(full_path)
77 | elif bf.isdir(full_path):
78 | results.extend(_list_image_files_recursively(full_path))
79 | return results
80 |
81 |
82 | class ImageDataset(Dataset):
83 | def __init__(
84 | self,
85 | resolution,
86 | image_paths,
87 | classes=None,
88 | shard=0,
89 | num_shards=1,
90 | random_crop=False,
91 | random_flip=True,
92 | ):
93 | super().__init__()
94 | self.resolution = resolution
95 | self.local_images = image_paths[shard:][::num_shards]
96 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
97 | self.random_crop = random_crop
98 | self.random_flip = random_flip
99 |
100 | def __len__(self):
101 | return len(self.local_images)
102 |
103 | def __getitem__(self, idx):
104 | path = self.local_images[idx]
105 | with bf.BlobFile(path, "rb") as f:
106 | pil_image = Image.open(f)
107 | pil_image.load()
108 | pil_image = pil_image.convert("RGB")
109 |
110 | if self.random_crop:
111 | arr = random_crop_arr(pil_image, self.resolution)
112 | else:
113 | arr = center_crop_arr(pil_image, self.resolution)
114 |
115 | if self.random_flip and random.random() < 0.5:
116 | arr = arr[:, ::-1]
117 |
118 | arr = arr.astype(np.float32) / 127.5 - 1
119 |
120 | out_dict = {}
121 | if self.local_classes is not None:
122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123 | return np.transpose(arr, [2, 0, 1]), out_dict
124 |
125 |
126 | def center_crop_arr(pil_image, image_size):
127 | # We are not on a new enough PIL to support the `reducing_gap`
128 | # argument, which uses BOX downsampling at powers of two first.
129 | # Thus, we do it by hand to improve downsample quality.
130 | while min(*pil_image.size) >= 2 * image_size:
131 | pil_image = pil_image.resize(
132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
133 | )
134 |
135 | scale = image_size / min(*pil_image.size)
136 | pil_image = pil_image.resize(
137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
138 | )
139 |
140 | arr = np.array(pil_image)
141 | crop_y = (arr.shape[0] - image_size) // 2
142 | crop_x = (arr.shape[1] - image_size) // 2
143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
144 |
145 |
146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
150 |
151 | # We are not on a new enough PIL to support the `reducing_gap`
152 | # argument, which uses BOX downsampling at powers of two first.
153 | # Thus, we do it by hand to improve downsample quality.
154 | while min(*pil_image.size) >= 2 * smaller_dim_size:
155 | pil_image = pil_image.resize(
156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
157 | )
158 |
159 | scale = smaller_dim_size / min(*pil_image.size)
160 | pil_image = pil_image.resize(
161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
162 | )
163 |
164 | arr = np.array(pil_image)
165 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
166 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
168 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided-diffusion/guided_diffusion/stega_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn.functional import relu, sigmoid
5 |
6 |
7 | class StegaStampEncoder(nn.Module):
8 | def __init__(
9 | self,
10 | resolution=32,
11 | IMAGE_CHANNELS=1,
12 | fingerprint_size=100,
13 | return_residual=False,
14 | ):
15 | super(StegaStampEncoder, self).__init__()
16 | self.fingerprint_size = fingerprint_size
17 | self.IMAGE_CHANNELS = IMAGE_CHANNELS
18 | self.return_residual = return_residual
19 | self.secret_dense = nn.Linear(self.fingerprint_size, 64 * 64 * IMAGE_CHANNELS)
20 |
21 | log_resolution = int(math.log(resolution, 2))
22 | assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
23 |
24 | self.fingerprint_upsample = nn.Upsample(scale_factor=(2**(log_resolution-6), 2**(log_resolution-6)))
25 | self.conv1 = nn.Conv2d(2 * IMAGE_CHANNELS, 32, 3, 1, 1)
26 | self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
27 | self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
28 | self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
29 | self.conv5 = nn.Conv2d(128, 256, 3, 2, 1)
30 | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
31 | self.up6 = nn.Conv2d(256, 128, 2, 1)
32 | self.upsample6 = nn.Upsample(scale_factor=(2, 2))
33 | self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1)
34 | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
35 | self.up7 = nn.Conv2d(128, 64, 2, 1)
36 | self.upsample7 = nn.Upsample(scale_factor=(2, 2))
37 | self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1)
38 | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
39 | self.up8 = nn.Conv2d(64, 32, 2, 1)
40 | self.upsample8 = nn.Upsample(scale_factor=(2, 2))
41 | self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1)
42 | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
43 | self.up9 = nn.Conv2d(32, 32, 2, 1)
44 | self.upsample9 = nn.Upsample(scale_factor=(2, 2))
45 | self.conv9 = nn.Conv2d(32 + 32 + 2 * IMAGE_CHANNELS, 32, 3, 1, 1)
46 | self.conv10 = nn.Conv2d(32, 32, 3, 1, 1)
47 | self.residual = nn.Conv2d(32, IMAGE_CHANNELS, 1)
48 |
49 | def forward(self, fingerprint, image):
50 | fingerprint = relu(self.secret_dense(fingerprint))
51 | fingerprint = fingerprint.view((-1, self.IMAGE_CHANNELS, 64, 64))
52 | fingerprint_enlarged = self.fingerprint_upsample(fingerprint)
53 | inputs = torch.cat([fingerprint_enlarged, image], dim=1)
54 | conv1 = relu(self.conv1(inputs))
55 | conv2 = relu(self.conv2(conv1))
56 | conv3 = relu(self.conv3(conv2))
57 | conv4 = relu(self.conv4(conv3))
58 | conv5 = relu(self.conv5(conv4))
59 | up6 = relu(self.up6(self.pad6(self.upsample6(conv5))))
60 | merge6 = torch.cat([conv4, up6], dim=1)
61 | conv6 = relu(self.conv6(merge6))
62 | up7 = relu(self.up7(self.pad7(self.upsample7(conv6))))
63 | merge7 = torch.cat([conv3, up7], dim=1)
64 | conv7 = relu(self.conv7(merge7))
65 | up8 = relu(self.up8(self.pad8(self.upsample8(conv7))))
66 | merge8 = torch.cat([conv2, up8], dim=1)
67 | conv8 = relu(self.conv8(merge8))
68 | up9 = relu(self.up9(self.pad9(self.upsample9(conv8))))
69 | merge9 = torch.cat([conv1, up9, inputs], dim=1)
70 | conv9 = relu(self.conv9(merge9))
71 | conv10 = relu(self.conv10(conv9))
72 | residual = self.residual(conv10)
73 | if not self.return_residual:
74 | residual = sigmoid(residual)
75 | return residual
76 |
77 |
78 | class StegaStampDecoder(nn.Module):
79 | def __init__(self, resolution=32, IMAGE_CHANNELS=1, fingerprint_size=1):
80 | super(StegaStampDecoder, self).__init__()
81 | self.resolution = resolution
82 | self.IMAGE_CHANNELS = IMAGE_CHANNELS
83 | self.fingerprint_size=fingerprint_size
84 | self.decoder = nn.Sequential(
85 | nn.Conv2d(IMAGE_CHANNELS, 32, (3, 3), 2, 1), # 16
86 | nn.ReLU(),
87 | nn.Conv2d(32, 32, 3, 1, 1),
88 | nn.ReLU(),
89 | nn.Conv2d(32, 64, 3, 2, 1), # 8
90 | nn.ReLU(),
91 | nn.Conv2d(64, 64, 3, 1, 1),
92 | nn.ReLU(),
93 | nn.Conv2d(64, 64, 3, 2, 1), # 4
94 | nn.ReLU(),
95 | nn.Conv2d(64, 128, 3, 2, 1), # 2
96 | nn.ReLU(),
97 | nn.Conv2d(128, 128, (3, 3), 2, 1),
98 | nn.ReLU(),
99 | )
100 | self.dense = nn.Sequential(
101 | nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
102 | nn.ReLU(),
103 | nn.Linear(512, fingerprint_size),
104 | )
105 |
106 | def forward(self, image):
107 | x = self.decoder(image)
108 | x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
109 | return self.dense(x)
110 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/classifier_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Like image_sample.py, but use a noisy image classifier to guide the sampling
3 | process towards more realistic images.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 | import torch.nn.functional as F
13 |
14 | from guided_diffusion import dist_util, logger
15 | from guided_diffusion.script_util import (
16 | NUM_CLASSES,
17 | model_and_diffusion_defaults,
18 | classifier_defaults,
19 | create_model_and_diffusion,
20 | create_classifier,
21 | add_dict_to_argparser,
22 | args_to_dict,
23 | )
24 |
25 |
26 | def main():
27 | args = create_argparser().parse_args()
28 |
29 | dist_util.setup_dist()
30 | logger.configure()
31 |
32 | logger.log("creating model and diffusion...")
33 | model, diffusion = create_model_and_diffusion(
34 | **args_to_dict(args, model_and_diffusion_defaults().keys())
35 | )
36 | model.load_state_dict(
37 | dist_util.load_state_dict(args.model_path, map_location="cpu")
38 | )
39 | model.to(dist_util.dev())
40 | if args.use_fp16:
41 | model.convert_to_fp16()
42 | model.eval()
43 |
44 | logger.log("loading classifier...")
45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
46 | classifier.load_state_dict(
47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu")
48 | )
49 | classifier.to(dist_util.dev())
50 | if args.classifier_use_fp16:
51 | classifier.convert_to_fp16()
52 | classifier.eval()
53 |
54 | def cond_fn(x, t, y=None):
55 | assert y is not None
56 | with th.enable_grad():
57 | x_in = x.detach().requires_grad_(True)
58 | logits = classifier(x_in, t)
59 | log_probs = F.log_softmax(logits, dim=-1)
60 | selected = log_probs[range(len(logits)), y.view(-1)]
61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
62 |
63 | def model_fn(x, t, y=None):
64 | assert y is not None
65 | return model(x, t, y if args.class_cond else None)
66 |
67 | logger.log("sampling...")
68 | all_images = []
69 | all_labels = []
70 | while len(all_images) * args.batch_size < args.num_samples:
71 | model_kwargs = {}
72 | classes = th.randint(
73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
74 | )
75 | model_kwargs["y"] = classes
76 | sample_fn = (
77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
78 | )
79 | sample = sample_fn(
80 | model_fn,
81 | (args.batch_size, 3, args.image_size, args.image_size),
82 | clip_denoised=args.clip_denoised,
83 | model_kwargs=model_kwargs,
84 | cond_fn=cond_fn,
85 | device=dist_util.dev(),
86 | )
87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
88 | sample = sample.permute(0, 2, 3, 1)
89 | sample = sample.contiguous()
90 |
91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
95 | dist.all_gather(gathered_labels, classes)
96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
97 | logger.log(f"created {len(all_images) * args.batch_size} samples")
98 |
99 | arr = np.concatenate(all_images, axis=0)
100 | arr = arr[: args.num_samples]
101 | label_arr = np.concatenate(all_labels, axis=0)
102 | label_arr = label_arr[: args.num_samples]
103 | if dist.get_rank() == 0:
104 | shape_str = "x".join([str(x) for x in arr.shape])
105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
106 | logger.log(f"saving to {out_path}")
107 | np.savez(out_path, arr, label_arr)
108 |
109 | dist.barrier()
110 | logger.log("sampling complete")
111 |
112 |
113 | def create_argparser():
114 | defaults = dict(
115 | clip_denoised=True,
116 | num_samples=10000,
117 | batch_size=16,
118 | use_ddim=False,
119 | model_path="",
120 | classifier_path="",
121 | classifier_scale=1.0,
122 | )
123 | defaults.update(model_and_diffusion_defaults())
124 | defaults.update(classifier_defaults())
125 | parser = argparse.ArgumentParser()
126 | add_dict_to_argparser(parser, defaults)
127 | return parser
128 |
129 |
130 | if __name__ == "__main__":
131 | main()
132 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/image_nll.py:
--------------------------------------------------------------------------------
1 | """
2 | Approximate the bits/dimension for an image model.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import numpy as np
9 | import torch.distributed as dist
10 |
11 | from guided_diffusion import dist_util, logger
12 | from guided_diffusion.image_datasets import load_data
13 | from guided_diffusion.script_util import (
14 | model_and_diffusion_defaults,
15 | create_model_and_diffusion,
16 | add_dict_to_argparser,
17 | args_to_dict,
18 | )
19 |
20 |
21 | def main():
22 | args = create_argparser().parse_args()
23 |
24 | dist_util.setup_dist()
25 | logger.configure()
26 |
27 | logger.log("creating model and diffusion...")
28 | model, diffusion = create_model_and_diffusion(
29 | **args_to_dict(args, model_and_diffusion_defaults().keys())
30 | )
31 | model.load_state_dict(
32 | dist_util.load_state_dict(args.model_path, map_location="cpu")
33 | )
34 | model.to(dist_util.dev())
35 | model.eval()
36 |
37 | logger.log("creating data loader...")
38 | data = load_data(
39 | data_dir=args.data_dir,
40 | batch_size=args.batch_size,
41 | image_size=args.image_size,
42 | class_cond=args.class_cond,
43 | deterministic=True,
44 | )
45 |
46 | logger.log("evaluating...")
47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised)
48 |
49 |
50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised):
51 | all_bpd = []
52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
53 | num_complete = 0
54 | while num_complete < num_samples:
55 | batch, model_kwargs = next(data)
56 | batch = batch.to(dist_util.dev())
57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
58 | minibatch_metrics = diffusion.calc_bpd_loop(
59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
60 | )
61 |
62 | for key, term_list in all_metrics.items():
63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
64 | dist.all_reduce(terms)
65 | term_list.append(terms.detach().cpu().numpy())
66 |
67 | total_bpd = minibatch_metrics["total_bpd"]
68 | total_bpd = total_bpd.mean() / dist.get_world_size()
69 | dist.all_reduce(total_bpd)
70 | all_bpd.append(total_bpd.item())
71 | num_complete += dist.get_world_size() * batch.shape[0]
72 |
73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}")
74 |
75 | if dist.get_rank() == 0:
76 | for name, terms in all_metrics.items():
77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz")
78 | logger.log(f"saving {name} terms to {out_path}")
79 | np.savez(out_path, np.mean(np.stack(terms), axis=0))
80 |
81 | dist.barrier()
82 | logger.log("evaluation complete")
83 |
84 |
85 | def create_argparser():
86 | defaults = dict(
87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path=""
88 | )
89 | defaults.update(model_and_diffusion_defaults())
90 | parser = argparse.ArgumentParser()
91 | add_dict_to_argparser(parser, defaults)
92 | return parser
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/image_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import sys
8 | import os
9 | dir_path = os.path.dirname(os.path.realpath(__file__))
10 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
11 | sys.path.insert(0, parent_dir_path)
12 |
13 | import numpy as np
14 | import torch as th
15 | import torch.distributed as dist
16 | from PIL import Image
17 |
18 | from guided_diffusion import dist_util, logger
19 | from guided_diffusion.script_util import (
20 | NUM_CLASSES,
21 | model_and_diffusion_defaults,
22 | create_model_and_diffusion,
23 | add_dict_to_argparser,
24 | args_to_dict,
25 | )
26 |
27 |
28 |
29 |
30 | def save_images(images, output_path):
31 | for i in range(images.shape[0]):
32 | image_array = np.uint8(images[i])
33 | image = Image.fromarray(image_array)
34 | image.save(os.path.join(output_path, f'image_{i}.png'))
35 |
36 | def main():
37 | args = create_argparser().parse_args()
38 |
39 | dist_util.setup_dist()
40 | logger.configure()
41 |
42 | logger.log("creating model and diffusion...")
43 |
44 | if not os.path.exists(f'watermark_pool/{args.wm_length}_1e4.npy'):
45 | os.makedirs('watermark_pool', exist_ok=True)
46 | np.save(f'watermark_pool/{args.wm_length}_1e4.npy', np.random.randint(0, 2, size=(int(1e4), args.wm_length)))
47 | np.save(f'watermark_pool/{args.wm_length}_1e5.npy', np.random.randint(0, 2, size=(int(1e5), args.wm_length)))
48 | np.save(f'watermark_pool/{args.wm_length}_1e6.npy', np.random.randint(0, 2, size=(int(1e6), args.wm_length)))
49 |
50 | keys = np.load(f'watermark_pool/{args.wm_length}_1e4.npy')[:1000]
51 |
52 | model, diffusion = create_model_and_diffusion(
53 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=args.wm_length
54 | )
55 | model.load_state_dict(
56 | dist_util.load_state_dict(args.model_path, map_location="cpu")
57 | )
58 | model.to(dist_util.dev())
59 | if args.use_fp16:
60 | model.convert_to_fp16()
61 | model.eval()
62 |
63 | logger.log("sampling...")
64 | all_images = []
65 |
66 | # while len(all_images) * args.batch_size < args.num_samples:
67 | for idx, key in enumerate(keys):
68 | if args.wm_length < 0:
69 | key = None
70 | else:
71 | key = th.from_numpy(key).to(dist_util.dev()).float()
72 | key = key.repeat(args.batch_size, 1)
73 |
74 | model_kwargs = {}
75 | if args.class_cond:
76 | classes = th.randint(
77 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
78 | )
79 | model_kwargs["y"] = classes
80 | sample_fn = (
81 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
82 | )
83 | sample = sample_fn(
84 | model,
85 | (args.batch_size, 3, args.image_size, args.image_size),
86 | clip_denoised=args.clip_denoised,
87 | model_kwargs=model_kwargs,
88 | key=key,
89 | )
90 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
91 | sample = sample.permute(0, 2, 3, 1)
92 | sample = sample.contiguous()
93 |
94 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
95 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
96 | # all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
97 | all_images = [sample.cpu().numpy() for sample in gathered_samples]
98 | if args.class_cond:
99 | gathered_labels = [
100 | th.zeros_like(classes) for _ in range(dist.get_world_size())
101 | ]
102 | dist.all_gather(gathered_labels, classes)
103 |
104 | logger.log(f"created {len(all_images) * args.batch_size} samples")
105 |
106 |
107 | arr = np.concatenate(all_images, axis=0)
108 | arr = arr[: args.num_samples]
109 | os.makedirs(os.path.join(args.output_path, f'{idx}/'), exist_ok=True)
110 | save_images(arr, os.path.join(args.output_path, f'{idx}/'))
111 |
112 | logger.log(f"saving to {os.path.join(args.output_path, f'{idx}/')}")
113 |
114 |
115 | dist.barrier()
116 | logger.log("sampling complete")
117 |
118 |
119 | def create_argparser():
120 | defaults = dict(
121 | clip_denoised=True,
122 | num_samples=10,
123 | batch_size=16,
124 | use_ddim=False,
125 | model_path="",
126 | output_path='saved_images/',
127 | wm_length=48,
128 | )
129 | defaults.update(model_and_diffusion_defaults())
130 | parser = argparse.ArgumentParser()
131 | add_dict_to_argparser(parser, defaults)
132 | return parser
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/image_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 | import sys
7 | import os
8 | dir_path = os.path.dirname(os.path.realpath(__file__))
9 | parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
10 | sys.path.insert(0, parent_dir_path)
11 |
12 | from guided_diffusion import dist_util, logger
13 | from guided_diffusion.image_datasets import load_data
14 | from guided_diffusion.resample import create_named_schedule_sampler
15 | from guided_diffusion.script_util import (
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | args_to_dict,
19 | add_dict_to_argparser,
20 | )
21 | from guided_diffusion.train_util import TrainLoop
22 | from guided_diffusion.stega_model import StegaStampDecoder
23 | import torch as th
24 |
25 | def main():
26 | args = create_argparser().parse_args()
27 |
28 | dist_util.setup_dist()
29 | logger.configure()
30 |
31 | logger.log("creating model and diffusion...")
32 | # Change model architecture
33 |
34 | model, diffusion = create_model_and_diffusion(
35 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=args.wm_length
36 | )
37 |
38 | # Create the original model architecture
39 | ori_model, _ = create_model_and_diffusion(
40 | **args_to_dict(args, model_and_diffusion_defaults().keys()), wm_length=0
41 | )
42 | for param in ori_model.parameters():
43 | param.requires_grad = False
44 |
45 | ori_model.eval()
46 | ori_model.to(dist_util.dev())
47 |
48 | model.to(dist_util.dev())
49 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
50 |
51 | # Define watermark decoder
52 | if args.wm_length > 0 and isinstance(args.wm_length, int):
53 | wm_decoder = StegaStampDecoder(
54 | args.image_size,
55 | 3,
56 | args.wm_length,
57 | )
58 | # wm_decoder.load_state_dict(th.load(args.wm_decoder_path, map_location='cpu')).eval()
59 | wm_decoder.to(dist_util.dev())
60 | else:
61 | wm_decoder = None
62 |
63 | logger.log("creating data loader...")
64 | data = load_data(
65 | data_dir=args.data_dir,
66 | batch_size=args.batch_size,
67 | image_size=args.image_size,
68 | class_cond=args.class_cond,
69 | )
70 |
71 | logger.log("training...")
72 | TrainLoop(
73 | model=model,
74 | diffusion=diffusion,
75 | data=data,
76 | batch_size=args.batch_size,
77 | microbatch=args.microbatch,
78 | lr=args.lr,
79 | ema_rate=args.ema_rate,
80 | log_interval=args.log_interval,
81 | save_interval=args.save_interval,
82 | resume_checkpoint=args.resume_checkpoint,
83 | use_fp16=args.use_fp16,
84 | fp16_scale_growth=args.fp16_scale_growth,
85 | schedule_sampler=schedule_sampler,
86 | weight_decay=args.weight_decay,
87 | lr_anneal_steps=args.lr_anneal_steps,
88 | ori_model=ori_model,
89 | wm_length=args.wm_length,
90 | alpha=args.alpha,
91 | threshold=args.threshold,
92 | wm_decoder=wm_decoder
93 | ).run_loop()
94 |
95 |
96 | def create_argparser():
97 | defaults = dict(
98 | data_dir="",
99 | schedule_sampler="uniform",
100 | lr=1e-4,
101 | weight_decay=0.0,
102 | lr_anneal_steps=0,
103 | batch_size=1,
104 | microbatch=-1, # -1 disables microbatches
105 | ema_rate="0.9999", # comma-separated list of EMA values
106 | log_interval=10,
107 | save_interval=10000,
108 | resume_checkpoint="",
109 | use_fp16=False,
110 | fp16_scale_growth=1e-3,
111 | wm_length=48,
112 | alpha=0.4,
113 | threshold=400,
114 | wm_decoder_path='./',
115 | )
116 | defaults.update(model_and_diffusion_defaults())
117 | parser = argparse.ArgumentParser()
118 | add_dict_to_argparser(parser, defaults)
119 | return parser
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/super_res_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of samples from a super resolution model, given a batch
3 | of samples from a regular model from image_sample.py.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import blobfile as bf
10 | import numpy as np
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | from guided_diffusion import dist_util, logger
15 | from guided_diffusion.script_util import (
16 | sr_model_and_diffusion_defaults,
17 | sr_create_model_and_diffusion,
18 | args_to_dict,
19 | add_dict_to_argparser,
20 | )
21 |
22 |
23 | def main():
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure()
28 |
29 | logger.log("creating model...")
30 | model, diffusion = sr_create_model_and_diffusion(
31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
32 | )
33 | model.load_state_dict(
34 | dist_util.load_state_dict(args.model_path, map_location="cpu")
35 | )
36 | model.to(dist_util.dev())
37 | if args.use_fp16:
38 | model.convert_to_fp16()
39 | model.eval()
40 |
41 | logger.log("loading data...")
42 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond)
43 |
44 | logger.log("creating samples...")
45 | all_images = []
46 | while len(all_images) * args.batch_size < args.num_samples:
47 | model_kwargs = next(data)
48 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
49 | sample = diffusion.p_sample_loop(
50 | model,
51 | (args.batch_size, 3, args.large_size, args.large_size),
52 | clip_denoised=args.clip_denoised,
53 | model_kwargs=model_kwargs,
54 | )
55 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
56 | sample = sample.permute(0, 2, 3, 1)
57 | sample = sample.contiguous()
58 |
59 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
60 | dist.all_gather(all_samples, sample) # gather not supported with NCCL
61 | for sample in all_samples:
62 | all_images.append(sample.cpu().numpy())
63 | logger.log(f"created {len(all_images) * args.batch_size} samples")
64 |
65 | arr = np.concatenate(all_images, axis=0)
66 | arr = arr[: args.num_samples]
67 | if dist.get_rank() == 0:
68 | shape_str = "x".join([str(x) for x in arr.shape])
69 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
70 | logger.log(f"saving to {out_path}")
71 | np.savez(out_path, arr)
72 |
73 | dist.barrier()
74 | logger.log("sampling complete")
75 |
76 |
77 | def load_data_for_worker(base_samples, batch_size, class_cond):
78 | with bf.BlobFile(base_samples, "rb") as f:
79 | obj = np.load(f)
80 | image_arr = obj["arr_0"]
81 | if class_cond:
82 | label_arr = obj["arr_1"]
83 | rank = dist.get_rank()
84 | num_ranks = dist.get_world_size()
85 | buffer = []
86 | label_buffer = []
87 | while True:
88 | for i in range(rank, len(image_arr), num_ranks):
89 | buffer.append(image_arr[i])
90 | if class_cond:
91 | label_buffer.append(label_arr[i])
92 | if len(buffer) == batch_size:
93 | batch = th.from_numpy(np.stack(buffer)).float()
94 | batch = batch / 127.5 - 1.0
95 | batch = batch.permute(0, 3, 1, 2)
96 | res = dict(low_res=batch)
97 | if class_cond:
98 | res["y"] = th.from_numpy(np.stack(label_buffer))
99 | yield res
100 | buffer, label_buffer = [], []
101 |
102 |
103 | def create_argparser():
104 | defaults = dict(
105 | clip_denoised=True,
106 | num_samples=10000,
107 | batch_size=16,
108 | use_ddim=False,
109 | base_samples="",
110 | model_path="",
111 | )
112 | defaults.update(sr_model_and_diffusion_defaults())
113 | parser = argparse.ArgumentParser()
114 | add_dict_to_argparser(parser, defaults)
115 | return parser
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/guided-diffusion/scripts/super_res_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a super-resolution model.
3 | """
4 |
5 | import argparse
6 |
7 | import torch.nn.functional as F
8 |
9 | from guided_diffusion import dist_util, logger
10 | from guided_diffusion.image_datasets import load_data
11 | from guided_diffusion.resample import create_named_schedule_sampler
12 | from guided_diffusion.script_util import (
13 | sr_model_and_diffusion_defaults,
14 | sr_create_model_and_diffusion,
15 | args_to_dict,
16 | add_dict_to_argparser,
17 | )
18 | from guided_diffusion.train_util import TrainLoop
19 |
20 |
21 | def main():
22 | args = create_argparser().parse_args()
23 |
24 | dist_util.setup_dist()
25 | logger.configure()
26 |
27 | logger.log("creating model...")
28 | model, diffusion = sr_create_model_and_diffusion(
29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
30 | )
31 | model.to(dist_util.dev())
32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
33 |
34 | logger.log("creating data loader...")
35 | data = load_superres_data(
36 | args.data_dir,
37 | args.batch_size,
38 | large_size=args.large_size,
39 | small_size=args.small_size,
40 | class_cond=args.class_cond,
41 | )
42 |
43 | logger.log("training...")
44 | TrainLoop(
45 | model=model,
46 | diffusion=diffusion,
47 | data=data,
48 | batch_size=args.batch_size,
49 | microbatch=args.microbatch,
50 | lr=args.lr,
51 | ema_rate=args.ema_rate,
52 | log_interval=args.log_interval,
53 | save_interval=args.save_interval,
54 | resume_checkpoint=args.resume_checkpoint,
55 | use_fp16=args.use_fp16,
56 | fp16_scale_growth=args.fp16_scale_growth,
57 | schedule_sampler=schedule_sampler,
58 | weight_decay=args.weight_decay,
59 | lr_anneal_steps=args.lr_anneal_steps,
60 | ).run_loop()
61 |
62 |
63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False):
64 | data = load_data(
65 | data_dir=data_dir,
66 | batch_size=batch_size,
67 | image_size=large_size,
68 | class_cond=class_cond,
69 | )
70 | for large_batch, model_kwargs in data:
71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area")
72 | yield large_batch, model_kwargs
73 |
74 |
75 | def create_argparser():
76 | defaults = dict(
77 | data_dir="",
78 | schedule_sampler="uniform",
79 | lr=1e-4,
80 | weight_decay=0.0,
81 | lr_anneal_steps=0,
82 | batch_size=1,
83 | microbatch=-1,
84 | ema_rate="0.9999",
85 | log_interval=10,
86 | save_interval=10000,
87 | resume_checkpoint="",
88 | use_fp16=False,
89 | fp16_scale_growth=1e-3,
90 | )
91 | defaults.update(sr_model_and_diffusion_defaults())
92 | parser = argparse.ArgumentParser()
93 | add_dict_to_argparser(parser, defaults)
94 | return parser
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/guided-diffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="guided-diffusion",
5 | py_modules=["guided_diffusion"],
6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7 | )
8 |
--------------------------------------------------------------------------------
/guided-diffusion/train.sh:
--------------------------------------------------------------------------------
1 | MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --image_size 256 --num_channels 256 --learn_sigma True --num_head_channels 64 --num_res_blocks 2 --resblock_updown True"
2 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
3 | TRAIN_FLAGS="--lr 1e-4 --batch_size 4"
4 | NUM_GPUS=1
5 | mpiexec -n $NUM_GPUS python scripts/image_train.py --alpha 0.4 --threshold 400 --wm_decoder_path ./ --data_dir ../../../../data/imagenet/val --resume_checkpoint models/256x256_diffusion_uncond.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
6 |
7 |
--------------------------------------------------------------------------------
/pics/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/pics/framework.png
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/stable-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/stable-diffusion/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/stable-diffusion/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/stable-diffusion/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - transformers==4.19.2
27 | - torchmetrics==0.6.0
28 | - kornia==0.6
29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31 | - -e .
32 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/data/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/stable-diffusion/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/__pycache__/autoencoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/__pycache__/autoencoder.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6 |
7 |
8 | class DPMSolverSampler(object):
9 | def __init__(self, model, **kwargs):
10 | super().__init__()
11 | self.model = model
12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14 |
15 | def register_buffer(self, name, attr):
16 | if type(attr) == torch.Tensor:
17 | if attr.device != torch.device("cuda"):
18 | attr = attr.to(torch.device("cuda"))
19 | setattr(self, name, attr)
20 |
21 | @torch.no_grad()
22 | def sample(self,
23 | S,
24 | batch_size,
25 | shape,
26 | conditioning=None,
27 | callback=None,
28 | normals_sequence=None,
29 | img_callback=None,
30 | quantize_x0=False,
31 | eta=0.,
32 | mask=None,
33 | x0=None,
34 | temperature=1.,
35 | noise_dropout=0.,
36 | score_corrector=None,
37 | corrector_kwargs=None,
38 | verbose=True,
39 | x_T=None,
40 | log_every_t=100,
41 | unconditional_guidance_scale=1.,
42 | unconditional_conditioning=None,
43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44 | **kwargs
45 | ):
46 | if conditioning is not None:
47 | if isinstance(conditioning, dict):
48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49 | if cbs != batch_size:
50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51 | else:
52 | if conditioning.shape[0] != batch_size:
53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54 |
55 | # sampling
56 | C, H, W = shape
57 | size = (batch_size, C, H, W)
58 |
59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60 |
61 | device = self.model.betas.device
62 | if x_T is None:
63 | img = torch.randn(size, device=device)
64 | else:
65 | img = x_T
66 |
67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68 |
69 | model_fn = model_wrapper(
70 | lambda x, t, c: self.model.apply_model(x, t, c),
71 | ns,
72 | model_type="noise",
73 | guidance_type="classifier-free",
74 | condition=conditioning,
75 | unconditional_condition=unconditional_conditioning,
76 | guidance_scale=unconditional_guidance_scale,
77 | )
78 |
79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81 |
82 | return x.to(device), None
83 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/__pycache__/ema.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/ema.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/__pycache__/x_transformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/__pycache__/x_transformer.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/stable-diffusion/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/stable-diffusion/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable-diffusion/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/stable-diffusion/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00000.png
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00001.png
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00002.png
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00003.png
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00004.png
--------------------------------------------------------------------------------
/stable-diffusion/outputs/txt2img-samples/samples/00005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rmin2000/WaDiff/33e657de73c3e75e82ea66bcec14012a2f4fcb6a/stable-diffusion/outputs/txt2img-samples/samples/00005.png
--------------------------------------------------------------------------------
/stable-diffusion/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/stable-diffusion/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/stable-diffusion/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/stable-diffusion/scripts/tests/test_watermark.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import fire
3 | from imwatermark import WatermarkDecoder
4 |
5 |
6 | def testit(img_path):
7 | bgr = cv2.imread(img_path)
8 | decoder = WatermarkDecoder('bytes', 136)
9 | watermark = decoder.decode(bgr, 'dwtDct')
10 | try:
11 | dec = watermark.decode('utf-8')
12 | except:
13 | dec = "null"
14 | print(dec)
15 |
16 |
17 | if __name__ == "__main__":
18 | fire.Fire(testit)
--------------------------------------------------------------------------------
/stable-diffusion/scripts/train_searcher.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import scann
4 | import argparse
5 | import glob
6 | from multiprocessing import cpu_count
7 | from tqdm import tqdm
8 |
9 | from ldm.util import parallel_data_prefetch
10 |
11 |
12 | def search_bruteforce(searcher):
13 | return searcher.score_brute_force().build()
14 |
15 |
16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17 | partioning_trainsize, num_leaves, num_leaves_to_search):
18 | return searcher.tree(num_leaves=num_leaves,
19 | num_leaves_to_search=num_leaves_to_search,
20 | training_sample_size=partioning_trainsize). \
21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22 |
23 |
24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26 | reorder_k).build()
27 |
28 | def load_datapool(dpath):
29 |
30 |
31 | def load_single_file(saved_embeddings):
32 | compressed = np.load(saved_embeddings)
33 | database = {key: compressed[key] for key in compressed.files}
34 | return database
35 |
36 | def load_multi_files(data_archive):
37 | database = {key: [] for key in data_archive[0].files}
38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39 | for key in d.files:
40 | database[key].append(d[key])
41 |
42 | return database
43 |
44 | print(f'Load saved patch embedding from "{dpath}"')
45 | file_content = glob.glob(os.path.join(dpath, '*.npz'))
46 |
47 | if len(file_content) == 1:
48 | data_pool = load_single_file(file_content[0])
49 | elif len(file_content) > 1:
50 | data = [np.load(f) for f in file_content]
51 | prefetched_data = parallel_data_prefetch(load_multi_files, data,
52 | n_proc=min(len(data), cpu_count()), target_data_type='dict')
53 |
54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55 | else:
56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57 |
58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59 | return data_pool
60 |
61 |
62 | def train_searcher(opt,
63 | metric='dot_product',
64 | partioning_trainsize=None,
65 | reorder_k=None,
66 | # todo tune
67 | aiq_thld=0.2,
68 | dims_per_block=2,
69 | num_leaves=None,
70 | num_leaves_to_search=None,):
71 |
72 | data_pool = load_datapool(opt.database)
73 | k = opt.knn
74 |
75 | if not reorder_k:
76 | reorder_k = 2 * k
77 |
78 | # normalize
79 | # embeddings =
80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81 | pool_size = data_pool['embedding'].shape[0]
82 |
83 | print(*(['#'] * 100))
84 | print('Initializing scaNN searcher with the following values:')
85 | print(f'k: {k}')
86 | print(f'metric: {metric}')
87 | print(f'reorder_k: {reorder_k}')
88 | print(f'anisotropic_quantization_threshold: {aiq_thld}')
89 | print(f'dims_per_block: {dims_per_block}')
90 | print(*(['#'] * 100))
91 | print('Start training searcher....')
92 | print(f'N samples in pool is {pool_size}')
93 |
94 | # this reflects the recommended design choices proposed at
95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96 | if pool_size < 2e4:
97 | print('Using brute force search.')
98 | searcher = search_bruteforce(searcher)
99 | elif 2e4 <= pool_size and pool_size < 1e5:
100 | print('Using asymmetric hashing search and reordering.')
101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102 | else:
103 | print('Using using partioning, asymmetric hashing search and reordering.')
104 |
105 | if not partioning_trainsize:
106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10
107 | if not num_leaves:
108 | num_leaves = int(np.sqrt(pool_size))
109 |
110 | if not num_leaves_to_search:
111 | num_leaves_to_search = max(num_leaves // 20, 1)
112 |
113 | print('Partitioning params:')
114 | print(f'num_leaves: {num_leaves}')
115 | print(f'num_leaves_to_search: {num_leaves_to_search}')
116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118 | partioning_trainsize, num_leaves, num_leaves_to_search)
119 |
120 | print('Finish training searcher')
121 | searcher_savedir = opt.target_path
122 | os.makedirs(searcher_savedir, exist_ok=True)
123 | searcher.serialize(searcher_savedir)
124 | print(f'Saved trained searcher under "{searcher_savedir}"')
125 |
126 | if __name__ == '__main__':
127 | sys.path.append(os.getcwd())
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--database',
130 | '-d',
131 | default='data/rdm/retrieval_databases/openimages',
132 | type=str,
133 | help='path to folder containing the clip feature of the database')
134 | parser.add_argument('--target_path',
135 | '-t',
136 | default='data/rdm/searchers/openimages',
137 | type=str,
138 | help='path to the target folder where the searcher shall be stored.')
139 | parser.add_argument('--knn',
140 | '-k',
141 | default=20,
142 | type=int,
143 | help='number of nearest neighbors, for which the searcher shall be optimized')
144 |
145 | opt, _ = parser.parse_known_args()
146 |
147 | train_searcher(opt,)
--------------------------------------------------------------------------------
/stable-diffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/trace.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import glob
3 | import argparse
4 | from StegaStamp import models
5 | from torchvision import transforms
6 | from PIL import Image
7 | import numpy as np
8 |
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | "--image_path", type=str, default='./guided-diffusion/saved_images/', help="Directory of watermark examples."
13 | )
14 | parser.add_argument(
15 | "--bit_length", type=int, default=48, help="Length of watermark bits."
16 | )
17 | parser.add_argument(
18 | "--model_type", type=str, default='imagenet', choices=['imagenet', 'stable'], help="ImageNet Diffusion or Stable Diffusion."
19 | )
20 | parser.add_argument(
21 | "--checkpoint", type=str, default='./', help="Checkpoint of the watermark decoder."
22 | )
23 | parser.add_argument(
24 | "--device", type=int, default=0, help="GPU index."
25 | )
26 | parser.add_argument(
27 | "--detection_thre", type=float, default=0.8, help="Bit threshold for detection."
28 | )
29 |
30 |
31 | args = parser.parse_args()
32 |
33 |
34 |
35 | def trace(image_path, decoder, detection_thre, device):
36 |
37 | count_pool_1e4, count_pool_1e5, count_pool_1e6, count_detection = 0, 0, 0, 0
38 | decoder.to(args.device)
39 |
40 | # Load pre-defined watermarks and watermarked images
41 | user_pool_1e4 = np.load(f'watermark_pool/{bit_length}_1e4.npy')
42 | user_pool_1e5 = np.load(f'watermark_pool/{bit_length}_1e5.npy')
43 | user_pool_1e6 = np.load(f'watermark_pool/{bit_length}_1e6.npy')
44 |
45 | image_path_list = glob.glob(image_path + '*/*.png')
46 |
47 |
48 | for path in image_path_list:
49 | img = transforms.ToTensor()(Image.open(path)).to(device)
50 | user_index = int(path.split('/')[-2])
51 |
52 | fingerprints_predicted = (decoder(img) > 0).float().cpu().numpy()
53 |
54 | if 1 - np.abs(fingerprints_predicted - user_pool_1e4[user_index]).sum(0) / len(fingerprints_predicted) > detection_thre
55 | count_detection += 1
56 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e4).sum(0)) == user_index:
57 | count_pool_1e4 += 1
58 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e5).sum(0)) == user_index:
59 | count_pool_1e5 += 1
60 | if np.argmin(np.abs(fingerprints_predicted - user_pool_1e6).sum(0)) == user_index:
61 | count_pool_1e6 += 1
62 |
63 | return {'trace1e4': count_pool_1e4/1e4, 'trace1e5': count_pool_1e5/1e5, 'trace1e6': count_pool_1e6/1e6, 'detection_acc': count_detection/len(image_path_list)}
64 |
65 |
66 |
67 |
68 | if __name__ == "__main__":
69 |
70 | # Create watermark decoder
71 | decoder = models.StegaStampDecoder(
72 | 256 if args.model_type == 'imagenet' else 512,
73 | 3,
74 | args.bit_length,
75 | )
76 | decoder.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
77 |
78 | # Perform tracing
79 | result = trace(args.image_path, decoder, args.detection_thre, args.device)
80 | print(result['trace1e4'], result['trace1e5'], result['trace1e6'], result['detection_acc'])
81 |
--------------------------------------------------------------------------------