├── data └── horse.jpg ├── assets └── teaser.png ├── requirements.txt ├── config_pnp.yaml ├── License ├── README.md ├── pnp.py ├── pnp_utils.py └── preprocess.py /data/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalGeyer/pnp-diffusers/HEAD/data/horse.jpg -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichalGeyer/pnp-diffusers/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | numpy==1.25.0 4 | diffusers==0.17.1 5 | xformers==0.0.20 6 | transformers==4.30.2 7 | accelerate==0.20.3 8 | -------------------------------------------------------------------------------- /config_pnp.yaml: -------------------------------------------------------------------------------- 1 | # general 2 | seed: 1 3 | device: 'cuda' 4 | output_path: 'PNP-results/horse' 5 | 6 | # data 7 | image_path: 'data/horse.jpg' 8 | latents_path: 'latents_forward' 9 | 10 | # diffusion 11 | sd_version: '2.1' 12 | guidance_scale: 7.5 13 | n_timesteps: 50 14 | prompt: a photo of a pink toy horse on the beach 15 | negative_prompt: ugly, blurry, black, low res, unrealistic 16 | 17 | # pnp injection thresholds, ∈ [0, 1] 18 | pnp_attn_t: 0.5 19 | pnp_f_t: 0.8 20 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Narek Tumanyan, Michal Geyer, Shai Bagon, Tali Dekel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation (CVPR 2023) 2 | 3 | ## [Project Page] 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-PnP-b31b1b.svg)](https://arxiv.org/abs/2211.12572) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/hysts/PnP-diffusion-features) [![TI2I](https://img.shields.io/badge/benchmarks-TI2I-blue)](https://www.dropbox.com/sh/8giw0uhfekft47h/AAAF1frwakVsQocKczZZSX6La?dl=0) 6 | 7 | ![teaser](assets/teaser.png) 8 | 9 | 10 | **To plug-and-play diffusion features, please follow these steps:** 11 | 12 | 1. [Setup](#setup) 13 | 2. [Latent extraction](#latent-extraction) 14 | 3. [Running PnP](#running-pnp) 15 | 16 | 17 | ## Setup 18 | 19 | Create the environment and install the dependencies by running: 20 | 21 | ``` 22 | conda create -n pnp-diffusers python=3.9 23 | conda activate pnp-diffusers 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 28 | ## Latent Extraction 29 | 30 | We first compute the intermediate noisy latents of the structure guidance image. To do that, run: 31 | 32 | ``` 33 | python preprocess.py --data_path --inversion_prompt 34 | ``` 35 | 36 | where `` should describe the content of the guidance image. The intermediate noisy latents will be saved under the path `latents_forward/`, where `` is the filename of the provided guidance image. 37 | 38 | 39 | ## Running PnP 40 | 41 | Run the following command for applying PnP on the structure guidance image: 42 | 43 | ``` 44 | python pnp.py --config_path 45 | ``` 46 | 47 | where `` is a path to a yaml config file. The config includes fields for providing the guidance image path, the PnP output path, translation prompt, guidance scale, PnP feature and self-attention injection thresholds, and additional hyperparameters. See an example config in `config_pnp.yaml`. 48 | 49 | 50 | ## Citation 51 | ``` 52 | @InProceedings{Tumanyan_2023_CVPR, 53 | author = {Tumanyan, Narek and Geyer, Michal and Bagon, Shai and Dekel, Tali}, 54 | title = {Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation}, 55 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 56 | month = {June}, 57 | year = {2023}, 58 | pages = {1921-1930} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /pnp.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as T 7 | import argparse 8 | from PIL import Image 9 | import yaml 10 | from tqdm import tqdm 11 | from transformers import logging 12 | from diffusers import DDIMScheduler, StableDiffusionPipeline 13 | 14 | from pnp_utils import * 15 | 16 | # suppress partial model loading warning 17 | logging.set_verbosity_error() 18 | 19 | class PNP(nn.Module): 20 | def __init__(self, config): 21 | super().__init__() 22 | self.config = config 23 | self.device = config["device"] 24 | sd_version = config["sd_version"] 25 | 26 | if sd_version == '2.1': 27 | model_key = "stabilityai/stable-diffusion-2-1-base" 28 | elif sd_version == '2.0': 29 | model_key = "stabilityai/stable-diffusion-2-base" 30 | elif sd_version == '1.5': 31 | model_key = "runwayml/stable-diffusion-v1-5" 32 | else: 33 | raise ValueError(f'Stable-diffusion version {sd_version} not supported.') 34 | 35 | # Create SD models 36 | print('Loading SD model') 37 | 38 | pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda") 39 | pipe.enable_xformers_memory_efficient_attention() 40 | 41 | self.vae = pipe.vae 42 | self.tokenizer = pipe.tokenizer 43 | self.text_encoder = pipe.text_encoder 44 | self.unet = pipe.unet 45 | 46 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") 47 | self.scheduler.set_timesteps(config["n_timesteps"], device=self.device) 48 | print('SD model loaded') 49 | 50 | # load image 51 | self.image, self.eps = self.get_data() 52 | 53 | self.text_embeds = self.get_text_embeds(config["prompt"], config["negative_prompt"]) 54 | self.pnp_guidance_embeds = self.get_text_embeds("", "").chunk(2)[0] 55 | 56 | 57 | @torch.no_grad() 58 | def get_text_embeds(self, prompt, negative_prompt, batch_size=1): 59 | # Tokenize text and get embeddings 60 | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, 61 | truncation=True, return_tensors='pt') 62 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 63 | 64 | # Do the same for unconditional embeddings 65 | uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, 66 | return_tensors='pt') 67 | 68 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 69 | 70 | # Cat for final embeddings 71 | text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size) 72 | return text_embeddings 73 | 74 | @torch.no_grad() 75 | def decode_latent(self, latent): 76 | with torch.autocast(device_type='cuda', dtype=torch.float32): 77 | latent = 1 / 0.18215 * latent 78 | img = self.vae.decode(latent).sample 79 | img = (img / 2 + 0.5).clamp(0, 1) 80 | return img 81 | 82 | @torch.autocast(device_type='cuda', dtype=torch.float32) 83 | def get_data(self): 84 | # load image 85 | image = Image.open(self.config["image_path"]).convert('RGB') 86 | image = image.resize((512, 512), resample=Image.Resampling.LANCZOS) 87 | image = T.ToTensor()(image).to(self.device) 88 | # get noise 89 | latents_path = os.path.join(self.config["latents_path"], os.path.splitext(os.path.basename(self.config["image_path"]))[0], f'noisy_latents_{self.scheduler.timesteps[0]}.pt') 90 | noisy_latent = torch.load(latents_path).to(self.device) 91 | return image, noisy_latent 92 | 93 | @torch.no_grad() 94 | def denoise_step(self, x, t): 95 | # register the time step and features in pnp injection modules 96 | source_latents = load_source_latents_t(t, os.path.join(self.config["latents_path"], os.path.splitext(os.path.basename(self.config["image_path"]))[0])) 97 | latent_model_input = torch.cat([source_latents] + ([x] * 2)) 98 | 99 | register_time(self, t.item()) 100 | 101 | # compute text embeddings 102 | text_embed_input = torch.cat([self.pnp_guidance_embeds, self.text_embeds], dim=0) 103 | 104 | # apply the denoising network 105 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample'] 106 | 107 | # perform guidance 108 | _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3) 109 | noise_pred = noise_pred_uncond + self.config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond) 110 | 111 | # compute the denoising step with the reference model 112 | denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample'] 113 | return denoised_latent 114 | 115 | def init_pnp(self, conv_injection_t, qk_injection_t): 116 | self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else [] 117 | self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else [] 118 | register_attention_control_efficient(self, self.qk_injection_timesteps) 119 | register_conv_control_efficient(self, self.conv_injection_timesteps) 120 | 121 | def run_pnp(self): 122 | pnp_f_t = int(self.config["n_timesteps"] * self.config["pnp_f_t"]) 123 | pnp_attn_t = int(self.config["n_timesteps"] * self.config["pnp_attn_t"]) 124 | self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t) 125 | edited_img = self.sample_loop(self.eps) 126 | 127 | def sample_loop(self, x): 128 | with torch.autocast(device_type='cuda', dtype=torch.float32): 129 | for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")): 130 | x = self.denoise_step(x, t) 131 | 132 | decoded_latent = self.decode_latent(x) 133 | T.ToPILImage()(decoded_latent[0]).save(f'{self.config["output_path"]}/output-{self.config["prompt"]}.png') 134 | 135 | return decoded_latent 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--config_path', type=str, default='pnp-configs/config-horse.yaml') 141 | opt = parser.parse_args() 142 | with open(opt.config_path, "r") as f: 143 | config = yaml.safe_load(f) 144 | os.makedirs(config["output_path"], exist_ok=True) 145 | with open(os.path.join(config["output_path"], "config.yaml"), "w") as f: 146 | yaml.dump(config, f) 147 | 148 | seed_everything(config["seed"]) 149 | print(config) 150 | pnp = PNP(config) 151 | pnp.run_pnp() -------------------------------------------------------------------------------- /pnp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | def seed_everything(seed): 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed(seed) 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | 12 | def register_time(model, t): 13 | conv_module = model.unet.up_blocks[1].resnets[1] 14 | setattr(conv_module, 't', t) 15 | down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]} 16 | up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} 17 | for res in up_res_dict: 18 | for block in up_res_dict[res]: 19 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 20 | setattr(module, 't', t) 21 | for res in down_res_dict: 22 | for block in down_res_dict[res]: 23 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1 24 | setattr(module, 't', t) 25 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1 26 | setattr(module, 't', t) 27 | 28 | 29 | def load_source_latents_t(t, latents_path): 30 | latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt') 31 | assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}' 32 | latents = torch.load(latents_t_path) 33 | return latents 34 | 35 | def register_attention_control_efficient(model, injection_schedule): 36 | def sa_forward(self): 37 | to_out = self.to_out 38 | if type(to_out) is torch.nn.modules.container.ModuleList: 39 | to_out = self.to_out[0] 40 | else: 41 | to_out = self.to_out 42 | 43 | def forward(x, encoder_hidden_states=None, attention_mask=None): 44 | batch_size, sequence_length, dim = x.shape 45 | h = self.heads 46 | 47 | is_cross = encoder_hidden_states is not None 48 | encoder_hidden_states = encoder_hidden_states if is_cross else x 49 | if not is_cross and self.injection_schedule is not None and ( 50 | self.t in self.injection_schedule or self.t == 1000): 51 | q = self.to_q(x) 52 | k = self.to_k(encoder_hidden_states) 53 | 54 | source_batch_size = int(q.shape[0] // 3) 55 | # inject unconditional 56 | q[source_batch_size:2 * source_batch_size] = q[:source_batch_size] 57 | k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] 58 | # inject conditional 59 | q[2 * source_batch_size:] = q[:source_batch_size] 60 | k[2 * source_batch_size:] = k[:source_batch_size] 61 | 62 | q = self.head_to_batch_dim(q) 63 | k = self.head_to_batch_dim(k) 64 | else: 65 | q = self.to_q(x) 66 | k = self.to_k(encoder_hidden_states) 67 | q = self.head_to_batch_dim(q) 68 | k = self.head_to_batch_dim(k) 69 | 70 | v = self.to_v(encoder_hidden_states) 71 | v = self.head_to_batch_dim(v) 72 | 73 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 74 | 75 | if attention_mask is not None: 76 | attention_mask = attention_mask.reshape(batch_size, -1) 77 | max_neg_value = -torch.finfo(sim.dtype).max 78 | attention_mask = attention_mask[:, None, :].repeat(h, 1, 1) 79 | sim.masked_fill_(~attention_mask, max_neg_value) 80 | 81 | # attention, what we cannot get enough of 82 | attn = sim.softmax(dim=-1) 83 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 84 | out = self.batch_to_head_dim(out) 85 | 86 | return to_out(out) 87 | 88 | return forward 89 | 90 | res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution 91 | for res in res_dict: 92 | for block in res_dict[res]: 93 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1 94 | module.forward = sa_forward(module) 95 | setattr(module, 'injection_schedule', injection_schedule) 96 | 97 | 98 | def register_conv_control_efficient(model, injection_schedule): 99 | def conv_forward(self): 100 | def forward(input_tensor, temb): 101 | hidden_states = input_tensor 102 | 103 | hidden_states = self.norm1(hidden_states) 104 | hidden_states = self.nonlinearity(hidden_states) 105 | 106 | if self.upsample is not None: 107 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 108 | if hidden_states.shape[0] >= 64: 109 | input_tensor = input_tensor.contiguous() 110 | hidden_states = hidden_states.contiguous() 111 | input_tensor = self.upsample(input_tensor) 112 | hidden_states = self.upsample(hidden_states) 113 | elif self.downsample is not None: 114 | input_tensor = self.downsample(input_tensor) 115 | hidden_states = self.downsample(hidden_states) 116 | 117 | hidden_states = self.conv1(hidden_states) 118 | 119 | if temb is not None: 120 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] 121 | 122 | if temb is not None and self.time_embedding_norm == "default": 123 | hidden_states = hidden_states + temb 124 | 125 | hidden_states = self.norm2(hidden_states) 126 | 127 | if temb is not None and self.time_embedding_norm == "scale_shift": 128 | scale, shift = torch.chunk(temb, 2, dim=1) 129 | hidden_states = hidden_states * (1 + scale) + shift 130 | 131 | hidden_states = self.nonlinearity(hidden_states) 132 | 133 | hidden_states = self.dropout(hidden_states) 134 | hidden_states = self.conv2(hidden_states) 135 | if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000): 136 | source_batch_size = int(hidden_states.shape[0] // 3) 137 | # inject unconditional 138 | hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size] 139 | # inject conditional 140 | hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size] 141 | 142 | if self.conv_shortcut is not None: 143 | input_tensor = self.conv_shortcut(input_tensor) 144 | 145 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 146 | 147 | return output_tensor 148 | 149 | return forward 150 | 151 | conv_module = model.unet.up_blocks[1].resnets[1] 152 | conv_module.forward = conv_forward(conv_module) 153 | setattr(conv_module, 'injection_schedule', injection_schedule) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer, logging 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler 3 | 4 | # suppress partial model loading warning 5 | logging.set_verbosity_error() 6 | 7 | import os 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | import torch 11 | import torch.nn as nn 12 | import argparse 13 | from pathlib import Path 14 | from pnp_utils import * 15 | import torchvision.transforms as T 16 | 17 | 18 | def get_timesteps(scheduler, num_inference_steps, strength, device): 19 | # get the original timestep using init_timestep 20 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 21 | 22 | t_start = max(num_inference_steps - init_timestep, 0) 23 | timesteps = scheduler.timesteps[t_start:] 24 | 25 | return timesteps, num_inference_steps - t_start 26 | 27 | 28 | class Preprocess(nn.Module): 29 | def __init__(self, device, sd_version='2.0', hf_key=None): 30 | super().__init__() 31 | 32 | self.device = device 33 | self.sd_version = sd_version 34 | self.use_depth = False 35 | 36 | print(f'[INFO] loading stable diffusion...') 37 | if hf_key is not None: 38 | print(f'[INFO] using hugging face custom model key: {hf_key}') 39 | model_key = hf_key 40 | elif self.sd_version == '2.1': 41 | model_key = "stabilityai/stable-diffusion-2-1-base" 42 | elif self.sd_version == '2.0': 43 | model_key = "stabilityai/stable-diffusion-2-base" 44 | elif self.sd_version == '1.5': 45 | model_key = "runwayml/stable-diffusion-v1-5" 46 | elif self.sd_version == 'depth': 47 | model_key = "stabilityai/stable-diffusion-2-depth" 48 | self.use_depth = True 49 | else: 50 | raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') 51 | 52 | # Create model 53 | self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16", 54 | torch_dtype=torch.float16).to(self.device) 55 | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") 56 | self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16", 57 | torch_dtype=torch.float16).to(self.device) 58 | self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16", 59 | torch_dtype=torch.float16).to(self.device) 60 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") 61 | print(f'[INFO] loaded stable diffusion!') 62 | 63 | self.inversion_func = self.ddim_inversion 64 | 65 | @torch.no_grad() 66 | def get_text_embeds(self, prompt, negative_prompt, device="cuda"): 67 | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, 68 | truncation=True, return_tensors='pt') 69 | text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] 70 | uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, 71 | return_tensors='pt') 72 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] 73 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 74 | return text_embeddings 75 | 76 | @torch.no_grad() 77 | def decode_latents(self, latents): 78 | with torch.autocast(device_type='cuda', dtype=torch.float32): 79 | latents = 1 / 0.18215 * latents 80 | imgs = self.vae.decode(latents).sample 81 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 82 | return imgs 83 | 84 | def load_img(self, image_path): 85 | image_pil = T.Resize(512)(Image.open(image_path).convert("RGB")) 86 | image = T.ToTensor()(image_pil).unsqueeze(0).to(device) 87 | return image 88 | 89 | @torch.no_grad() 90 | def encode_imgs(self, imgs): 91 | with torch.autocast(device_type='cuda', dtype=torch.float32): 92 | imgs = 2 * imgs - 1 93 | posterior = self.vae.encode(imgs).latent_dist 94 | latents = posterior.mean * 0.18215 95 | return latents 96 | 97 | @torch.no_grad() 98 | def ddim_inversion(self, cond, latent, save_path, save_latents=True, 99 | timesteps_to_save=None): 100 | timesteps = reversed(self.scheduler.timesteps) 101 | with torch.autocast(device_type='cuda', dtype=torch.float32): 102 | for i, t in enumerate(tqdm(timesteps)): 103 | cond_batch = cond.repeat(latent.shape[0], 1, 1) 104 | 105 | alpha_prod_t = self.scheduler.alphas_cumprod[t] 106 | alpha_prod_t_prev = ( 107 | self.scheduler.alphas_cumprod[timesteps[i - 1]] 108 | if i > 0 else self.scheduler.final_alpha_cumprod 109 | ) 110 | 111 | mu = alpha_prod_t ** 0.5 112 | mu_prev = alpha_prod_t_prev ** 0.5 113 | sigma = (1 - alpha_prod_t) ** 0.5 114 | sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 115 | 116 | eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample 117 | 118 | pred_x0 = (latent - sigma_prev * eps) / mu_prev 119 | latent = mu * pred_x0 + sigma * eps 120 | if save_latents: 121 | torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) 122 | torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) 123 | return latent 124 | 125 | @torch.no_grad() 126 | def ddim_sample(self, x, cond, save_path, save_latents=False, timesteps_to_save=None): 127 | timesteps = self.scheduler.timesteps 128 | with torch.autocast(device_type='cuda', dtype=torch.float32): 129 | for i, t in enumerate(tqdm(timesteps)): 130 | cond_batch = cond.repeat(x.shape[0], 1, 1) 131 | alpha_prod_t = self.scheduler.alphas_cumprod[t] 132 | alpha_prod_t_prev = ( 133 | self.scheduler.alphas_cumprod[timesteps[i + 1]] 134 | if i < len(timesteps) - 1 135 | else self.scheduler.final_alpha_cumprod 136 | ) 137 | mu = alpha_prod_t ** 0.5 138 | sigma = (1 - alpha_prod_t) ** 0.5 139 | mu_prev = alpha_prod_t_prev ** 0.5 140 | sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 141 | 142 | eps = self.unet(x, t, encoder_hidden_states=cond_batch).sample 143 | 144 | pred_x0 = (x - sigma * eps) / mu 145 | x = mu_prev * pred_x0 + sigma_prev * eps 146 | 147 | if save_latents: 148 | torch.save(x, os.path.join(save_path, f'noisy_latents_{t}.pt')) 149 | return x 150 | 151 | @torch.no_grad() 152 | def extract_latents(self, num_steps, data_path, save_path, timesteps_to_save, 153 | inversion_prompt='', extract_reverse=False): 154 | self.scheduler.set_timesteps(num_steps) 155 | 156 | cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0) 157 | image = self.load_img(data_path) 158 | latent = self.encode_imgs(image) 159 | 160 | inverted_x = self.inversion_func(cond, latent, save_path, save_latents=not extract_reverse, 161 | timesteps_to_save=timesteps_to_save) 162 | latent_reconstruction = self.ddim_sample(inverted_x, cond, save_path, save_latents=extract_reverse, 163 | timesteps_to_save=timesteps_to_save) 164 | rgb_reconstruction = self.decode_latents(latent_reconstruction) 165 | 166 | return rgb_reconstruction # , latent_reconstruction 167 | 168 | 169 | def run(opt): 170 | # timesteps to save 171 | if opt.sd_version == '2.1': 172 | model_key = "stabilityai/stable-diffusion-2-1-base" 173 | elif opt.sd_version == '2.0': 174 | model_key = "stabilityai/stable-diffusion-2-base" 175 | elif opt.sd_version == '1.5': 176 | model_key = "runwayml/stable-diffusion-v1-5" 177 | elif opt.sd_version == 'depth': 178 | model_key = "stabilityai/stable-diffusion-2-depth" 179 | toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") 180 | toy_scheduler.set_timesteps(opt.save_steps) 181 | timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt.save_steps, 182 | strength=1.0, 183 | device=device) 184 | 185 | seed_everything(opt.seed) 186 | 187 | extraction_path_prefix = "_reverse" if opt.extract_reverse else "_forward" 188 | save_path = os.path.join(opt.save_dir + extraction_path_prefix, os.path.splitext(os.path.basename(opt.data_path))[0]) 189 | os.makedirs(save_path, exist_ok=True) 190 | 191 | model = Preprocess(device, sd_version=opt.sd_version, hf_key=None) 192 | recon_image = model.extract_latents(data_path=opt.data_path, 193 | num_steps=opt.steps, 194 | save_path=save_path, 195 | timesteps_to_save=timesteps_to_save, 196 | inversion_prompt=opt.inversion_prompt, 197 | extract_reverse=opt.extract_reverse) 198 | 199 | T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg')) 200 | 201 | 202 | if __name__ == "__main__": 203 | device = 'cuda' 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('--data_path', type=str, 206 | default='data/horse.jpg') 207 | parser.add_argument('--save_dir', type=str, default='latents') 208 | parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], 209 | help="stable diffusion version") 210 | parser.add_argument('--seed', type=int, default=1) 211 | parser.add_argument('--steps', type=int, default=999) 212 | parser.add_argument('--save-steps', type=int, default=1000) 213 | parser.add_argument('--inversion_prompt', type=str, default='') 214 | parser.add_argument('--extract-reverse', default=False, action='store_true', help="extract features during the denoising process") 215 | opt = parser.parse_args() 216 | run(opt) 217 | --------------------------------------------------------------------------------