├── README.md ├── StableDiffusion_exps.py ├── StableDiff_toy_celebA.py ├── Diffusion_training_demo.py ├── net_models.py ├── StableDiff_UNet_unittest.py └── StableDiff_UNet_model.py /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion From Scratch 2 | 3 | Binxu Wang (binxu_wang@hms.harvard.edu) 4 | 5 | Tutorial on Stable Diffusion Models at ML from Scratch seminar series at Harvard. 6 | ![](https://scholar.harvard.edu/sites/scholar.harvard.edu/files/styles/os_files_large/public/binxuw/files/diffusion_proc1.gif?m=1667441103&itok=y1BDYFl1) 7 | * [Homepage](https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/stable-diffusion-scratch) 8 | * [Tutorial Slides](https://scholar.harvard.edu/files/binxuw/files/stable_diffusion_a_tutorial.pdf) 9 | 10 | This tiny self-contained code base allows you to 11 | * Rebuild the Stable Diffusion Model in a single Python script. 12 | * Train your toy version of stable diffusion on classic datasets like MNIST, CelebA 13 | ![](https://scholar.harvard.edu/sites/scholar.harvard.edu/files/styles/os_files_xxlarge/public/binxuw/files/stablediffusion_overview.jpg?m=1667438590&itok=n2gM0Xba) 14 | 15 | ## Colab notebooks 16 | * Playing with Stable Diffusion and inspecting the internal architecture of the models. [Open in Colab](https://colab.research.google.com/drive/1TvOlX2_l4pCBOKjDI672JcMm4q68sKrA?usp=sharing) 17 | * Build your own Stable Diffusion UNet model from scratch in a notebook. (with < 300 lines of codes!) [Open in Colab](https://colab.research.google.com/drive/1mm67_irYu3qU3hnfzqK5yQC38Fd5UFam?usp=sharing) 18 | * [Self contained script](https://github.com/Animadversio/DiffusionFromScratch/blob/master/StableDiff_UNet_model.py) 19 | * [Unit tests](https://github.com/Animadversio/DiffusionFromScratch/blob/master/StableDiff_UNet_unittest.py) 20 | * Build a Diffusion model (with UNet + cross attention) and train it to generate MNIST images based on the "text prompt". [Open in Colab (exercise)](https://colab.research.google.com/drive/1Y5wr91g5jmpCDiX-RLfWL1eSBWoSuLqO?usp=sharing) [Open in Colab (answer)](https://colab.research.google.com/drive/1_MEFfBdOI06GAuANrs1b8L-BBLn3x-ZJ?usp=sharing) 21 | 22 | 23 | ## Demo Outputs 24 | [![IMAGE ALT TEXT HERE](https://img.youtube.com/vi/SmY7vMNen2w/0.jpg)](https://www.youtube.com/watch?v=SmY7vMNen2w) 25 | 26 | [Music video](https://github.com/nateraw/stable-diffusion-videos) generated from Stable Diffusion. 27 | 28 | ## Star History 29 | 30 | 31 | 32 | 33 | 34 | Star History Chart 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /StableDiffusion_exps.py: -------------------------------------------------------------------------------- 1 | """Experimenting with StableDiffusion and our version. """ 2 | 3 | import torch 4 | from torch import autocast 5 | from diffusers import StableDiffusionPipeline 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def plt_show_image(image): 10 | plt.figure(figsize=(8, 8)) 11 | plt.imshow(image) 12 | plt.axis("off") 13 | plt.tight_layout() 14 | plt.show() 15 | 16 | 17 | def recursive_print(module, prefix="", depth=0, deepest=3): 18 | """Simulating print(module) for torch.nn.Modules 19 | but with depth control. Print to the `deepest` level. `deepest=0` means no print 20 | """ 21 | if depth >= deepest: 22 | return 23 | for name, child in module.named_children(): 24 | if len([*child.named_children()]) == 0: 25 | print(f"{prefix}({name}): {child}") 26 | else: 27 | print(f"{prefix}({name}): {type(child).__name__}") 28 | recursive_print(child, prefix + " ", depth + 1, deepest) 29 | 30 | #%% 31 | 32 | pipe = StableDiffusionPipeline.from_pretrained( 33 | "CompVis/stable-diffusion-v1-4", 34 | use_auth_token=True 35 | ).to("cuda") 36 | def dummy_checker(images, **kwargs): return images, False 37 | pipe.safety_checker = dummy_checker 38 | #%% 39 | recursive_print(pipe.unet, deepest=2) 40 | #%% Text to 41 | # prompt = "a photo of an ballerina riding a horse on mars" 42 | prompt = "A ballerina riding a Harley Motorcycle, CG Art" 43 | with autocast("cuda"): 44 | image = pipe(prompt)["sample"][0] 45 | 46 | image.save("astronaut_rides_horse.png") 47 | plt_show_image(image) 48 | #%% Loading in our own model! 49 | from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet 50 | myunet = UNet_SD() 51 | original_unet = pipe.unet.cpu() 52 | load_pipe_into_our_UNet(myunet, original_unet) 53 | pipe.unet = myunet.cuda() 54 | #%% 55 | torch.save(myunet.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/ourUNet.pth") 56 | 57 | 58 | #%% Saving images during diffusion process using callback 59 | 60 | latents_reservoir = [] 61 | @torch.no_grad() 62 | def plot_show_callback(i, t, latents): 63 | latents_reservoir.append(latents.detach().cpu()) 64 | latents = 1 / 0.18215 * latents 65 | image = pipe.vae.decode(latents).sample 66 | image = (image / 2 + 0.5).clamp(0, 1) 67 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 68 | plt_show_image(image[0]) 69 | plt.imsave(f"/home/binxuwang/DL_Projects/SDfromScratch/diffproc/sample_{i:02d}.png", image[0]) 70 | 71 | latents_reservoir = [] 72 | @torch.no_grad() 73 | def save_latents(i, t, latents): 74 | latents_reservoir.append(latents.detach().cpu()) 75 | #%% 76 | # prompt = "A ballerina dancing on a high ground in the starry night" 77 | # prompt = "A cute cat running on the grass in the style of Monet" 78 | prompt = "A ballerina chasing her cat running on the grass in the style of Monet" 79 | prompt = "A kitty cat dressed like Lincoln, old timey style" 80 | with autocast("cuda"): 81 | image = pipe(prompt, callback=plot_show_callback)["sample"][0] # plot_show_callback 82 | 83 | image.save("cat_Lincoln.png") 84 | plt_show_image(image) 85 | #%% 86 | len(latents_reservoir) 87 | plt_show_image(latents_reservoir[-1][0, [0, 1, 2,], :].permute(1, 2, 0).cpu().numpy() / 1.6 + 0.4) 88 | #%% Visualize architecture 89 | 90 | #%% Full unets 91 | recursive_print(pipe.unet, deepest=3) 92 | #%% 93 | recursive_print(pipe.vae, deepest=3) 94 | #%% Down blocks 95 | recursive_print(pipe.unet.down_blocks, deepest=4) 96 | #%% Up blocks 97 | recursive_print(pipe.unet.up_blocks, deepest=4) 98 | #%% 99 | torch.save(pipe.unet.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/SD_unet.pt",) 100 | torch.save(pipe.vae.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/SD_vae.pt",) 101 | #%% 102 | SD_unet = torch.load("/home/binxuwang/DL_Projects/SDfromScratch/SD_unet.pt") 103 | #%% 104 | # https://github.com/CompVis/stable-diffusion/blob/main/configs/stable-diffusion/v1-inference.yaml -------------------------------------------------------------------------------- /StableDiff_toy_celebA.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import functools 4 | from tqdm import tqdm, trange 5 | import torch.multiprocessing 6 | from tqdm import tqdm 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | torch.multiprocessing.set_sharing_strategy('file_system') 10 | #%% 11 | from torch.utils.data import DataLoader, TensorDataset 12 | from torchvision.datasets import CelebA 13 | from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize 14 | 15 | 16 | tfm = Compose([ 17 | Resize(32), 18 | CenterCrop(32), 19 | ToTensor(), 20 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"], 23 | transform=tfm, download=False) # ,"identity" 24 | #%% 25 | dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False) 26 | x_col = [] 27 | y_col = [] 28 | for xs, ys in tqdm(dataloader): 29 | x_col.append(xs) 30 | y_col.append(ys) 31 | x_col = torch.concat(x_col, dim=0) 32 | y_col = torch.concat(y_col, dim=0) 33 | print(x_col.shape) 34 | print(y_col.shape) 35 | 36 | nantoken = 40 37 | maxlen = (y_col.sum(dim=1)).max() 38 | yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken) 39 | 40 | saved_dataset = TensorDataset(x_col, yseq_data) 41 | #%% 42 | import math 43 | from torch.optim import Adam 44 | from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR 45 | device = 'cuda' 46 | 47 | def marginal_prob_std(t, sigma): 48 | t = torch.tensor(t, device=device) 49 | return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma)) 50 | 51 | 52 | def diffusion_coeff(t, sigma): 53 | return torch.tensor(sigma ** t, device=device) 54 | 55 | 56 | sigma = 25.0 # @param {'type':'number'} 57 | marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) 58 | diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) 59 | #% 60 | #@title Training Loss function 61 | def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): 62 | """The loss function for training score-based generative models. 63 | 64 | Args: 65 | model: A PyTorch model instance that represents a 66 | time-dependent score-based model. 67 | x: A mini-batch of training data. 68 | marginal_prob_std: A function that gives the standard deviation of 69 | the perturbation kernel. 70 | eps: A tolerance value for numerical stability. 71 | """ 72 | random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps 73 | z = torch.randn_like(x) 74 | std = marginal_prob_std(random_t) 75 | perturbed_x = x + z * std[:, None, None, None] 76 | score = model(perturbed_x, random_t, cond=y, output_dict=False) 77 | loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) 78 | return loss 79 | 80 | #% 81 | def train_score_model(score_model, cond_embed, dataset, lr, n_epochs, batch_size, ckpt_name, 82 | marginal_prob_std_fn=marginal_prob_std_fn, 83 | lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch), 84 | device="cuda", 85 | callback=None): # resume=False, 86 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) 87 | optimizer = Adam([*score_model.parameters(), *cond_embed.parameters()], lr=lr) 88 | scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn) 89 | tqdm_epoch = trange(n_epochs) 90 | for epoch in tqdm_epoch: 91 | score_model.train() 92 | avg_loss = 0. 93 | num_items = 0 94 | batch_tqdm = tqdm(data_loader) 95 | for x, y in batch_tqdm: 96 | x = x.to(device) 97 | y_emb = cond_embed(y.to(device)) 98 | loss = loss_fn_cond(score_model, x, y_emb, marginal_prob_std_fn) 99 | optimizer.zero_grad() 100 | loss.backward() 101 | optimizer.step() 102 | avg_loss += loss.item() * x.shape[0] 103 | num_items += x.shape[0] 104 | batch_tqdm.set_description("Epoch %d, loss %.4f" % (epoch, avg_loss / num_items)) 105 | scheduler.step() 106 | lr_current = scheduler.get_last_lr()[0] 107 | print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current)) 108 | # Print the averaged training loss so far. 109 | tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) 110 | # Update the checkpoint after each epoch of training. 111 | torch.save(score_model.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}.pth') 112 | torch.save(cond_embed.state_dict(), 113 | f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}_cond_embed.pth') 114 | if callback is not None: 115 | score_model.eval() 116 | callback(score_model, epoch, ckpt_name) 117 | #%% 118 | def Euler_Maruyama_sampler(score_model, 119 | marginal_prob_std, 120 | diffusion_coeff, 121 | batch_size=64, 122 | x_shape=(1, 28, 28), 123 | num_steps=500, 124 | device='cuda', 125 | eps=1e-3, 126 | y=None): 127 | """Generate samples from score-based models with the Euler-Maruyama solver. 128 | 129 | Args: 130 | score_model: A PyTorch model that represents the time-dependent score-based model. 131 | marginal_prob_std: A function that gives the standard deviation of 132 | the perturbation kernel. 133 | diffusion_coeff: A function that gives the diffusion coefficient of the SDE. 134 | batch_size: The number of samplers to generate by calling this function once. 135 | num_steps: The number of sampling steps. 136 | Equivalent to the number of discretized time steps. 137 | device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs. 138 | eps: The smallest time step for numerical stability. 139 | 140 | Returns: 141 | Samples. 142 | """ 143 | t = torch.ones(batch_size, device=device) 144 | init_x = torch.randn(batch_size, *x_shape, device=device) \ 145 | * marginal_prob_std(t)[:, None, None, None] 146 | time_steps = torch.linspace(1., eps, num_steps, device=device) 147 | step_size = time_steps[0] - time_steps[1] 148 | x = init_x 149 | with torch.no_grad(): 150 | for time_step in tqdm(time_steps): 151 | batch_time_step = torch.ones(batch_size, device=device) * time_step 152 | g = diffusion_coeff(batch_time_step) 153 | mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size 154 | x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x) 155 | # Do not include any noise in the last sampling step. 156 | return mean_x 157 | #%% 158 | import matplotlib.pyplot as plt 159 | from torchvision.utils import make_grid 160 | def save_sample_callback(score_model, epocs, ckpt_name): 161 | sample_batch_size = 64 162 | num_steps = 250 163 | y_samp = yseq_data[:sample_batch_size, :] 164 | y_emb = cond_embed(y_samp.cuda()) 165 | sampler = Euler_Maruyama_sampler 166 | samples = sampler(score_model, 167 | marginal_prob_std_fn, 168 | diffusion_coeff_fn, 169 | sample_batch_size, 170 | x_shape=(3, 32, 32), 171 | num_steps=num_steps, 172 | device=device, 173 | y=y_emb, ) 174 | denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], 175 | [1/0.229, 1/0.224, 1/0.225]) 176 | samples = denormalize(samples).clamp(0.0, 1.0) 177 | sample_grid = make_grid(samples, nrow=int(math.sqrt(sample_batch_size))) 178 | 179 | plt.figure(figsize=(8, 8)) 180 | plt.axis('off') 181 | plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) 182 | plt.tight_layout() 183 | plt.savefig(f"/home/binxuwang/DL_Projects/SDfromScratch/samples_{ckpt_name}_{epocs}.png") 184 | plt.show() 185 | #%% 186 | from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet 187 | #%% UNet without latent space no VAE 188 | unet_face = UNet_SD(in_channels=3, 189 | base_channels=128, 190 | time_emb_dim=256, 191 | context_dim=256, 192 | multipliers=(1, 1, 2), 193 | attn_levels=(1, 2, ), 194 | nResAttn_block=1, 195 | ) 196 | cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).cuda() 197 | #%% 198 | torch.save(unet_face.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/SD_unet_face.pt",) 199 | #%% 200 | unet_face(torch.randn(1, 3, 64, 64).cuda(), time_steps=torch.rand(1).cuda(), 201 | cond=torch.randn(1, 20, 256).cuda(), 202 | output_dict=False) 203 | #%% 204 | #%% 205 | train_score_model(unet_face, cond_embed, saved_dataset, 206 | lr=1.5e-4, n_epochs=100, batch_size=256, 207 | ckpt_name="unet_SD_face", device=device, 208 | callback=save_sample_callback) 209 | 210 | #%% 211 | 212 | 213 | save_sample_callback(unet_face, 0, "unet_SD_face") 214 | #%% 215 | torch.save(cond_embed.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{"unet_SD_face"}_cond_embed.pth') 216 | -------------------------------------------------------------------------------- /Diffusion_training_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | from tqdm import tqdm, trange 4 | import torch.multiprocessing 5 | torch.multiprocessing.set_sharing_strategy('file_system') 6 | 7 | # @title Diffusion constant and noise strength 8 | device = 'cuda' # @param ['cuda', 'cpu'] {'type':'string'} 9 | 10 | def marginal_prob_std(t, sigma): 11 | """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. 12 | 13 | Args: 14 | t: A vector of time steps. 15 | sigma: The $\sigma$ in our SDE. 16 | 17 | Returns: 18 | The standard deviation. 19 | """ 20 | t = torch.tensor(t, device=device) 21 | return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma)) 22 | 23 | 24 | def diffusion_coeff(t, sigma): 25 | """Compute the diffusion coefficient of our SDE. 26 | 27 | Args: 28 | t: A vector of time steps. 29 | sigma: The $\sigma$ in our SDE. 30 | 31 | Returns: 32 | The vector of diffusion coefficients. 33 | """ 34 | return torch.tensor(sigma ** t, device=device) 35 | 36 | 37 | sigma = 25.0 # @param {'type':'number'} 38 | marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) 39 | diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) 40 | #%% 41 | #@title Training Loss function 42 | def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): 43 | """The loss function for training score-based generative models. 44 | 45 | Args: 46 | model: A PyTorch model instance that represents a 47 | time-dependent score-based model. 48 | x: A mini-batch of training data. 49 | marginal_prob_std: A function that gives the standard deviation of 50 | the perturbation kernel. 51 | eps: A tolerance value for numerical stability. 52 | """ 53 | random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps 54 | z = torch.randn_like(x) 55 | std = marginal_prob_std(random_t) 56 | perturbed_x = x + z * std[:, None, None, None] 57 | score = model(perturbed_x, random_t, y=y) 58 | loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) 59 | return loss 60 | #%% 61 | # @title Diffusion Model Sampler 62 | def Euler_Maruyama_sampler(score_model, 63 | marginal_prob_std, 64 | diffusion_coeff, 65 | batch_size=64, 66 | x_shape=(1, 28, 28), 67 | num_steps=500, 68 | device='cuda', 69 | eps=1e-3, 70 | y=None): 71 | """Generate samples from score-based models with the Euler-Maruyama solver. 72 | 73 | Args: 74 | score_model: A PyTorch model that represents the time-dependent score-based model. 75 | marginal_prob_std: A function that gives the standard deviation of 76 | the perturbation kernel. 77 | diffusion_coeff: A function that gives the diffusion coefficient of the SDE. 78 | batch_size: The number of samplers to generate by calling this function once. 79 | num_steps: The number of sampling steps. 80 | Equivalent to the number of discretized time steps. 81 | device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs. 82 | eps: The smallest time step for numerical stability. 83 | 84 | Returns: 85 | Samples. 86 | """ 87 | t = torch.ones(batch_size, device=device) 88 | init_x = torch.randn(batch_size, *x_shape, device=device) \ 89 | * marginal_prob_std(t)[:, None, None, None] 90 | time_steps = torch.linspace(1., eps, num_steps, device=device) 91 | step_size = time_steps[0] - time_steps[1] 92 | x = init_x 93 | with torch.no_grad(): 94 | for time_step in tqdm(time_steps): 95 | batch_time_step = torch.ones(batch_size, device=device) * time_step 96 | g = diffusion_coeff(batch_time_step) 97 | mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size 98 | x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x) 99 | # Do not include any noise in the last sampling step. 100 | return mean_x 101 | 102 | #%% 103 | import torch 104 | import torch.nn as nn 105 | import torch.nn.functional as F 106 | import numpy as np 107 | import functools 108 | import math 109 | from torch.optim import Adam 110 | from torch.utils.data import DataLoader 111 | import torchvision.transforms as transforms 112 | from torchvision.datasets import MNIST 113 | import tqdm 114 | from tqdm.notebook import trange, tqdm 115 | from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR 116 | 117 | import matplotlib.pyplot as plt 118 | 119 | from torchvision.utils import make_grid 120 | from einops import rearrange 121 | 122 | #@title Diffusion Trainer 123 | def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name, 124 | marginal_prob_std_fn=marginal_prob_std_fn, 125 | lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch), 126 | device="cuda", 127 | callback=None): # resume=False, 128 | data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) 129 | 130 | optimizer = Adam(score_model.parameters(), lr=lr) 131 | scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn) 132 | tqdm_epoch = trange(n_epochs) 133 | for epoch in tqdm_epoch: 134 | score_model.train() 135 | avg_loss = 0. 136 | num_items = 0 137 | for x, y in tqdm(data_loader): 138 | x = x.to(device) 139 | loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn) 140 | optimizer.zero_grad() 141 | loss.backward() 142 | optimizer.step() 143 | avg_loss += loss.item() * x.shape[0] 144 | num_items += x.shape[0] 145 | scheduler.step() 146 | lr_current = scheduler.get_last_lr()[0] 147 | print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current)) 148 | # Print the averaged training loss so far. 149 | tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) 150 | # Update the checkpoint after each epoch of training. 151 | torch.save(score_model.state_dict(), f'ckpt_{ckpt_name}.pth') 152 | if callback is not None: 153 | score_model.eval() 154 | callback(score_model, epoch, ckpt_name) 155 | 156 | #%% 157 | 158 | #%% 159 | from tqdm import tqdm 160 | from torch.utils.data import DataLoader, TensorDataset 161 | from torchvision.datasets import CelebA 162 | from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize 163 | 164 | tfm = Compose([ 165 | Resize(32), 166 | CenterCrop(32), 167 | ToTensor(), 168 | Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 169 | ]) 170 | dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"], 171 | transform=tfm, download=False) # ,"identity" 172 | #%% 173 | # def preprocess_dataset(dataset_rsz, ): 174 | dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False) 175 | x_col = [] 176 | y_col = [] 177 | for xs, ys in tqdm(dataloader): 178 | x_col.append(xs) 179 | y_col.append(ys) 180 | x_col = torch.concat(x_col, dim=0) 181 | y_col = torch.concat(y_col, dim=0) 182 | print(x_col.shape) 183 | print(y_col.shape) 184 | 185 | nantoken = 40 186 | maxlen = (y_col.sum(dim=1)).max() 187 | yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken) 188 | 189 | saved_dataset = TensorDataset(x_col, yseq_data) 190 | # return saved_dataset 191 | #%% 192 | import matplotlib.pyplot as plt 193 | 194 | def save_sample_callback(score_model, epocs, ckpt_name): 195 | sample_batch_size = 64 196 | num_steps = 250 197 | y_samp = yseq_data[:sample_batch_size, :] 198 | sampler = Euler_Maruyama_sampler 199 | samples = sampler(score_model, 200 | marginal_prob_std_fn, 201 | diffusion_coeff_fn, 202 | sample_batch_size, 203 | x_shape=(3, 32, 32), 204 | num_steps=num_steps, 205 | device=device, 206 | y=y_samp, ) 207 | denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], 208 | [1/0.229, 1/0.224, 1/0.225]) 209 | samples = denormalize(samples).clamp(0.0, 1.0) 210 | sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size))) 211 | 212 | plt.figure(figsize=(8, 8)) 213 | plt.axis('off') 214 | plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) 215 | plt.tight_layout() 216 | plt.savefig(f"samples_{ckpt_name}_{epocs}.png") 217 | plt.show() 218 | #%% 219 | #@title Training model 220 | 221 | # continue_training = False #@param {type:"boolean"} 222 | # if not continue_training: 223 | # print("initilize new score model...") 224 | score_model = torch.nn.DataParallel( 225 | UNet_Tranformer_attrb(marginal_prob_std=marginal_prob_std_fn)) 226 | score_model = score_model.to(device) 227 | 228 | n_epochs = 50 229 | batch_size = 1024 230 | lr = 10e-4 231 | train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size, 232 | "Unet-tfmer_pad", device="cuda", callback=save_sample_callback) 233 | #%% 234 | 235 | #%% 236 | n_epochs = 50 237 | batch_size = 1048 238 | lr = 2e-4 239 | train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size, 240 | "Unet-tfmer", device="cuda", callback=save_sample_callback) 241 | #%% 242 | 243 | sample_batch_size = 64 #@param {'type':'integer'} 244 | num_steps = 250 #@param {'type':'integer'} 245 | sampler = Euler_Maruyama_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'} 246 | # score_model.eval() 247 | ## Generate samples using the specified sampler. 248 | samples = sampler(score_model, 249 | marginal_prob_std_fn, 250 | diffusion_coeff_fn, 251 | sample_batch_size, 252 | x_shape=(3, 32, 32), 253 | num_steps=num_steps, 254 | device=device, 255 | y=yseq_data[:sample_batch_size,:], ) 256 | 257 | ## Sample visualization. 258 | denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], 259 | [1/0.229, 1/0.224, 1/0.225]) 260 | 261 | samples = denormalize(samples).clamp(0.0, 1.0) 262 | sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size))) 263 | 264 | plt.figure(figsize=(8, 8)) 265 | plt.axis('off') 266 | plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) 267 | plt.tight_layout() 268 | plt.show() 269 | 270 | -------------------------------------------------------------------------------- /net_models.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | class GaussianFourierProjection(nn.Module): 10 | """Gaussian random features for encoding time steps.""" 11 | def __init__(self, embed_dim, scale=30.): 12 | super().__init__() 13 | # Randomly sample weights during initialization. These weights are fixed 14 | # during optimization and are not trainable. 15 | self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) 16 | def forward(self, x): 17 | x_proj = x[:, None] * self.W[None, :] * 2 * math.pi 18 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 19 | 20 | 21 | class Dense(nn.Module): 22 | """A fully connected layer that reshapes outputs to feature maps. 23 | Allow time repr to input additively from the side of a convolution layer. 24 | """ 25 | def __init__(self, input_dim, output_dim): 26 | super().__init__() 27 | self.dense = nn.Linear(input_dim, output_dim) 28 | def forward(self, x): 29 | return self.dense(x)[..., None, None] 30 | 31 | 32 | class CrossAttention(nn.Module): 33 | """General implementation of Cross & Self Attention""" 34 | def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1, ): 35 | super(CrossAttention, self).__init__() 36 | self.hidden_dim = hidden_dim 37 | self.context_dim = context_dim 38 | self.embed_dim = embed_dim 39 | self.query = nn.Linear(hidden_dim, embed_dim, bias=False) 40 | if context_dim is None: 41 | # Self Attention 42 | self.key = nn.Linear(hidden_dim, embed_dim, bias=False) 43 | self.value = nn.Linear(hidden_dim, hidden_dim, bias=False) 44 | self.self_attn = True 45 | else: 46 | # Cross Attention 47 | self.key = nn.Linear(context_dim, embed_dim, bias=False) 48 | self.value = nn.Linear(context_dim, hidden_dim, bias=False) 49 | self.self_attn = False 50 | # self.query = nn.Conv1d(hidden_dim, embed_dim, 1, bias=False) 51 | # if context_dim is None: 52 | # self.key = nn.Conv1d(hidden_dim, embed_dim, 1, bias=False) 53 | # self.value = nn.Conv1d(hidden_dim, hidden_dim, 1, bias=False) 54 | # self.self_attn = True 55 | # else: 56 | # self.key = nn.Conv1d(context_dim, embed_dim, 1, bias=False) 57 | # self.value = nn.Conv1d(context_dim, hidden_dim, 1, bias=False) 58 | # self.self_attn = False 59 | 60 | def forward(self, tokens, context=None): 61 | Q = self.query(tokens) 62 | K = self.key(tokens) if self.self_attn else self.key(context) 63 | V = self.value(tokens) if self.self_attn else self.value(context) 64 | # if self.self_attn: 65 | # print(Q.shape, K.shape, V.shape) 66 | scoremats = torch.einsum("BTH,BSH->BTS", Q, K) 67 | attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1) 68 | # print(scoremats.shape, attnmats.shape, ) 69 | ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V) 70 | return ctx_vecs 71 | 72 | 73 | class TransformerBlock(nn.Module): 74 | def __init__(self, hidden_dim, context_dim): 75 | super(TransformerBlock, self).__init__() 76 | self.attn_self = CrossAttention(hidden_dim, hidden_dim, ) 77 | self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim) 78 | 79 | self.norm1 = nn.LayerNorm(hidden_dim) 80 | self.norm2 = nn.LayerNorm(hidden_dim) 81 | self.norm3 = nn.LayerNorm(hidden_dim) 82 | self.ffn = nn.Sequential( 83 | nn.Linear(hidden_dim, 3 * hidden_dim), 84 | nn.GELU(), 85 | nn.Linear(3 * hidden_dim, hidden_dim) 86 | ) 87 | 88 | 89 | def forward(self, x, context=None): 90 | x = self.attn_self(self.norm1(x)) + x 91 | x = self.attn_cross(self.norm2(x), context=context) + x 92 | x = self.ffn(self.norm3(x)) + x 93 | return x 94 | 95 | 96 | class SpatialTransformer(nn.Module): 97 | def __init__(self, hidden_dim, context_dim): 98 | super(SpatialTransformer, self).__init__() 99 | self.transformer = TransformerBlock(hidden_dim, context_dim) 100 | 101 | def forward(self, x, context=None): 102 | b, c, h, w = x.shape 103 | x_in = x 104 | # context = rearrange(context, "b c T -> b T c") 105 | x = rearrange(x, "b c h w->b (h w) c") 106 | x = self.transformer(x, context) 107 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 108 | return x + x_in 109 | 110 | 111 | class UNet_Tranformer_attrb(nn.Module): 112 | """A time-dependent score-based model built upon U-Net architecture.""" 113 | 114 | def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256, 115 | text_dim=256, nAttr=40): 116 | """Initialize a time-dependent score-based network. 117 | 118 | Args: 119 | marginal_prob_std: A function that takes time t and gives the standard 120 | deviation of the perturbation kernel p_{0t}(x(t) | x(0)). 121 | channels: The number of channels for feature maps of each resolution. 122 | embed_dim: The dimensionality of Gaussian random feature embeddings. 123 | """ 124 | super().__init__() 125 | # Gaussian random feature embedding layer for time 126 | self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim), 127 | nn.Linear(embed_dim, embed_dim)) 128 | # Encoding layers where the resolution decreases 129 | self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, bias=False, ) 130 | self.dense1 = Dense(embed_dim, channels[0]) 131 | self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0]) 132 | self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False, ) 133 | self.dense2 = Dense(embed_dim, channels[1]) 134 | self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1]) 135 | self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False, ) 136 | self.dense3 = Dense(embed_dim, channels[2]) 137 | self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2]) 138 | self.attn3 = SpatialTransformer(channels[2], text_dim) 139 | self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False, ) 140 | self.dense4 = Dense(embed_dim, channels[3]) 141 | self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3]) 142 | self.attn4 = SpatialTransformer(channels[3], text_dim) 143 | 144 | # Decoding layers where the resolution increases 145 | self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False, output_padding=1) 146 | self.dense5 = Dense(embed_dim, channels[2]) 147 | self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2]) 148 | self.attn5 = SpatialTransformer(channels[2], text_dim) 149 | self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, 150 | output_padding=1) # , output_padding=1) # + channels[2] 151 | self.dense6 = Dense(embed_dim, channels[1]) 152 | self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1]) 153 | self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, 154 | output_padding=1) # , output_padding=1) # + channels[1] 155 | self.dense7 = Dense(embed_dim, channels[0]) 156 | self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0]) 157 | self.tconv1 = nn.ConvTranspose2d(channels[0], 3, 3, stride=1, ) # + channels[0] 158 | 159 | # The swish activation function 160 | self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x) 161 | self.marginal_prob_std = marginal_prob_std 162 | self.cond_embed = nn.Embedding(nAttr + 1, text_dim, 163 | padding_idx=nAttr) # +1 for the padding index 164 | 165 | def forward(self, x, t, y=None): 166 | # Obtain the Gaussian random feature embedding for t 167 | embed = self.act(self.embed(t)) 168 | y_embed = self.cond_embed(y) # .unsqueeze(1) 169 | # Encoding path 170 | h1 = self.conv1(x) + self.dense1(embed) 171 | ## Incorporate information from t 172 | ## Group normalization 173 | h1 = self.act(self.gnorm1(h1)) 174 | h2 = self.conv2(h1) + self.dense2(embed) 175 | h2 = self.act(self.gnorm2(h2)) 176 | h3 = self.conv3(h2) + self.dense3(embed) 177 | h3 = self.act(self.gnorm3(h3)) 178 | # h3 = self.attn3(h3, y_embed) 179 | h4 = self.conv4(h3) + self.dense4(embed) 180 | h4 = self.act(self.gnorm4(h4)) 181 | h4 = self.attn4(h4, y_embed) 182 | 183 | # Decoding path 184 | h = self.tconv4(h4) + self.dense5(embed) 185 | ## Skip connection from the encoding path 186 | h = self.act(self.tgnorm4(h)) 187 | # h = self.attn5(h, y_embed) 188 | h = self.tconv3(h + h3) + self.dense6(embed) 189 | h = self.act(self.tgnorm3(h)) 190 | h = self.tconv2(h + h2) + self.dense7(embed) 191 | h = self.act(self.tgnorm2(h)) 192 | h = self.tconv1(h + h1) 193 | 194 | # Normalize output 195 | h = h / self.marginal_prob_std(t)[:, None, None, None] 196 | return h 197 | 198 | 199 | class ResBlock(nn.Module): 200 | def __init__(self, in_chan, out_chan, stride=1, downsample=None): 201 | super().__init__() 202 | self.conv1 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1) 203 | self.conv2 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1) 204 | self.conv3 = nn.Conv2d(in_chan, in_chan, 3, stride=1, padding=1) 205 | 206 | 207 | 208 | class UNet_Tranformer_ResBlk_attrb(nn.Module): 209 | """A time-dependent score-based model built upon U-Net architecture.""" 210 | 211 | def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256, 212 | text_dim=256, nAttr=40): 213 | """Initialize a time-dependent score-based network. 214 | 215 | Args: 216 | marginal_prob_std: A function that takes time t and gives the standard 217 | deviation of the perturbation kernel p_{0t}(x(t) | x(0)). 218 | channels: The number of channels for feature maps of each resolution. 219 | embed_dim: The dimensionality of Gaussian random feature embeddings. 220 | """ 221 | super().__init__() 222 | # Gaussian random feature embedding layer for time 223 | self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim), 224 | nn.Linear(embed_dim, embed_dim)) 225 | # Encoding layers where the resolution decreases 226 | self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, bias=False, ) 227 | self.dense1 = Dense(embed_dim, channels[0]) 228 | self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0]) 229 | self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False, ) 230 | self.dense2 = Dense(embed_dim, channels[1]) 231 | self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1]) 232 | self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False, ) 233 | self.dense3 = Dense(embed_dim, channels[2]) 234 | self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2]) 235 | self.attn3 = SpatialTransformer(channels[2], text_dim) 236 | self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False, ) 237 | self.dense4 = Dense(embed_dim, channels[3]) 238 | self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3]) 239 | self.attn4 = SpatialTransformer(channels[3], text_dim) 240 | 241 | # Decoding layers where the resolution increases 242 | self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False, output_padding=1) 243 | self.dense5 = Dense(embed_dim, channels[2]) 244 | self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2]) 245 | self.attn5 = SpatialTransformer(channels[2], text_dim) 246 | self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, 247 | output_padding=1) # , output_padding=1) # + channels[2] 248 | self.dense6 = Dense(embed_dim, channels[1]) 249 | self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1]) 250 | self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, 251 | output_padding=1) # , output_padding=1) # + channels[1] 252 | self.dense7 = Dense(embed_dim, channels[0]) 253 | self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0]) 254 | self.tconv1 = nn.ConvTranspose2d(channels[0], 3, 3, stride=1, ) # + channels[0] 255 | 256 | # The swish activation function 257 | self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x) 258 | self.marginal_prob_std = marginal_prob_std 259 | self.cond_embed = nn.Embedding(nAttr + 1, text_dim, 260 | padding_idx=nAttr) # +1 for the padding index 261 | 262 | def forward(self, x, t, y=None): 263 | # Obtain the Gaussian random feature embedding for t 264 | embed = self.act(self.embed(t)) 265 | y_embed = self.cond_embed(y) # .unsqueeze(1) 266 | # Encoding path 267 | h1 = self.conv1(x) + self.dense1(embed) 268 | ## Incorporate information from t 269 | ## Group normalization 270 | h1 = self.act(self.gnorm1(h1)) 271 | h2 = self.conv2(h1) + self.dense2(embed) 272 | h2 = self.act(self.gnorm2(h2)) 273 | h3 = self.conv3(h2) + self.dense3(embed) 274 | h3 = self.act(self.gnorm3(h3)) 275 | # h3 = self.attn3(h3, y_embed) 276 | h4 = self.conv4(h3) + self.dense4(embed) 277 | h4 = self.act(self.gnorm4(h4)) 278 | h4 = self.attn4(h4, y_embed) 279 | 280 | # Decoding path 281 | h = self.tconv4(h4) + self.dense5(embed) 282 | ## Skip connection from the encoding path 283 | h = self.act(self.tgnorm4(h)) 284 | # h = self.attn5(h, y_embed) 285 | h = self.tconv3(h + h3) + self.dense6(embed) 286 | h = self.act(self.tgnorm3(h)) 287 | h = self.tconv2(h + h2) + self.dense7(embed) 288 | h = self.act(self.tgnorm2(h)) 289 | h = self.tconv1(h + h1) 290 | 291 | # Normalize output 292 | h = h / self.marginal_prob_std(t)[:, None, None, None] 293 | return h -------------------------------------------------------------------------------- /StableDiff_UNet_unittest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test out the key modules for Stable Diffusion 3 | - ResBlock 4 | - UpSample 5 | - DownSample 6 | - CrossAttention 7 | - BasicTransformer (self, cross, FFN) 8 | - Spatial Transformer 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from diffusers import StableDiffusionPipeline 14 | from StableDiff_UNet_model import * 15 | 16 | pipe = StableDiffusionPipeline.from_pretrained( 17 | "CompVis/stable-diffusion-v1-4", 18 | use_auth_token=True 19 | ).to("cuda") 20 | #%% test ResBlock Implementation 21 | tmp_blk = ResBlock(320, 1280).cuda() 22 | std_blk = pipe.unet.down_blocks[0].resnets[0] 23 | SD = std_blk.state_dict() 24 | tmp_blk.load_state_dict(SD) 25 | lat_tmp = torch.randn(3, 320, 32, 32) 26 | temb = torch.randn(3, 1280) 27 | with torch.no_grad(): 28 | out = pipe.unet.down_blocks[0].resnets[0](lat_tmp.cuda(),temb.cuda()) 29 | out2 = tmp_blk(lat_tmp.cuda(), temb.cuda()) 30 | 31 | assert torch.allclose(out2, out) 32 | 33 | #%% test downsampler 34 | tmpdsp = DownSample(320).cuda() 35 | stddsp = pipe.unet.down_blocks[0].downsamplers[0] 36 | SD = stddsp.state_dict() 37 | tmpdsp.load_state_dict(SD) 38 | lat_tmp = torch.randn(3, 320, 32, 32) 39 | with torch.no_grad(): 40 | out = stddsp(lat_tmp.cuda()) 41 | out2 = tmpdsp(lat_tmp.cuda()) 42 | 43 | assert torch.allclose(out2, out) 44 | 45 | #%% test upsampler 46 | tmpusp = UpSample(1280).cuda() 47 | stdusp = pipe.unet.up_blocks[1].upsamplers[0] 48 | SD = stdusp.state_dict() 49 | tmpusp.load_state_dict(SD) 50 | lat_tmp = torch.randn(3, 1280, 32, 32) 51 | with torch.no_grad(): 52 | out = stdusp(lat_tmp.cuda()) 53 | out2 = tmpusp(lat_tmp.cuda()) 54 | 55 | assert torch.allclose(out2, out) 56 | 57 | 58 | #%% test SpatialTransformer Implementation 59 | # Check self attention 60 | tmpSattn = CrossAttention(320, 320, context_dim=None, num_heads=8).cuda() 61 | stdSattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1 62 | tmpSattn.load_state_dict(stdSattn.state_dict()) # checked 63 | with torch.no_grad(): 64 | lat_tmp = torch.randn(3, 32, 320) 65 | out = stdSattn(lat_tmp.cuda()) 66 | out2 = tmpSattn(lat_tmp.cuda()) 67 | assert torch.allclose(out2, out) # False 68 | 69 | #%% 70 | # Check Cross attention 71 | tmpXattn = CrossAttention(320, 320, context_dim=768, num_heads=8).cuda() 72 | stdXattn = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn2 73 | tmpXattn.load_state_dict(stdXattn.state_dict()) # checked 74 | with torch.no_grad(): 75 | lat_tmp = torch.randn(3, 32, 320) 76 | ctx_tmp = torch.randn(3, 5, 768) 77 | out = stdXattn(lat_tmp.cuda(), ctx_tmp.cuda()) 78 | out2 = tmpXattn(lat_tmp.cuda(), ctx_tmp.cuda()) 79 | assert torch.allclose(out2, out) # False 80 | 81 | #%% test TransformerBlock Implementation 82 | tmpTfmer = TransformerBlock(320, context_dim=768, num_heads=8).cuda() 83 | stdTfmer = pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0] 84 | tmpTfmer.load_state_dict(stdTfmer.state_dict()) # checked 85 | with torch.no_grad(): 86 | lat_tmp = torch.randn(3, 32, 320) 87 | ctx_tmp = torch.randn(3, 5, 768) 88 | out = tmpTfmer(lat_tmp.cuda(), ctx_tmp.cuda()) 89 | out2 = stdTfmer(lat_tmp.cuda(), ctx_tmp.cuda()) 90 | assert torch.allclose(out2, out) # False 91 | 92 | 93 | #%% test SpatialTransformer Implementation 94 | tmpSpTfmer = SpatialTransformer(320, context_dim=768, num_heads=8).cuda() 95 | stdSpTfmer = pipe.unet.down_blocks[0].attentions[0] 96 | tmpSpTfmer.load_state_dict(stdSpTfmer.state_dict()) # checked 97 | with torch.no_grad(): 98 | lat_tmp = torch.randn(3, 320, 8, 8) 99 | ctx_tmp = torch.randn(3, 5, 768) 100 | out = tmpSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda()) 101 | out2 = stdSpTfmer(lat_tmp.cuda(), ctx_tmp.cuda()) 102 | assert torch.allclose(out2, out) # False 103 | 104 | #%% test UNet downblocks 105 | tmpUNet = UNet_SD() 106 | load_pipe_into_our_UNet(tmpUNet, pipe) 107 | 108 | #%% 109 | tmpUNet.output[0].load_state_dict(pipe.unet.conv_norm_out.state_dict()) 110 | tmpUNet.output[2].load_state_dict(pipe.unet.conv_out.state_dict()) 111 | tmpUNet.conv_in.load_state_dict(pipe.unet.conv_in.state_dict()) 112 | tmpUNet.time_embedding.load_state_dict(pipe.unet.time_embedding.state_dict()) 113 | 114 | # Loading the down blocks 115 | tmpUNet.down_blocks[0][0].load_state_dict(pipe.unet.down_blocks[0].resnets[0].state_dict()) 116 | tmpUNet.down_blocks[0][1].load_state_dict(pipe.unet.down_blocks[0].attentions[0].state_dict()) 117 | tmpUNet.down_blocks[1][0].load_state_dict(pipe.unet.down_blocks[0].resnets[1].state_dict()) 118 | tmpUNet.down_blocks[1][1].load_state_dict(pipe.unet.down_blocks[0].attentions[1].state_dict()) 119 | tmpUNet.down_blocks[2][0].load_state_dict(pipe.unet.down_blocks[0].downsamplers[0].state_dict()) 120 | 121 | tmpUNet.down_blocks[3][0].load_state_dict(pipe.unet.down_blocks[1].resnets[0].state_dict()) 122 | tmpUNet.down_blocks[3][1].load_state_dict(pipe.unet.down_blocks[1].attentions[0].state_dict()) 123 | tmpUNet.down_blocks[4][0].load_state_dict(pipe.unet.down_blocks[1].resnets[1].state_dict()) 124 | tmpUNet.down_blocks[4][1].load_state_dict(pipe.unet.down_blocks[1].attentions[1].state_dict()) 125 | tmpUNet.down_blocks[5][0].load_state_dict(pipe.unet.down_blocks[1].downsamplers[0].state_dict()) 126 | 127 | tmpUNet.down_blocks[6][0].load_state_dict(pipe.unet.down_blocks[2].resnets[0].state_dict()) 128 | tmpUNet.down_blocks[6][1].load_state_dict(pipe.unet.down_blocks[2].attentions[0].state_dict()) 129 | tmpUNet.down_blocks[7][0].load_state_dict(pipe.unet.down_blocks[2].resnets[1].state_dict()) 130 | tmpUNet.down_blocks[7][1].load_state_dict(pipe.unet.down_blocks[2].attentions[1].state_dict()) 131 | tmpUNet.down_blocks[8][0].load_state_dict(pipe.unet.down_blocks[2].downsamplers[0].state_dict()) 132 | 133 | tmpUNet.down_blocks[9][0].load_state_dict(pipe.unet.down_blocks[3].resnets[0].state_dict()) 134 | tmpUNet.down_blocks[10][0].load_state_dict(pipe.unet.down_blocks[3].resnets[1].state_dict()) 135 | 136 | # Loading the middle blocks 137 | tmpUNet.mid_block[0].load_state_dict(pipe.unet.mid_block.resnets[0].state_dict()) 138 | tmpUNet.mid_block[1].load_state_dict(pipe.unet.mid_block.attentions[0].state_dict()) 139 | tmpUNet.mid_block[2].load_state_dict(pipe.unet.mid_block.resnets[1].state_dict()) 140 | 141 | #%% Loading the up blocks 142 | # upblock 0 143 | tmpUNet.up_blocks[0][0].load_state_dict(pipe.unet.up_blocks[0].resnets[0].state_dict()) 144 | tmpUNet.up_blocks[1][0].load_state_dict(pipe.unet.up_blocks[0].resnets[1].state_dict()) 145 | tmpUNet.up_blocks[2][0].load_state_dict(pipe.unet.up_blocks[0].resnets[2].state_dict()) 146 | tmpUNet.up_blocks[2][1].load_state_dict(pipe.unet.up_blocks[0].upsamplers[0].state_dict()) 147 | #% upblock 1 148 | tmpUNet.up_blocks[3][0].load_state_dict(pipe.unet.up_blocks[1].resnets[0].state_dict()) 149 | tmpUNet.up_blocks[3][1].load_state_dict(pipe.unet.up_blocks[1].attentions[0].state_dict()) 150 | tmpUNet.up_blocks[4][0].load_state_dict(pipe.unet.up_blocks[1].resnets[1].state_dict()) 151 | tmpUNet.up_blocks[4][1].load_state_dict(pipe.unet.up_blocks[1].attentions[1].state_dict()) 152 | tmpUNet.up_blocks[5][0].load_state_dict(pipe.unet.up_blocks[1].resnets[2].state_dict()) 153 | tmpUNet.up_blocks[5][1].load_state_dict(pipe.unet.up_blocks[1].attentions[2].state_dict()) 154 | tmpUNet.up_blocks[5][2].load_state_dict(pipe.unet.up_blocks[1].upsamplers[0].state_dict()) 155 | #% upblock 2 156 | tmpUNet.up_blocks[6][0].load_state_dict(pipe.unet.up_blocks[2].resnets[0].state_dict()) 157 | tmpUNet.up_blocks[6][1].load_state_dict(pipe.unet.up_blocks[2].attentions[0].state_dict()) 158 | tmpUNet.up_blocks[7][0].load_state_dict(pipe.unet.up_blocks[2].resnets[1].state_dict()) 159 | tmpUNet.up_blocks[7][1].load_state_dict(pipe.unet.up_blocks[2].attentions[1].state_dict()) 160 | tmpUNet.up_blocks[8][0].load_state_dict(pipe.unet.up_blocks[2].resnets[2].state_dict()) 161 | tmpUNet.up_blocks[8][1].load_state_dict(pipe.unet.up_blocks[2].attentions[2].state_dict()) 162 | tmpUNet.up_blocks[8][2].load_state_dict(pipe.unet.up_blocks[2].upsamplers[0].state_dict()) 163 | #% upblock 3 164 | tmpUNet.up_blocks[9][0].load_state_dict(pipe.unet.up_blocks[3].resnets[0].state_dict()) 165 | tmpUNet.up_blocks[9][1].load_state_dict(pipe.unet.up_blocks[3].attentions[0].state_dict()) 166 | tmpUNet.up_blocks[10][0].load_state_dict(pipe.unet.up_blocks[3].resnets[1].state_dict()) 167 | tmpUNet.up_blocks[10][1].load_state_dict(pipe.unet.up_blocks[3].attentions[1].state_dict()) 168 | tmpUNet.up_blocks[11][0].load_state_dict(pipe.unet.up_blocks[3].resnets[2].state_dict()) 169 | tmpUNet.up_blocks[11][1].load_state_dict(pipe.unet.up_blocks[3].attentions[2].state_dict()) 170 | #%% 171 | 172 | #%% Check entire UNet, very small difference 173 | tmpUNet.cuda().eval() 174 | pipe.unet.eval() 175 | with torch.no_grad(): 176 | lat_x = torch.randn(3, 4, 32, 32).cuda() 177 | ctx_tmp = torch.randn(3, 5, 768).cuda() 178 | t_emb_tmp = torch.rand(3, ).cuda() 179 | out = tmpUNet(lat_x, t_emb_tmp, ctx_tmp) 180 | out2 = pipe.unet(lat_x, t_emb_tmp, ctx_tmp) 181 | sample = out2.sample 182 | 183 | print((sample - out).max(), (sample - out).min()) # 0.0008 -0.0008 184 | assert torch.allclose(sample, out) # False 185 | #%% checked all downward blocks [Checked] 186 | tmpUNet.mid_block.cuda() 187 | tmpUNet.down_blocks.cuda() 188 | with torch.no_grad(): 189 | lat_tmp = torch.randn(3, 320, 32, 32) 190 | ctx_tmp = torch.randn(3, 5, 768) 191 | t_emb_tmp = torch.randn(3, 1280) 192 | out = tmpUNet.down_blocks[0:3](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 193 | out2 = pipe.unet.down_blocks[0](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 194 | assert torch.allclose(out2[0], out) # False 195 | 196 | #%% checked all downward and middle blocks [Checked] 197 | with torch.no_grad(): 198 | lat_tmp = torch.randn(3, 320, 64, 64) 199 | ctx_tmp = torch.randn(3, 5, 768) 200 | t_emb_tmp = torch.randn(3, 1280) 201 | # our implementation 202 | downout = tmpUNet.down_blocks(lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 203 | out = tmpUNet.mid_block(downout, t_emb_tmp.cuda(), ctx_tmp.cuda()) 204 | # standard implementation 205 | hidden = lat_tmp.cuda() 206 | for i in range(3): 207 | hidden, out_col = pipe.unet.down_blocks[i](hidden, t_emb_tmp.cuda(), ctx_tmp.cuda()) 208 | downout2, out_col = pipe.unet.down_blocks[-1](hidden, t_emb_tmp.cuda(), ) 209 | out2 = pipe.unet.mid_block(downout2, t_emb_tmp.cuda(), ctx_tmp.cuda()) 210 | assert torch.allclose(out2, out) # False 211 | assert torch.allclose(downout2, downout) # False 212 | 213 | #%% checked all downward blocks [Checked, exact] 214 | tmpUNet.mid_block.cuda() 215 | tmpUNet.down_blocks.cuda() 216 | tmpUNet.up_blocks.cuda() 217 | with torch.no_grad(): 218 | lat_tmp = torch.randn(2, 320, 32, 32).cuda() 219 | ctx_tmp = torch.randn(2, 5, 768).cuda() 220 | t_emb_tmp = torch.randn(2, 1280).cuda() 221 | # our implementation 222 | hidden = lat_tmp.cuda() 223 | down_x_cache = [hidden] 224 | for module in tmpUNet.down_blocks: 225 | hidden = module(hidden, t_emb_tmp, ctx_tmp) 226 | down_x_cache.append(hidden) 227 | out = tmpUNet.mid_block(hidden, t_emb_tmp, ctx_tmp) 228 | 229 | # Hugginface standard implementation 230 | hidden = lat_tmp.cuda() 231 | out_cache = (hidden, ) 232 | for i in range(3): 233 | hidden, out_col = pipe.unet.down_blocks[i](hidden, t_emb_tmp, ctx_tmp) 234 | out_cache = out_cache + out_col 235 | downout2, out_col = pipe.unet.down_blocks[-1](hidden, t_emb_tmp,) 236 | out_cache = out_cache + out_col 237 | out2 = pipe.unet.mid_block(downout2, t_emb_tmp, ctx_tmp) 238 | 239 | assert torch.allclose(out2, out) # False 240 | for x1, x2 in zip(down_x_cache, out_cache): 241 | assert torch.allclose(x1, x2) # False 242 | #%% checked all downward and middle and upward blocks [Checked, not exactly same!] 243 | tmpUNet.mid_block.cuda() 244 | tmpUNet.down_blocks.cuda() 245 | tmpUNet.up_blocks.cuda() 246 | tmpUNet.eval() 247 | pipe.unet.eval() 248 | with torch.no_grad(): 249 | lat_tmp = torch.randn(2, 320, 32, 32).cuda() 250 | ctx_tmp = torch.randn(2, 5, 768).cuda() 251 | t_emb_tmp = torch.randn(2, 1280).cuda() 252 | # our implementation 253 | hidden = lat_tmp.cuda() 254 | down_x_cache = [hidden] 255 | for module in tmpUNet.down_blocks: 256 | hidden = module(hidden, t_emb_tmp, ctx_tmp) 257 | down_x_cache.append(hidden) 258 | out = tmpUNet.mid_block(hidden, t_emb_tmp, ctx_tmp) 259 | for module in tmpUNet.up_blocks[:]: 260 | out = module(torch.cat((out, down_x_cache.pop()), dim=1), t_emb_tmp, ctx_tmp) 261 | 262 | # Hugginface standard implementation 263 | hidden = lat_tmp.cuda() 264 | out_cache = (hidden, ) 265 | for i in range(3): 266 | hidden, out_col = pipe.unet.down_blocks[i](hidden, t_emb_tmp, ctx_tmp) 267 | out_cache = out_cache + out_col 268 | downout2, out_col = pipe.unet.down_blocks[-1](hidden, t_emb_tmp,) 269 | out_cache = out_cache + out_col 270 | out2 = pipe.unet.mid_block(downout2, t_emb_tmp, ctx_tmp) 271 | out2 = pipe.unet.up_blocks[0](hidden_states=out2, temb=t_emb_tmp, 272 | res_hidden_states_tuple=out_cache[-3:], 273 | ) 274 | out_cache = out_cache[:-3] 275 | for i in range(1, 4): 276 | out2 = pipe.unet.up_blocks[i](hidden_states=out2, temb=t_emb_tmp, 277 | res_hidden_states_tuple=out_cache[-3:], 278 | encoder_hidden_states=ctx_tmp, 279 | ) 280 | out_cache = out_cache[:-3] 281 | # for i, upsample_block in enumerate(pipe.unet.up_blocks): 282 | # is_final_block = i == len(pipe.unet.up_blocks) - 1 283 | # 284 | # res_samples = out_cache[-len(upsample_block.resnets):] 285 | # out_cache = out_cache[: -len(upsample_block.resnets)] 286 | # 287 | # if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: 288 | # out2 = upsample_block( 289 | # hidden_states=out2, temb=t_emb_tmp, 290 | # res_hidden_states_tuple=res_samples, 291 | # encoder_hidden_states=ctx_tmp, 292 | # ) 293 | # else: 294 | # out2 = upsample_block( 295 | # hidden_states=out2, temb=t_emb_tmp, 296 | # res_hidden_states_tuple=res_samples, 297 | # ) 298 | 299 | print((out2 - out).max(), (out2 - out).min()) 300 | assert torch.allclose(out2, out) # False 301 | #%% 302 | 303 | #%% Older version 304 | # Load in down blocks 305 | tmpDown = tmpUNet.down_blocks[0].cuda() 306 | stdDown = pipe.unet.down_blocks[0] 307 | tmpDown[0].load_state_dict(stdDown.resnets[0].state_dict()) 308 | tmpDown[1].load_state_dict(stdDown.attentions[0].state_dict()) 309 | tmpDown[2].load_state_dict(stdDown.resnets[1].state_dict()) 310 | tmpDown[3].load_state_dict(stdDown.attentions[1].state_dict()) 311 | tmpDown[4].load_state_dict(stdDown.downsamplers[0].state_dict()) 312 | tmpDown = tmpUNet.down_blocks[1].cuda() 313 | stdDown = pipe.unet.down_blocks[1] 314 | tmpDown[0].load_state_dict(stdDown.resnets[0].state_dict()) 315 | tmpDown[1].load_state_dict(stdDown.attentions[0].state_dict()) 316 | tmpDown[2].load_state_dict(stdDown.resnets[1].state_dict()) 317 | tmpDown[3].load_state_dict(stdDown.attentions[1].state_dict()) 318 | tmpDown[4].load_state_dict(stdDown.downsamplers[0].state_dict()) 319 | tmpDown = tmpUNet.down_blocks[2].cuda() 320 | stdDown = pipe.unet.down_blocks[2] 321 | tmpDown[0].load_state_dict(stdDown.resnets[0].state_dict()) 322 | tmpDown[1].load_state_dict(stdDown.attentions[0].state_dict()) 323 | tmpDown[2].load_state_dict(stdDown.resnets[1].state_dict()) 324 | tmpDown[3].load_state_dict(stdDown.attentions[1].state_dict()) 325 | tmpDown[4].load_state_dict(stdDown.downsamplers[0].state_dict()) 326 | tmpDown = tmpUNet.down_blocks[3].cuda() 327 | stdDown = pipe.unet.down_blocks[3] 328 | tmpDown[0].load_state_dict(stdDown.resnets[0].state_dict()) 329 | tmpDown[1].load_state_dict(stdDown.resnets[1].state_dict()) 330 | # Load in middle blocks 331 | stdMid = pipe.unet.mid_block 332 | tmpUNet.mid_block[0].load_state_dict(stdMid.resnets[0].state_dict()) 333 | tmpUNet.mid_block[1].load_state_dict(stdMid.attentions[0].state_dict()) 334 | tmpUNet.mid_block[2].load_state_dict(stdMid.resnets[1].state_dict()) 335 | #%% Check down sample blocks 336 | with torch.no_grad(): 337 | lat_tmp = torch.randn(3, 320, 32, 32) 338 | ctx_tmp = torch.randn(3, 5, 768) 339 | t_emb_tmp = torch.randn(3, 1280) 340 | out = tmpUNet.down_blocks[0](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 341 | out2 = pipe.unet.down_blocks[0](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 342 | assert torch.allclose(out2[0], out) # False 343 | #%% Check down sample blocks 344 | tmpUNet.down_blocks[1].cuda() 345 | with torch.no_grad(): 346 | lat_tmp = torch.randn(3, 320, 16, 16) 347 | ctx_tmp = torch.randn(3, 5, 768) 348 | t_emb_tmp = torch.randn(3, 1280) 349 | out = tmpUNet.down_blocks[1](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 350 | out2 = pipe.unet.down_blocks[1](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 351 | assert torch.allclose(out2[0], out) # False 352 | #%% Down block 2 353 | tmpUNet.down_blocks[2].cuda() 354 | with torch.no_grad(): 355 | lat_tmp = torch.randn(3, 640, 8, 8) 356 | ctx_tmp = torch.randn(3, 5, 768) 357 | t_emb_tmp = torch.randn(3, 1280) 358 | out = tmpUNet.down_blocks[2](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 359 | out2 = pipe.unet.down_blocks[2](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 360 | assert torch.allclose(out2[0], out) # False 361 | #%% Down block 3 362 | tmpUNet.down_blocks[3].cuda() 363 | with torch.no_grad(): 364 | lat_tmp = torch.randn(3, 1280, 8, 8) 365 | ctx_tmp = torch.randn(3, 5, 768) 366 | t_emb_tmp = torch.randn(3, 1280) 367 | out = tmpUNet.down_blocks[3](lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 368 | out2 = pipe.unet.down_blocks[3](lat_tmp.cuda(), t_emb_tmp.cuda()) 369 | assert torch.allclose(out2[0], out) # False 370 | #%% Check middle blocks 371 | stdMid = pipe.unet.mid_block 372 | tmpUNet.mid_block.cuda() 373 | with torch.no_grad(): 374 | lat_tmp = torch.randn(3, 1280, 8, 8) 375 | t_emb_tmp = torch.randn(3, 1280) 376 | ctx_tmp = torch.randn(3, 5, 768) 377 | out = tmpUNet.mid_block(lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 378 | out2 = stdMid(lat_tmp.cuda(), t_emb_tmp.cuda(), ctx_tmp.cuda()) 379 | assert torch.allclose(out2, out) # False 380 | #%% test UNet in total 381 | 382 | -------------------------------------------------------------------------------- /StableDiff_UNet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reconstructing UNet in Stable Diffusion from scratch in ~ 300 lines of codes 3 | 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from collections import OrderedDict 11 | from easydict import EasyDict as edict 12 | 13 | 14 | class UNet_SD(nn.Module): 15 | 16 | def __init__(self, in_channels=4, 17 | base_channels=320, 18 | time_emb_dim=1280, 19 | context_dim=768, 20 | multipliers=(1, 2, 4, 4), 21 | attn_levels=(0, 1, 2), 22 | nResAttn_block=2, 23 | cat_unet=True): 24 | super().__init__() 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | self.in_channels = in_channels 27 | self.out_channels = in_channels 28 | base_channels = base_channels 29 | time_emb_dim = time_emb_dim 30 | context_dim = context_dim 31 | multipliers = multipliers 32 | nlevel = len(multipliers) 33 | self.base_channels = base_channels 34 | # attn_levels = [0, 1, 2] 35 | level_channels = [base_channels * mult for mult in multipliers] 36 | # Transform time into embedding 37 | self.time_embedding = nn.Sequential(OrderedDict({ 38 | "linear_1": nn.Linear(base_channels, time_emb_dim, bias=True), 39 | "act": nn.SiLU(), 40 | "linear_2": nn.Linear(time_emb_dim, time_emb_dim, bias=True), 41 | }) 42 | ) # 2 layer MLP 43 | self.conv_in = nn.Conv2d(self.in_channels, base_channels, 3, stride=1, padding=1) 44 | 45 | # Tensor Downsample blocks 46 | nResAttn_block = nResAttn_block 47 | self.down_blocks = TimeModulatedSequential() # nn.ModuleList() 48 | self.down_blocks_channels = [base_channels] 49 | cur_chan = base_channels 50 | for i in range(nlevel): 51 | for j in range(nResAttn_block): 52 | res_attn_sandwich = TimeModulatedSequential() 53 | # input_chan of first ResBlock is different from the rest. 54 | res_attn_sandwich.append(ResBlock(in_channel=cur_chan, time_emb_dim=time_emb_dim, out_channel=level_channels[i])) 55 | if i in attn_levels: 56 | # add attention except for the last level 57 | res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim=context_dim)) 58 | cur_chan = level_channels[i] 59 | self.down_blocks.append(res_attn_sandwich) 60 | self.down_blocks_channels.append(cur_chan) 61 | # res_attn_sandwich.append(DownSample(level_channels[i])) 62 | if not i == nlevel - 1: 63 | self.down_blocks.append(TimeModulatedSequential(DownSample(level_channels[i]))) 64 | self.down_blocks_channels.append(cur_chan) 65 | 66 | self.mid_block = TimeModulatedSequential( 67 | ResBlock(cur_chan, time_emb_dim), 68 | SpatialTransformer(cur_chan, context_dim=context_dim), 69 | ResBlock(cur_chan, time_emb_dim), 70 | ) 71 | 72 | # Tensor Upsample blocks 73 | self.up_blocks = nn.ModuleList() # TimeModulatedSequential() # 74 | for i in reversed(range(nlevel)): 75 | for j in range(nResAttn_block + 1): 76 | res_attn_sandwich = TimeModulatedSequential() 77 | res_attn_sandwich.append(ResBlock(in_channel=cur_chan + self.down_blocks_channels.pop(), 78 | time_emb_dim=time_emb_dim, out_channel=level_channels[i])) 79 | if i in attn_levels: 80 | res_attn_sandwich.append(SpatialTransformer(level_channels[i], context_dim=context_dim)) 81 | cur_chan = level_channels[i] 82 | if j == nResAttn_block and i != 0: 83 | res_attn_sandwich.append(UpSample(level_channels[i])) 84 | self.up_blocks.append(res_attn_sandwich) 85 | # Read out from tensor to latent space 86 | self.output = nn.Sequential( 87 | nn.GroupNorm(32, base_channels, ), 88 | nn.SiLU(), 89 | nn.Conv2d(base_channels, self.out_channels, 3, padding=1), 90 | ) 91 | self.to(self.device) 92 | def time_proj(self, time_steps, max_period: int = 10000): 93 | if time_steps.ndim == 0: 94 | time_steps = time_steps.unsqueeze(0) 95 | half = self.base_channels // 2 96 | frequencies = torch.exp(- math.log(max_period) 97 | * torch.arange(start=0, end=half, dtype=torch.float32) / half 98 | ).to(device=time_steps.device) 99 | angles = time_steps[:, None].float() * frequencies[None, :] 100 | return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1) 101 | 102 | def forward(self, x, time_steps, cond=None, encoder_hidden_states=None, output_dict=True): 103 | if cond is None and encoder_hidden_states is not None: 104 | cond = encoder_hidden_states 105 | t_emb = self.time_proj(time_steps) 106 | t_emb = self.time_embedding(t_emb) 107 | x = self.conv_in(x) 108 | down_x_cache = [x] 109 | for module in self.down_blocks: 110 | x = module(x, t_emb, cond) 111 | down_x_cache.append(x) 112 | x = self.mid_block(x, t_emb, cond) 113 | for module in self.up_blocks: 114 | x = module(torch.cat((x, down_x_cache.pop()), dim=1), t_emb, cond) 115 | x = self.output(x) 116 | if output_dict: 117 | return edict(sample=x) 118 | else: 119 | return x 120 | 121 | # Modified Container. Modify the nn.Sequential to time modulated Sequential 122 | class TimeModulatedSequential(nn.Sequential): 123 | """ Modify the nn.Sequential to time modulated Sequential """ 124 | def forward(self, x, t_emb, cond=None): 125 | for module in self: 126 | if isinstance(module, TimeModulatedSequential): 127 | x = module(x, t_emb, cond) 128 | elif isinstance(module, ResBlock): 129 | # For certain layers, add the time modulation. 130 | x = module(x, t_emb) 131 | elif isinstance(module, SpatialTransformer): 132 | # For certain layers, add the class conditioning. 133 | x = module(x, cond=cond) 134 | else: 135 | x = module(x) 136 | 137 | return x 138 | 139 | 140 | # backbone, Residual Block (Checked) 141 | class ResBlock(nn.Module): 142 | def __init__(self, in_channel, time_emb_dim, out_channel=None, ): 143 | super().__init__() 144 | if out_channel is None: 145 | out_channel = in_channel 146 | self.norm1 = nn.GroupNorm(32, in_channel, eps=1e-05, affine=True) 147 | self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1) 148 | self.time_emb_proj = nn.Linear(in_features=time_emb_dim, out_features=out_channel, bias=True) 149 | self.norm2 = nn.GroupNorm(32, out_channel, eps=1e-05, affine=True) 150 | self.dropout = nn.Dropout(p=0.0, inplace=False) 151 | self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1) 152 | self.nonlinearity = nn.SiLU() 153 | if in_channel == out_channel: 154 | self.conv_shortcut = nn.Identity() 155 | else: 156 | self.conv_shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1) 157 | 158 | def forward(self, x, t_emb, cond=None): 159 | # Input conv 160 | h = self.norm1(x) 161 | h = self.nonlinearity(h) 162 | h = self.conv1(h) 163 | # Time modulation 164 | if t_emb is not None: 165 | t_hidden = self.time_emb_proj(self.nonlinearity(t_emb)) 166 | h = h + t_hidden[:, :, None, None] 167 | # Output conv 168 | h = self.norm2(h) 169 | h = self.nonlinearity(h) 170 | h = self.dropout(h) 171 | h = self.conv2(h) 172 | # Skip connection 173 | return h + self.conv_shortcut(x) 174 | 175 | 176 | # UpSampling (Checked) 177 | class UpSample(nn.Module): 178 | def __init__(self, channel, scale_factor=2, mode='nearest'): 179 | super(UpSample, self).__init__() 180 | self.scale_factor = scale_factor 181 | self.mode = mode 182 | self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, ) 183 | 184 | def forward(self, x): 185 | x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) 186 | return self.conv(x) 187 | 188 | 189 | # DownSampling (Checked) 190 | class DownSample(nn.Module): 191 | def __init__(self, channel, ): 192 | super(DownSample, self).__init__() 193 | self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, ) 194 | 195 | def forward(self, x): 196 | return self.conv(x) # F.interpolate(x, scale_factor=1/self.scale_factor, mode=self.mode) 197 | 198 | 199 | # Transformer layers 200 | # Self and Cross Attention mechanism (Checked) 201 | class CrossAttention(nn.Module): 202 | """General implementation of Cross & Self Attention multi-head 203 | """ 204 | def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=8, ): 205 | super(CrossAttention, self).__init__() 206 | self.hidden_dim = hidden_dim 207 | self.context_dim = context_dim 208 | self.embed_dim = embed_dim 209 | self.num_heads = num_heads 210 | self.head_dim = embed_dim // num_heads 211 | self.to_q = nn.Linear(hidden_dim, embed_dim, bias=False) 212 | if context_dim is None: 213 | # Self Attention 214 | self.to_k = nn.Linear(hidden_dim, embed_dim, bias=False) 215 | self.to_v = nn.Linear(hidden_dim, embed_dim, bias=False) 216 | self.self_attn = True 217 | else: 218 | # Cross Attention 219 | self.to_k = nn.Linear(context_dim, embed_dim, bias=False) 220 | self.to_v = nn.Linear(context_dim, embed_dim, bias=False) 221 | self.self_attn = False 222 | self.to_out = nn.Sequential( 223 | nn.Linear(embed_dim, hidden_dim, bias=True) 224 | ) # this could be omitted 225 | 226 | def forward(self, tokens, context=None): 227 | Q = self.to_q(tokens) 228 | K = self.to_k(tokens) if self.self_attn else self.to_k(context) 229 | V = self.to_v(tokens) if self.self_attn else self.to_v(context) 230 | # print(Q.shape, K.shape, V.shape) 231 | # transform heads onto batch dimension 232 | Q = rearrange(Q, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 233 | K = rearrange(K, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 234 | V = rearrange(V, 'B T (H D) -> (B H) T D', H=self.num_heads, D=self.head_dim) 235 | # print(Q.shape, K.shape, V.shape) 236 | scoremats = torch.einsum("BTD,BSD->BTS", Q, K) 237 | attnmats = F.softmax(scoremats / math.sqrt(self.head_dim), dim=-1) 238 | # print(scoremats.shape, attnmats.shape, ) 239 | ctx_vecs = torch.einsum("BTS,BSD->BTD", attnmats, V) 240 | # split the heads transform back to hidden. 241 | ctx_vecs = rearrange(ctx_vecs, '(B H) T D -> B T (H D)', H=self.num_heads, D=self.head_dim) 242 | # TODO: note this `to_out` is also a linear layer, could be in principle merged into the to_value layer. 243 | return self.to_out(ctx_vecs) 244 | 245 | 246 | class TransformerBlock(nn.Module): 247 | def __init__(self, hidden_dim, context_dim, num_heads=8): 248 | super(TransformerBlock, self).__init__() 249 | self.attn1 = CrossAttention(hidden_dim, hidden_dim, num_heads=num_heads) # self attention 250 | self.attn2 = CrossAttention(hidden_dim, hidden_dim, context_dim, num_heads=num_heads) # cross attention 251 | 252 | self.norm1 = nn.LayerNorm(hidden_dim) 253 | self.norm2 = nn.LayerNorm(hidden_dim) 254 | self.norm3 = nn.LayerNorm(hidden_dim) 255 | # to be compatible with Diffuser, could simplify. 256 | self.ff = FeedForward_GEGLU(hidden_dim, ) 257 | # self.ff = nn.Sequential( 258 | # nn.Linear(hidden_dim, 3 * hidden_dim), 259 | # nn.GELU(), 260 | # nn.Linear(3 * hidden_dim, hidden_dim) 261 | # ) 262 | 263 | def forward(self, x, context=None): 264 | x = self.attn1(self.norm1(x)) + x 265 | x = self.attn2(self.norm2(x), context=context) + x 266 | x = self.ff(self.norm3(x)) + x 267 | return x 268 | 269 | 270 | class GEGLU_proj(nn.Module): 271 | def __init__(self, in_dim, out_dim): 272 | super(GEGLU_proj, self).__init__() 273 | self.proj = nn.Linear(in_dim, 2 * out_dim) 274 | 275 | def forward(self, x): 276 | x = self.proj(x) 277 | x, gates = x.chunk(2, dim=-1) 278 | return x * F.gelu(gates) 279 | 280 | 281 | class FeedForward_GEGLU(nn.Module): 282 | # https://github.com/huggingface/diffusers/blob/95414bd6bf9bb34a312a7c55f10ba9b379f33890/src/diffusers/models/attention.py#L339 283 | # A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 284 | def __init__(self, hidden_dim, mult=4): 285 | super(FeedForward_GEGLU, self).__init__() 286 | self.net = nn.Sequential( 287 | GEGLU_proj(hidden_dim, mult * hidden_dim), 288 | nn.Dropout(0.0), 289 | nn.Linear(mult * hidden_dim, hidden_dim) 290 | ) # to be compatible with Diffuser, could simplify. 291 | 292 | def forward(self, x, ): 293 | return self.net(x) 294 | 295 | 296 | class SpatialTransformer(nn.Module): 297 | def __init__(self, hidden_dim, context_dim, num_heads=8): 298 | super(SpatialTransformer, self).__init__() 299 | self.norm = nn.GroupNorm(32, hidden_dim, eps=1e-6, affine=True) 300 | self.proj_in = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0) 301 | # self.transformer = TransformerBlock(hidden_dim, context_dim, num_heads=8) 302 | self.transformer_blocks = nn.Sequential( 303 | TransformerBlock(hidden_dim, context_dim, num_heads=8) 304 | ) # to be compatible with Diffuser, could simplify. 305 | self.proj_out = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1, stride=1, padding=0) 306 | 307 | def forward(self, x, cond=None): 308 | b, c, h, w = x.shape 309 | x_in = x 310 | # context = rearrange(context, "b c T -> b T c") 311 | x = self.proj_in(self.norm(x)) 312 | x = rearrange(x, "b c h w->b (h w) c") 313 | x = self.transformer_blocks[0](x, cond) 314 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 315 | return self.proj_out(x) + x_in 316 | 317 | 318 | def load_pipe_into_our_UNet(myUNet, pipe_unet): 319 | # load the pretrained weights from the pipe into our UNet. 320 | # Loading input and output layers. 321 | myUNet.output[0].load_state_dict(pipe_unet.conv_norm_out.state_dict()) 322 | myUNet.output[2].load_state_dict(pipe_unet.conv_out.state_dict()) 323 | myUNet.conv_in.load_state_dict(pipe_unet.conv_in.state_dict()) 324 | myUNet.time_embedding.load_state_dict(pipe_unet.time_embedding.state_dict()) 325 | #% Loading the down blocks 326 | myUNet.down_blocks[0][0].load_state_dict(pipe_unet.down_blocks[0].resnets[0].state_dict()) 327 | myUNet.down_blocks[0][1].load_state_dict(pipe_unet.down_blocks[0].attentions[0].state_dict()) 328 | myUNet.down_blocks[1][0].load_state_dict(pipe_unet.down_blocks[0].resnets[1].state_dict()) 329 | myUNet.down_blocks[1][1].load_state_dict(pipe_unet.down_blocks[0].attentions[1].state_dict()) 330 | myUNet.down_blocks[2][0].load_state_dict(pipe_unet.down_blocks[0].downsamplers[0].state_dict()) 331 | 332 | myUNet.down_blocks[3][0].load_state_dict(pipe_unet.down_blocks[1].resnets[0].state_dict()) 333 | myUNet.down_blocks[3][1].load_state_dict(pipe_unet.down_blocks[1].attentions[0].state_dict()) 334 | myUNet.down_blocks[4][0].load_state_dict(pipe_unet.down_blocks[1].resnets[1].state_dict()) 335 | myUNet.down_blocks[4][1].load_state_dict(pipe_unet.down_blocks[1].attentions[1].state_dict()) 336 | myUNet.down_blocks[5][0].load_state_dict(pipe_unet.down_blocks[1].downsamplers[0].state_dict()) 337 | 338 | myUNet.down_blocks[6][0].load_state_dict(pipe_unet.down_blocks[2].resnets[0].state_dict()) 339 | myUNet.down_blocks[6][1].load_state_dict(pipe_unet.down_blocks[2].attentions[0].state_dict()) 340 | myUNet.down_blocks[7][0].load_state_dict(pipe_unet.down_blocks[2].resnets[1].state_dict()) 341 | myUNet.down_blocks[7][1].load_state_dict(pipe_unet.down_blocks[2].attentions[1].state_dict()) 342 | myUNet.down_blocks[8][0].load_state_dict(pipe_unet.down_blocks[2].downsamplers[0].state_dict()) 343 | 344 | myUNet.down_blocks[9][0].load_state_dict(pipe_unet.down_blocks[3].resnets[0].state_dict()) 345 | myUNet.down_blocks[10][0].load_state_dict(pipe_unet.down_blocks[3].resnets[1].state_dict()) 346 | 347 | #% Loading the middle blocks 348 | myUNet.mid_block[0].load_state_dict(pipe_unet.mid_block.resnets[0].state_dict()) 349 | myUNet.mid_block[1].load_state_dict(pipe_unet.mid_block.attentions[0].state_dict()) 350 | myUNet.mid_block[2].load_state_dict(pipe_unet.mid_block.resnets[1].state_dict()) 351 | # % Loading the up blocks 352 | # upblock 0 353 | myUNet.up_blocks[0][0].load_state_dict(pipe_unet.up_blocks[0].resnets[0].state_dict()) 354 | myUNet.up_blocks[1][0].load_state_dict(pipe_unet.up_blocks[0].resnets[1].state_dict()) 355 | myUNet.up_blocks[2][0].load_state_dict(pipe_unet.up_blocks[0].resnets[2].state_dict()) 356 | myUNet.up_blocks[2][1].load_state_dict(pipe_unet.up_blocks[0].upsamplers[0].state_dict()) 357 | # % upblock 1 358 | myUNet.up_blocks[3][0].load_state_dict(pipe_unet.up_blocks[1].resnets[0].state_dict()) 359 | myUNet.up_blocks[3][1].load_state_dict(pipe_unet.up_blocks[1].attentions[0].state_dict()) 360 | myUNet.up_blocks[4][0].load_state_dict(pipe_unet.up_blocks[1].resnets[1].state_dict()) 361 | myUNet.up_blocks[4][1].load_state_dict(pipe_unet.up_blocks[1].attentions[1].state_dict()) 362 | myUNet.up_blocks[5][0].load_state_dict(pipe_unet.up_blocks[1].resnets[2].state_dict()) 363 | myUNet.up_blocks[5][1].load_state_dict(pipe_unet.up_blocks[1].attentions[2].state_dict()) 364 | myUNet.up_blocks[5][2].load_state_dict(pipe_unet.up_blocks[1].upsamplers[0].state_dict()) 365 | # % upblock 2 366 | myUNet.up_blocks[6][0].load_state_dict(pipe_unet.up_blocks[2].resnets[0].state_dict()) 367 | myUNet.up_blocks[6][1].load_state_dict(pipe_unet.up_blocks[2].attentions[0].state_dict()) 368 | myUNet.up_blocks[7][0].load_state_dict(pipe_unet.up_blocks[2].resnets[1].state_dict()) 369 | myUNet.up_blocks[7][1].load_state_dict(pipe_unet.up_blocks[2].attentions[1].state_dict()) 370 | myUNet.up_blocks[8][0].load_state_dict(pipe_unet.up_blocks[2].resnets[2].state_dict()) 371 | myUNet.up_blocks[8][1].load_state_dict(pipe_unet.up_blocks[2].attentions[2].state_dict()) 372 | myUNet.up_blocks[8][2].load_state_dict(pipe_unet.up_blocks[2].upsamplers[0].state_dict()) 373 | # % upblock 3 374 | myUNet.up_blocks[9][0].load_state_dict(pipe_unet.up_blocks[3].resnets[0].state_dict()) 375 | myUNet.up_blocks[9][1].load_state_dict(pipe_unet.up_blocks[3].attentions[0].state_dict()) 376 | myUNet.up_blocks[10][0].load_state_dict(pipe_unet.up_blocks[3].resnets[1].state_dict()) 377 | myUNet.up_blocks[10][1].load_state_dict(pipe_unet.up_blocks[3].attentions[1].state_dict()) 378 | myUNet.up_blocks[11][0].load_state_dict(pipe_unet.up_blocks[3].resnets[2].state_dict()) 379 | myUNet.up_blocks[11][1].load_state_dict(pipe_unet.up_blocks[3].attentions[2].state_dict()) --------------------------------------------------------------------------------