├── data
└── horse.jpg
├── assets
└── teaser.png
├── requirements.txt
├── config_pnp.yaml
├── License
├── README.md
├── pnp.py
├── pnp_utils.py
└── preprocess.py
/data/horse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MichalGeyer/pnp-diffusers/HEAD/data/horse.jpg
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MichalGeyer/pnp-diffusers/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision==0.15.2
3 | numpy==1.25.0
4 | diffusers==0.17.1
5 | xformers==0.0.20
6 | transformers==4.30.2
7 | accelerate==0.20.3
8 |
--------------------------------------------------------------------------------
/config_pnp.yaml:
--------------------------------------------------------------------------------
1 | # general
2 | seed: 1
3 | device: 'cuda'
4 | output_path: 'PNP-results/horse'
5 |
6 | # data
7 | image_path: 'data/horse.jpg'
8 | latents_path: 'latents_forward'
9 |
10 | # diffusion
11 | sd_version: '2.1'
12 | guidance_scale: 7.5
13 | n_timesteps: 50
14 | prompt: a photo of a pink toy horse on the beach
15 | negative_prompt: ugly, blurry, black, low res, unrealistic
16 |
17 | # pnp injection thresholds, ∈ [0, 1]
18 | pnp_attn_t: 0.5
19 | pnp_f_t: 0.8
20 |
--------------------------------------------------------------------------------
/License:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Narek Tumanyan, Michal Geyer, Shai Bagon, Tali Dekel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation (CVPR 2023)
2 |
3 | ## [Project Page]
4 |
5 | [](https://arxiv.org/abs/2211.12572) [](https://huggingface.co/spaces/hysts/PnP-diffusion-features)
[](https://www.dropbox.com/sh/8giw0uhfekft47h/AAAF1frwakVsQocKczZZSX6La?dl=0)
6 |
7 | 
8 |
9 |
10 | **To plug-and-play diffusion features, please follow these steps:**
11 |
12 | 1. [Setup](#setup)
13 | 2. [Latent extraction](#latent-extraction)
14 | 3. [Running PnP](#running-pnp)
15 |
16 |
17 | ## Setup
18 |
19 | Create the environment and install the dependencies by running:
20 |
21 | ```
22 | conda create -n pnp-diffusers python=3.9
23 | conda activate pnp-diffusers
24 | pip install -r requirements.txt
25 | ```
26 |
27 |
28 | ## Latent Extraction
29 |
30 | We first compute the intermediate noisy latents of the structure guidance image. To do that, run:
31 |
32 | ```
33 | python preprocess.py --data_path --inversion_prompt
34 | ```
35 |
36 | where `` should describe the content of the guidance image. The intermediate noisy latents will be saved under the path `latents_forward/`, where `` is the filename of the provided guidance image.
37 |
38 |
39 | ## Running PnP
40 |
41 | Run the following command for applying PnP on the structure guidance image:
42 |
43 | ```
44 | python pnp.py --config_path
45 | ```
46 |
47 | where `` is a path to a yaml config file. The config includes fields for providing the guidance image path, the PnP output path, translation prompt, guidance scale, PnP feature and self-attention injection thresholds, and additional hyperparameters. See an example config in `config_pnp.yaml`.
48 |
49 |
50 | ## Citation
51 | ```
52 | @InProceedings{Tumanyan_2023_CVPR,
53 | author = {Tumanyan, Narek and Geyer, Michal and Bagon, Shai and Dekel, Tali},
54 | title = {Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation},
55 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
56 | month = {June},
57 | year = {2023},
58 | pages = {1921-1930}
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/pnp.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from pathlib import Path
4 | import torch
5 | import torch.nn as nn
6 | import torchvision.transforms as T
7 | import argparse
8 | from PIL import Image
9 | import yaml
10 | from tqdm import tqdm
11 | from transformers import logging
12 | from diffusers import DDIMScheduler, StableDiffusionPipeline
13 |
14 | from pnp_utils import *
15 |
16 | # suppress partial model loading warning
17 | logging.set_verbosity_error()
18 |
19 | class PNP(nn.Module):
20 | def __init__(self, config):
21 | super().__init__()
22 | self.config = config
23 | self.device = config["device"]
24 | sd_version = config["sd_version"]
25 |
26 | if sd_version == '2.1':
27 | model_key = "stabilityai/stable-diffusion-2-1-base"
28 | elif sd_version == '2.0':
29 | model_key = "stabilityai/stable-diffusion-2-base"
30 | elif sd_version == '1.5':
31 | model_key = "runwayml/stable-diffusion-v1-5"
32 | else:
33 | raise ValueError(f'Stable-diffusion version {sd_version} not supported.')
34 |
35 | # Create SD models
36 | print('Loading SD model')
37 |
38 | pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda")
39 | pipe.enable_xformers_memory_efficient_attention()
40 |
41 | self.vae = pipe.vae
42 | self.tokenizer = pipe.tokenizer
43 | self.text_encoder = pipe.text_encoder
44 | self.unet = pipe.unet
45 |
46 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
47 | self.scheduler.set_timesteps(config["n_timesteps"], device=self.device)
48 | print('SD model loaded')
49 |
50 | # load image
51 | self.image, self.eps = self.get_data()
52 |
53 | self.text_embeds = self.get_text_embeds(config["prompt"], config["negative_prompt"])
54 | self.pnp_guidance_embeds = self.get_text_embeds("", "").chunk(2)[0]
55 |
56 |
57 | @torch.no_grad()
58 | def get_text_embeds(self, prompt, negative_prompt, batch_size=1):
59 | # Tokenize text and get embeddings
60 | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
61 | truncation=True, return_tensors='pt')
62 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
63 |
64 | # Do the same for unconditional embeddings
65 | uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
66 | return_tensors='pt')
67 |
68 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
69 |
70 | # Cat for final embeddings
71 | text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
72 | return text_embeddings
73 |
74 | @torch.no_grad()
75 | def decode_latent(self, latent):
76 | with torch.autocast(device_type='cuda', dtype=torch.float32):
77 | latent = 1 / 0.18215 * latent
78 | img = self.vae.decode(latent).sample
79 | img = (img / 2 + 0.5).clamp(0, 1)
80 | return img
81 |
82 | @torch.autocast(device_type='cuda', dtype=torch.float32)
83 | def get_data(self):
84 | # load image
85 | image = Image.open(self.config["image_path"]).convert('RGB')
86 | image = image.resize((512, 512), resample=Image.Resampling.LANCZOS)
87 | image = T.ToTensor()(image).to(self.device)
88 | # get noise
89 | latents_path = os.path.join(self.config["latents_path"], os.path.splitext(os.path.basename(self.config["image_path"]))[0], f'noisy_latents_{self.scheduler.timesteps[0]}.pt')
90 | noisy_latent = torch.load(latents_path).to(self.device)
91 | return image, noisy_latent
92 |
93 | @torch.no_grad()
94 | def denoise_step(self, x, t):
95 | # register the time step and features in pnp injection modules
96 | source_latents = load_source_latents_t(t, os.path.join(self.config["latents_path"], os.path.splitext(os.path.basename(self.config["image_path"]))[0]))
97 | latent_model_input = torch.cat([source_latents] + ([x] * 2))
98 |
99 | register_time(self, t.item())
100 |
101 | # compute text embeddings
102 | text_embed_input = torch.cat([self.pnp_guidance_embeds, self.text_embeds], dim=0)
103 |
104 | # apply the denoising network
105 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample']
106 |
107 | # perform guidance
108 | _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3)
109 | noise_pred = noise_pred_uncond + self.config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
110 |
111 | # compute the denoising step with the reference model
112 | denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
113 | return denoised_latent
114 |
115 | def init_pnp(self, conv_injection_t, qk_injection_t):
116 | self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
117 | self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
118 | register_attention_control_efficient(self, self.qk_injection_timesteps)
119 | register_conv_control_efficient(self, self.conv_injection_timesteps)
120 |
121 | def run_pnp(self):
122 | pnp_f_t = int(self.config["n_timesteps"] * self.config["pnp_f_t"])
123 | pnp_attn_t = int(self.config["n_timesteps"] * self.config["pnp_attn_t"])
124 | self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
125 | edited_img = self.sample_loop(self.eps)
126 |
127 | def sample_loop(self, x):
128 | with torch.autocast(device_type='cuda', dtype=torch.float32):
129 | for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")):
130 | x = self.denoise_step(x, t)
131 |
132 | decoded_latent = self.decode_latent(x)
133 | T.ToPILImage()(decoded_latent[0]).save(f'{self.config["output_path"]}/output-{self.config["prompt"]}.png')
134 |
135 | return decoded_latent
136 |
137 |
138 | if __name__ == '__main__':
139 | parser = argparse.ArgumentParser()
140 | parser.add_argument('--config_path', type=str, default='pnp-configs/config-horse.yaml')
141 | opt = parser.parse_args()
142 | with open(opt.config_path, "r") as f:
143 | config = yaml.safe_load(f)
144 | os.makedirs(config["output_path"], exist_ok=True)
145 | with open(os.path.join(config["output_path"], "config.yaml"), "w") as f:
146 | yaml.dump(config, f)
147 |
148 | seed_everything(config["seed"])
149 | print(config)
150 | pnp = PNP(config)
151 | pnp.run_pnp()
--------------------------------------------------------------------------------
/pnp_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 | import numpy as np
5 |
6 | def seed_everything(seed):
7 | torch.manual_seed(seed)
8 | torch.cuda.manual_seed(seed)
9 | random.seed(seed)
10 | np.random.seed(seed)
11 |
12 | def register_time(model, t):
13 | conv_module = model.unet.up_blocks[1].resnets[1]
14 | setattr(conv_module, 't', t)
15 | down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
16 | up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
17 | for res in up_res_dict:
18 | for block in up_res_dict[res]:
19 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
20 | setattr(module, 't', t)
21 | for res in down_res_dict:
22 | for block in down_res_dict[res]:
23 | module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
24 | setattr(module, 't', t)
25 | module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
26 | setattr(module, 't', t)
27 |
28 |
29 | def load_source_latents_t(t, latents_path):
30 | latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
31 | assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
32 | latents = torch.load(latents_t_path)
33 | return latents
34 |
35 | def register_attention_control_efficient(model, injection_schedule):
36 | def sa_forward(self):
37 | to_out = self.to_out
38 | if type(to_out) is torch.nn.modules.container.ModuleList:
39 | to_out = self.to_out[0]
40 | else:
41 | to_out = self.to_out
42 |
43 | def forward(x, encoder_hidden_states=None, attention_mask=None):
44 | batch_size, sequence_length, dim = x.shape
45 | h = self.heads
46 |
47 | is_cross = encoder_hidden_states is not None
48 | encoder_hidden_states = encoder_hidden_states if is_cross else x
49 | if not is_cross and self.injection_schedule is not None and (
50 | self.t in self.injection_schedule or self.t == 1000):
51 | q = self.to_q(x)
52 | k = self.to_k(encoder_hidden_states)
53 |
54 | source_batch_size = int(q.shape[0] // 3)
55 | # inject unconditional
56 | q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
57 | k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
58 | # inject conditional
59 | q[2 * source_batch_size:] = q[:source_batch_size]
60 | k[2 * source_batch_size:] = k[:source_batch_size]
61 |
62 | q = self.head_to_batch_dim(q)
63 | k = self.head_to_batch_dim(k)
64 | else:
65 | q = self.to_q(x)
66 | k = self.to_k(encoder_hidden_states)
67 | q = self.head_to_batch_dim(q)
68 | k = self.head_to_batch_dim(k)
69 |
70 | v = self.to_v(encoder_hidden_states)
71 | v = self.head_to_batch_dim(v)
72 |
73 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
74 |
75 | if attention_mask is not None:
76 | attention_mask = attention_mask.reshape(batch_size, -1)
77 | max_neg_value = -torch.finfo(sim.dtype).max
78 | attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
79 | sim.masked_fill_(~attention_mask, max_neg_value)
80 |
81 | # attention, what we cannot get enough of
82 | attn = sim.softmax(dim=-1)
83 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
84 | out = self.batch_to_head_dim(out)
85 |
86 | return to_out(out)
87 |
88 | return forward
89 |
90 | res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
91 | for res in res_dict:
92 | for block in res_dict[res]:
93 | module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
94 | module.forward = sa_forward(module)
95 | setattr(module, 'injection_schedule', injection_schedule)
96 |
97 |
98 | def register_conv_control_efficient(model, injection_schedule):
99 | def conv_forward(self):
100 | def forward(input_tensor, temb):
101 | hidden_states = input_tensor
102 |
103 | hidden_states = self.norm1(hidden_states)
104 | hidden_states = self.nonlinearity(hidden_states)
105 |
106 | if self.upsample is not None:
107 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
108 | if hidden_states.shape[0] >= 64:
109 | input_tensor = input_tensor.contiguous()
110 | hidden_states = hidden_states.contiguous()
111 | input_tensor = self.upsample(input_tensor)
112 | hidden_states = self.upsample(hidden_states)
113 | elif self.downsample is not None:
114 | input_tensor = self.downsample(input_tensor)
115 | hidden_states = self.downsample(hidden_states)
116 |
117 | hidden_states = self.conv1(hidden_states)
118 |
119 | if temb is not None:
120 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
121 |
122 | if temb is not None and self.time_embedding_norm == "default":
123 | hidden_states = hidden_states + temb
124 |
125 | hidden_states = self.norm2(hidden_states)
126 |
127 | if temb is not None and self.time_embedding_norm == "scale_shift":
128 | scale, shift = torch.chunk(temb, 2, dim=1)
129 | hidden_states = hidden_states * (1 + scale) + shift
130 |
131 | hidden_states = self.nonlinearity(hidden_states)
132 |
133 | hidden_states = self.dropout(hidden_states)
134 | hidden_states = self.conv2(hidden_states)
135 | if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
136 | source_batch_size = int(hidden_states.shape[0] // 3)
137 | # inject unconditional
138 | hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
139 | # inject conditional
140 | hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
141 |
142 | if self.conv_shortcut is not None:
143 | input_tensor = self.conv_shortcut(input_tensor)
144 |
145 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
146 |
147 | return output_tensor
148 |
149 | return forward
150 |
151 | conv_module = model.unet.up_blocks[1].resnets[1]
152 | conv_module.forward = conv_forward(conv_module)
153 | setattr(conv_module, 'injection_schedule', injection_schedule)
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPTextModel, CLIPTokenizer, logging
2 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
3 |
4 | # suppress partial model loading warning
5 | logging.set_verbosity_error()
6 |
7 | import os
8 | from PIL import Image
9 | from tqdm import tqdm, trange
10 | import torch
11 | import torch.nn as nn
12 | import argparse
13 | from pathlib import Path
14 | from pnp_utils import *
15 | import torchvision.transforms as T
16 |
17 |
18 | def get_timesteps(scheduler, num_inference_steps, strength, device):
19 | # get the original timestep using init_timestep
20 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
21 |
22 | t_start = max(num_inference_steps - init_timestep, 0)
23 | timesteps = scheduler.timesteps[t_start:]
24 |
25 | return timesteps, num_inference_steps - t_start
26 |
27 |
28 | class Preprocess(nn.Module):
29 | def __init__(self, device, sd_version='2.0', hf_key=None):
30 | super().__init__()
31 |
32 | self.device = device
33 | self.sd_version = sd_version
34 | self.use_depth = False
35 |
36 | print(f'[INFO] loading stable diffusion...')
37 | if hf_key is not None:
38 | print(f'[INFO] using hugging face custom model key: {hf_key}')
39 | model_key = hf_key
40 | elif self.sd_version == '2.1':
41 | model_key = "stabilityai/stable-diffusion-2-1-base"
42 | elif self.sd_version == '2.0':
43 | model_key = "stabilityai/stable-diffusion-2-base"
44 | elif self.sd_version == '1.5':
45 | model_key = "runwayml/stable-diffusion-v1-5"
46 | elif self.sd_version == 'depth':
47 | model_key = "stabilityai/stable-diffusion-2-depth"
48 | self.use_depth = True
49 | else:
50 | raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
51 |
52 | # Create model
53 | self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16",
54 | torch_dtype=torch.float16).to(self.device)
55 | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
56 | self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16",
57 | torch_dtype=torch.float16).to(self.device)
58 | self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16",
59 | torch_dtype=torch.float16).to(self.device)
60 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
61 | print(f'[INFO] loaded stable diffusion!')
62 |
63 | self.inversion_func = self.ddim_inversion
64 |
65 | @torch.no_grad()
66 | def get_text_embeds(self, prompt, negative_prompt, device="cuda"):
67 | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
68 | truncation=True, return_tensors='pt')
69 | text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
70 | uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
71 | return_tensors='pt')
72 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
73 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
74 | return text_embeddings
75 |
76 | @torch.no_grad()
77 | def decode_latents(self, latents):
78 | with torch.autocast(device_type='cuda', dtype=torch.float32):
79 | latents = 1 / 0.18215 * latents
80 | imgs = self.vae.decode(latents).sample
81 | imgs = (imgs / 2 + 0.5).clamp(0, 1)
82 | return imgs
83 |
84 | def load_img(self, image_path):
85 | image_pil = T.Resize(512)(Image.open(image_path).convert("RGB"))
86 | image = T.ToTensor()(image_pil).unsqueeze(0).to(device)
87 | return image
88 |
89 | @torch.no_grad()
90 | def encode_imgs(self, imgs):
91 | with torch.autocast(device_type='cuda', dtype=torch.float32):
92 | imgs = 2 * imgs - 1
93 | posterior = self.vae.encode(imgs).latent_dist
94 | latents = posterior.mean * 0.18215
95 | return latents
96 |
97 | @torch.no_grad()
98 | def ddim_inversion(self, cond, latent, save_path, save_latents=True,
99 | timesteps_to_save=None):
100 | timesteps = reversed(self.scheduler.timesteps)
101 | with torch.autocast(device_type='cuda', dtype=torch.float32):
102 | for i, t in enumerate(tqdm(timesteps)):
103 | cond_batch = cond.repeat(latent.shape[0], 1, 1)
104 |
105 | alpha_prod_t = self.scheduler.alphas_cumprod[t]
106 | alpha_prod_t_prev = (
107 | self.scheduler.alphas_cumprod[timesteps[i - 1]]
108 | if i > 0 else self.scheduler.final_alpha_cumprod
109 | )
110 |
111 | mu = alpha_prod_t ** 0.5
112 | mu_prev = alpha_prod_t_prev ** 0.5
113 | sigma = (1 - alpha_prod_t) ** 0.5
114 | sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
115 |
116 | eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample
117 |
118 | pred_x0 = (latent - sigma_prev * eps) / mu_prev
119 | latent = mu * pred_x0 + sigma * eps
120 | if save_latents:
121 | torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
122 | torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
123 | return latent
124 |
125 | @torch.no_grad()
126 | def ddim_sample(self, x, cond, save_path, save_latents=False, timesteps_to_save=None):
127 | timesteps = self.scheduler.timesteps
128 | with torch.autocast(device_type='cuda', dtype=torch.float32):
129 | for i, t in enumerate(tqdm(timesteps)):
130 | cond_batch = cond.repeat(x.shape[0], 1, 1)
131 | alpha_prod_t = self.scheduler.alphas_cumprod[t]
132 | alpha_prod_t_prev = (
133 | self.scheduler.alphas_cumprod[timesteps[i + 1]]
134 | if i < len(timesteps) - 1
135 | else self.scheduler.final_alpha_cumprod
136 | )
137 | mu = alpha_prod_t ** 0.5
138 | sigma = (1 - alpha_prod_t) ** 0.5
139 | mu_prev = alpha_prod_t_prev ** 0.5
140 | sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
141 |
142 | eps = self.unet(x, t, encoder_hidden_states=cond_batch).sample
143 |
144 | pred_x0 = (x - sigma * eps) / mu
145 | x = mu_prev * pred_x0 + sigma_prev * eps
146 |
147 | if save_latents:
148 | torch.save(x, os.path.join(save_path, f'noisy_latents_{t}.pt'))
149 | return x
150 |
151 | @torch.no_grad()
152 | def extract_latents(self, num_steps, data_path, save_path, timesteps_to_save,
153 | inversion_prompt='', extract_reverse=False):
154 | self.scheduler.set_timesteps(num_steps)
155 |
156 | cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0)
157 | image = self.load_img(data_path)
158 | latent = self.encode_imgs(image)
159 |
160 | inverted_x = self.inversion_func(cond, latent, save_path, save_latents=not extract_reverse,
161 | timesteps_to_save=timesteps_to_save)
162 | latent_reconstruction = self.ddim_sample(inverted_x, cond, save_path, save_latents=extract_reverse,
163 | timesteps_to_save=timesteps_to_save)
164 | rgb_reconstruction = self.decode_latents(latent_reconstruction)
165 |
166 | return rgb_reconstruction # , latent_reconstruction
167 |
168 |
169 | def run(opt):
170 | # timesteps to save
171 | if opt.sd_version == '2.1':
172 | model_key = "stabilityai/stable-diffusion-2-1-base"
173 | elif opt.sd_version == '2.0':
174 | model_key = "stabilityai/stable-diffusion-2-base"
175 | elif opt.sd_version == '1.5':
176 | model_key = "runwayml/stable-diffusion-v1-5"
177 | elif opt.sd_version == 'depth':
178 | model_key = "stabilityai/stable-diffusion-2-depth"
179 | toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
180 | toy_scheduler.set_timesteps(opt.save_steps)
181 | timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt.save_steps,
182 | strength=1.0,
183 | device=device)
184 |
185 | seed_everything(opt.seed)
186 |
187 | extraction_path_prefix = "_reverse" if opt.extract_reverse else "_forward"
188 | save_path = os.path.join(opt.save_dir + extraction_path_prefix, os.path.splitext(os.path.basename(opt.data_path))[0])
189 | os.makedirs(save_path, exist_ok=True)
190 |
191 | model = Preprocess(device, sd_version=opt.sd_version, hf_key=None)
192 | recon_image = model.extract_latents(data_path=opt.data_path,
193 | num_steps=opt.steps,
194 | save_path=save_path,
195 | timesteps_to_save=timesteps_to_save,
196 | inversion_prompt=opt.inversion_prompt,
197 | extract_reverse=opt.extract_reverse)
198 |
199 | T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg'))
200 |
201 |
202 | if __name__ == "__main__":
203 | device = 'cuda'
204 | parser = argparse.ArgumentParser()
205 | parser.add_argument('--data_path', type=str,
206 | default='data/horse.jpg')
207 | parser.add_argument('--save_dir', type=str, default='latents')
208 | parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'],
209 | help="stable diffusion version")
210 | parser.add_argument('--seed', type=int, default=1)
211 | parser.add_argument('--steps', type=int, default=999)
212 | parser.add_argument('--save-steps', type=int, default=1000)
213 | parser.add_argument('--inversion_prompt', type=str, default='')
214 | parser.add_argument('--extract-reverse', default=False, action='store_true', help="extract features during the denoising process")
215 | opt = parser.parse_args()
216 | run(opt)
217 |
--------------------------------------------------------------------------------