├── .gitignore ├── LICENSE ├── README.md ├── ddm_inversion ├── ddim_inversion.py ├── inversion_utils.py └── utils.py ├── example_images ├── horse_mud.jpg └── sketch_cat.jpg ├── imgs └── teaser.jpg ├── main_run.py ├── prompt_to_prompt ├── LICENSE ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── ddim_inversion.cpython-38.pyc │ ├── inversion_utils.cpython-38.pyc │ ├── inversion_utils_vova.cpython-38.pyc │ ├── ptp_classes.cpython-38.pyc │ ├── ptp_utils.cpython-38.pyc │ ├── seq_aligner.cpython-38.pyc │ └── utils.cpython-38.pyc ├── contributing.md ├── data │ └── horse.jpg ├── docs │ ├── null_text_teaser.png │ └── teaser.png ├── example_images │ └── gnochi_mirror.jpeg ├── ptp_classes.py ├── ptp_utils.py ├── requirements.txt └── seq_aligner.py ├── requirements.txt └── test.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | results_100/ 3 | prompt_to_prompt/__pycache__/utils.cpython-38.pyc 4 | prompt_to_prompt/__pycache__/* 5 | *.pyc 6 | results/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 inbarhub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Python 3.8](https://img.shields.io/badge/python-3.812+-blue)](https://www.python.org/downloads/release/python-38/) 3 | [![torch](https://img.shields.io/badge/torch-2.0.0+-green)](https://pytorch.org/) 4 | 5 | 6 | # DDPM inversion, CVPR 2024 7 | 8 | [Project page](https://inbarhub.github.io/DDPM_inversion/) | [Arxiv](https://arxiv.org/abs/2304.06140) | [Supplementary materials](https://inbarhub.github.io/DDPM_inversion/resources/inversion_supp.pdf) | [Hugging Face Demo](https://huggingface.co/spaces/LinoyTsaban/edit_friendly_ddpm_inversion) 9 | ### Official pytorch implementation of the paper:
"An Edit Friendly DDPM Noise Space: Inversion and Manipulations" 10 | #### Inbar Huberman-Spiegelglas, Vladimir Kulikov and Tomer Michaeli 11 |
12 | 13 | ![](imgs/teaser.jpg) 14 | Our inversion can be used for text-based **editing of real images**, either by itself or in combination with other editing methods. 15 | Due to the stochastic nature of our method, we can generate **diverse outputs**, a feature that is not naturally available with methods relying on the DDIM inversion. 16 | 17 | In this repository we support editing using our inversion, prompt-to-prompt (p2p)+our inversion, ddim or [p2p](https://github.com/google/prompt-to-prompt) (with ddim inversion).
18 | **our inversion**: our ddpm inversion followed by generating an image conditioned on the target prompt. 19 | 20 | **prompt-to-prompt (p2p) + our inversion**: p2p method using our ddpm inversion. 21 | 22 | **ddim**: ddim inversion followed by generating an image conditioned on the target prompt. 23 | 24 | **p2p**: p2p method using ddim inversion (original paper). 25 | 26 | ## Table of Contents 27 | * [Requirements](#Requirements) 28 | * [Repository Structure](#Repository-Structure) 29 | * [Algorithm Inputs and Parameters](#Algorithm-Inputs-and-Parameters) 30 | * [Usage Example](#Usage-Example) 31 | 32 | * [Citation](#Citation) 33 | 34 | ## Requirements 35 | 36 | ``` 37 | python -m pip install -r requirements.txt 38 | ``` 39 | This code was tested with python 3.8 and torch 2.0.0. 40 | 41 | ## Repository Structure 42 | ``` 43 | ├── ddm_inversion - folder contains inversions in order to work on real images: ddim inversion as well as ddpm inversion (our method). 44 | ├── example_images - folder of input images to be edited 45 | ├── imgs - images used in this repository readme.md file 46 | ├── prompt_to_prompt - p2p code 47 | ├── main_run.py - main python file for real image editing 48 | └── test.yaml - yaml file contains images and prompts to test on 49 | ``` 50 | 51 | A folder named 'results' will be automatically created and all the results will be saved to this folder. We also add a timestamp to the saved images in this folder. 52 | 53 | ## Algorithm Inputs and Parameters 54 | Method's inputs: 55 | ``` 56 | init_img - the path to the input images 57 | source_prompt - a prompt describing the input image 58 | target_prompts - the edit prompt (creates several images if multiple prompts are given) 59 | ``` 60 | These three inputs are supplied through a YAML file (please use the provided 'test.yaml' file as a reference). 61 | 62 |
63 | Method's parameters are: 64 | 65 | ``` 66 | skip - controlling the adherence to the input image 67 | cfg_tar - classifier free guidance strengths 68 | ``` 69 | These two parameters have default values, as descibed in the paper. 70 | 71 | ## Usage Example 72 | ``` 73 | python3 main_run.py --mode="our_inv" --dataset_yaml="test.yaml" --skip=36 --cfg_tar=15 74 | python3 main_run.py --mode="p2pinv" --dataset_yaml="test.yaml" --skip=12 --cfg_tar=9 75 | 76 | ``` 77 | The ```mode``` argument can also be: ```ddim``` or ```p2p```. 78 | 79 | In ```our_inv``` and ```p2pinv``` modes we suggest to play around with ```skip``` in the range [0,40] and ```cfg_tar``` in the range [7,18]. 80 | 81 | **p2pinv and p2p**: 82 | Note that you can play with the cross-and self-attention via ```--xa``` and ```--sa``` arguments. We suggest to set them to (0.6,0.2) and (0.8,0.4) for p2pinv and p2p respectively. 83 | 84 | **ddim and p2p**: 85 | ```skip``` is overwritten to be 0. 86 | 87 | 95 | 96 | You can edit the test.yaml file to load your image and choose the desired prompts. 97 | 98 | 103 | 104 | ## Citation 105 | If you use this code for your research, please cite our paper: 106 | ``` 107 | @inproceedings{huberman2024edit, 108 | title={An edit friendly {DDPM} noise space: Inversion and manipulations}, 109 | author={Huberman-Spiegelglas, Inbar and Kulikov, Vladimir and Michaeli, Tomer}, 110 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 111 | pages={12469--12478}, 112 | year={2024} 113 | } 114 | ``` -------------------------------------------------------------------------------- /ddm_inversion/ddim_inversion.py: -------------------------------------------------------------------------------- 1 | 2 | from ddm_inversion.inversion_utils import encode_text 3 | from typing import Union 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | def next_step(model, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 9 | timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep 10 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] if timestep >= 0 else model.scheduler.final_alpha_cumprod 11 | alpha_prod_t_next = model.scheduler.alphas_cumprod[next_timestep] 12 | beta_prod_t = 1 - alpha_prod_t 13 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 14 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 15 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 16 | return next_sample 17 | 18 | def get_noise_pred(model, latent, t, context, cfg_scale): 19 | latents_input = torch.cat([latent] * 2) 20 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 21 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 22 | noise_pred = noise_pred_uncond + cfg_scale * (noise_prediction_text - noise_pred_uncond) 23 | # latents = next_step(model, noise_pred, t, latent) 24 | return noise_pred 25 | 26 | @torch.no_grad() 27 | def ddim_loop(model, w0, prompt, cfg_scale): 28 | # uncond_embeddings, cond_embeddings = self.context.chunk(2) 29 | # all_latent = [latent] 30 | text_embedding = encode_text(model, prompt) 31 | uncond_embedding = encode_text(model, "") 32 | context = torch.cat([uncond_embedding, text_embedding]) 33 | latent = w0.clone().detach() 34 | for i in tqdm(range(model.scheduler.num_inference_steps)): 35 | t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1] 36 | noise_pred = get_noise_pred(model, latent, t, context, cfg_scale) 37 | latent = next_step(model, noise_pred, t, latent) 38 | # all_latent.append(latent) 39 | return latent 40 | 41 | @torch.no_grad() 42 | def ddim_inversion(model, w0, prompt, cfg_scale): 43 | wT = ddim_loop(model, w0, prompt, cfg_scale) 44 | return wT -------------------------------------------------------------------------------- /ddm_inversion/inversion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from tqdm import tqdm 4 | 5 | def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'): 6 | from ddm_inversion.utils import pil_to_tensor 7 | from PIL import Image 8 | from glob import glob 9 | if img_name is not None: 10 | path = os.path.join(folder, img_name) 11 | else: 12 | path = glob(folder + "*")[idx] 13 | 14 | img = Image.open(path).resize((img_size, 15 | img_size)) 16 | 17 | img = pil_to_tensor(img).to(device) 18 | 19 | if img.shape[1]== 4: 20 | img = img[:,:3,:,:] 21 | return img 22 | 23 | def mu_tilde(model, xt,x0, timestep): 24 | "mu_tilde(x_t, x_0) DDPM paper eq. 7" 25 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 26 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 27 | alpha_t = model.scheduler.alphas[timestep] 28 | beta_t = 1 - alpha_t 29 | alpha_bar = model.scheduler.alphas_cumprod[timestep] 30 | return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt 31 | 32 | def sample_xts_from_x0(model, x0, num_inference_steps=50): 33 | """ 34 | Samples from P(x_1:T|x_0) 35 | """ 36 | # torch.manual_seed(43256465436) 37 | alpha_bar = model.scheduler.alphas_cumprod 38 | sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5 39 | alphas = model.scheduler.alphas 40 | betas = 1 - alphas 41 | variance_noise_shape = ( 42 | num_inference_steps, 43 | model.unet.in_channels, 44 | model.unet.sample_size, 45 | model.unet.sample_size) 46 | 47 | timesteps = model.scheduler.timesteps.to(model.device) 48 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)} 49 | xts = torch.zeros((num_inference_steps+1,model.unet.in_channels, model.unet.sample_size, model.unet.sample_size)).to(x0.device) 50 | xts[0] = x0 51 | for t in reversed(timesteps): 52 | idx = num_inference_steps-t_to_idx[int(t)] 53 | xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t] 54 | 55 | 56 | return xts 57 | 58 | 59 | def encode_text(model, prompts): 60 | text_input = model.tokenizer( 61 | prompts, 62 | padding="max_length", 63 | max_length=model.tokenizer.model_max_length, 64 | truncation=True, 65 | return_tensors="pt", 66 | ) 67 | with torch.no_grad(): 68 | text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] 69 | return text_encoding 70 | 71 | def forward_step(model, model_output, timestep, sample): 72 | next_timestep = min(model.scheduler.config.num_train_timesteps - 2, 73 | timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps) 74 | 75 | # 2. compute alphas, betas 76 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] 77 | # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod 78 | 79 | beta_prod_t = 1 - alpha_prod_t 80 | 81 | # 3. compute predicted original sample from predicted noise also called 82 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 83 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 84 | 85 | # 5. TODO: simple noising implementatiom 86 | next_sample = model.scheduler.add_noise(pred_original_sample, 87 | model_output, 88 | torch.LongTensor([next_timestep])) 89 | return next_sample 90 | 91 | 92 | def get_variance(model, timestep): #, prev_timestep): 93 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 94 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] 95 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 96 | beta_prod_t = 1 - alpha_prod_t 97 | beta_prod_t_prev = 1 - alpha_prod_t_prev 98 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 99 | return variance 100 | 101 | def inversion_forward_process(model, x0, 102 | etas = None, 103 | prog_bar = False, 104 | prompt = "", 105 | cfg_scale = 3.5, 106 | num_inference_steps=50, eps = None): 107 | 108 | if not prompt=="": 109 | text_embeddings = encode_text(model, prompt) 110 | uncond_embedding = encode_text(model, "") 111 | timesteps = model.scheduler.timesteps.to(model.device) 112 | variance_noise_shape = ( 113 | num_inference_steps, 114 | model.unet.in_channels, 115 | model.unet.sample_size, 116 | model.unet.sample_size) 117 | if etas is None or (type(etas) in [int, float] and etas == 0): 118 | eta_is_zero = True 119 | zs = None 120 | else: 121 | eta_is_zero = False 122 | if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps 123 | xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps) 124 | alpha_bar = model.scheduler.alphas_cumprod 125 | zs = torch.zeros(size=variance_noise_shape, device=model.device) 126 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)} 127 | xt = x0 128 | # op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps) 129 | op = tqdm(timesteps) if prog_bar else timesteps 130 | 131 | for t in op: 132 | # idx = t_to_idx[int(t)] 133 | idx = num_inference_steps-t_to_idx[int(t)]-1 134 | # 1. predict noise residual 135 | if not eta_is_zero: 136 | xt = xts[idx+1][None] 137 | # xt = xts_cycle[idx+1][None] 138 | 139 | with torch.no_grad(): 140 | out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding) 141 | if not prompt=="": 142 | cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings) 143 | 144 | if not prompt=="": 145 | ## classifier free guidance 146 | noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) 147 | else: 148 | noise_pred = out.sample 149 | if eta_is_zero: 150 | # 2. compute more noisy image and set x_t -> x_t+1 151 | xt = forward_step(model, noise_pred, t, xt) 152 | 153 | else: 154 | # xtm1 = xts[idx+1][None] 155 | xtm1 = xts[idx][None] 156 | # pred of x0 157 | pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5 158 | 159 | # direction to xt 160 | prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 161 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 162 | 163 | variance = get_variance(model, t) 164 | pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred 165 | 166 | mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 167 | 168 | z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 ) 169 | zs[idx] = z 170 | 171 | # correction to avoid error accumulation 172 | xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z 173 | xts[idx] = xtm1 174 | 175 | if not zs is None: 176 | zs[0] = torch.zeros_like(zs[0]) 177 | 178 | return xt, zs, xts 179 | 180 | 181 | def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None): 182 | # 1. get previous step value (=t-1) 183 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps 184 | # 2. compute alphas, betas 185 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] 186 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod 187 | beta_prod_t = 1 - alpha_prod_t 188 | # 3. compute predicted original sample from predicted noise also called 189 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 190 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 191 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 192 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 193 | # variance = self.scheduler._get_variance(timestep, prev_timestep) 194 | variance = get_variance(model, timestep) #, prev_timestep) 195 | std_dev_t = eta * variance ** (0.5) 196 | # Take care of asymetric reverse process (asyrp) 197 | model_output_direction = model_output 198 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 199 | # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction 200 | pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction 201 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 202 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 203 | # 8. Add noice if eta > 0 204 | if eta > 0: 205 | if variance_noise is None: 206 | variance_noise = torch.randn(model_output.shape, device=model.device) 207 | sigma_z = eta * variance ** (0.5) * variance_noise 208 | prev_sample = prev_sample + sigma_z 209 | 210 | return prev_sample 211 | 212 | def inversion_reverse_process(model, 213 | xT, 214 | etas = 0, 215 | prompts = "", 216 | cfg_scales = None, 217 | prog_bar = False, 218 | zs = None, 219 | controller=None, 220 | asyrp = False): 221 | 222 | batch_size = len(prompts) 223 | 224 | cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device) 225 | 226 | text_embeddings = encode_text(model, prompts) 227 | uncond_embedding = encode_text(model, [""] * batch_size) 228 | 229 | if etas is None: etas = 0 230 | if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps 231 | assert len(etas) == model.scheduler.num_inference_steps 232 | timesteps = model.scheduler.timesteps.to(model.device) 233 | 234 | xt = xT.expand(batch_size, -1, -1, -1) 235 | op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:] 236 | 237 | t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])} 238 | 239 | for t in op: 240 | idx = model.scheduler.num_inference_steps-t_to_idx[int(t)]-(model.scheduler.num_inference_steps-zs.shape[0]+1) 241 | ## Unconditional embedding 242 | with torch.no_grad(): 243 | uncond_out = model.unet.forward(xt, timestep = t, 244 | encoder_hidden_states = uncond_embedding) 245 | 246 | ## Conditional embedding 247 | if prompts: 248 | with torch.no_grad(): 249 | cond_out = model.unet.forward(xt, timestep = t, 250 | encoder_hidden_states = text_embeddings) 251 | 252 | 253 | z = zs[idx] if not zs is None else None 254 | z = z.expand(batch_size, -1, -1, -1) 255 | if prompts: 256 | ## classifier free guidance 257 | noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample) 258 | else: 259 | noise_pred = uncond_out.sample 260 | # 2. compute less noisy image and set x_t -> x_t-1 261 | xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z) 262 | if controller is not None: 263 | xt = controller.step_callback(xt) 264 | return xt, zs 265 | 266 | 267 | -------------------------------------------------------------------------------- /ddm_inversion/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from PIL import Image, ImageDraw ,ImageFont 3 | from matplotlib import pyplot as plt 4 | import torchvision.transforms as T 5 | import os 6 | import torch 7 | import yaml 8 | 9 | def show_torch_img(img): 10 | img = to_np_image(img) 11 | plt.imshow(img) 12 | plt.axis("off") 13 | 14 | def to_np_image(all_images): 15 | all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0] 16 | return all_images 17 | 18 | def tensor_to_pil(tensor_imgs): 19 | if type(tensor_imgs) == list: 20 | tensor_imgs = torch.cat(tensor_imgs) 21 | tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1) 22 | to_pil = T.ToPILImage() 23 | pil_imgs = [to_pil(img) for img in tensor_imgs] 24 | return pil_imgs 25 | 26 | def pil_to_tensor(pil_imgs): 27 | to_torch = T.ToTensor() 28 | if type(pil_imgs) == PIL.Image.Image: 29 | tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1 30 | elif type(pil_imgs) == list: 31 | tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device) 32 | else: 33 | raise Exception("Input need to be PIL.Image or list of PIL.Image") 34 | return tensor_imgs 35 | 36 | 37 | ## TODO implement this 38 | # n = 10 39 | # num_rows = 4 40 | # num_col = n // num_rows 41 | # num_col = num_col + 1 if n % num_rows else num_col 42 | # num_col 43 | def add_margin(pil_img, top = 0, right = 0, bottom = 0, 44 | left = 0, color = (255,255,255)): 45 | width, height = pil_img.size 46 | new_width = width + right + left 47 | new_height = height + top + bottom 48 | result = Image.new(pil_img.mode, (new_width, new_height), color) 49 | 50 | result.paste(pil_img, (left, top)) 51 | return result 52 | 53 | def image_grid(imgs, rows = 1, cols = None, 54 | size = None, 55 | titles = None, text_pos = (0, 0)): 56 | if type(imgs) == list and type(imgs[0]) == torch.Tensor: 57 | imgs = torch.cat(imgs) 58 | if type(imgs) == torch.Tensor: 59 | imgs = tensor_to_pil(imgs) 60 | 61 | if not size is None: 62 | imgs = [img.resize((size,size)) for img in imgs] 63 | if cols is None: 64 | cols = len(imgs) 65 | assert len(imgs) >= rows*cols 66 | 67 | top=20 68 | w, h = imgs[0].size 69 | delta = 0 70 | if len(imgs)> 1 and not imgs[1].size[1] == h: 71 | delta = top 72 | h = imgs[1].size[1] 73 | if not titles is None: 74 | font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 75 | size = 20, encoding="unic") 76 | h = top + h 77 | grid = Image.new('RGB', size=(cols*w, rows*h+delta)) 78 | for i, img in enumerate(imgs): 79 | 80 | if not titles is None: 81 | img = add_margin(img, top = top, bottom = 0,left=0) 82 | draw = ImageDraw.Draw(img) 83 | draw.text(text_pos, titles[i],(0,0,0), 84 | font = font) 85 | if not delta == 0 and i > 0: 86 | grid.paste(img, box=(i%cols*w, i//cols*h+delta)) 87 | else: 88 | grid.paste(img, box=(i%cols*w, i//cols*h)) 89 | 90 | return grid 91 | 92 | 93 | """ 94 | input_folder - dataset folder 95 | """ 96 | def load_dataset(input_folder): 97 | # full_file_names = glob.glob(input_folder) 98 | # class_names = [x[0] for x in os.walk(input_folder)] 99 | class_names = next(os.walk(input_folder))[1] 100 | class_names[:] = [d for d in class_names if not d[0] == '.'] 101 | file_names=[] 102 | for class_name in class_names: 103 | cur_path = os.path.join(input_folder, class_name) 104 | filenames = next(os.walk(cur_path), (None, None, []))[2] 105 | filenames = [f for f in filenames if not f[0] == '.'] 106 | file_names.append(filenames) 107 | return class_names, file_names 108 | 109 | 110 | def dataset_from_yaml(yaml_location): 111 | with open(yaml_location, 'r') as stream: 112 | data_loaded = yaml.safe_load(stream) 113 | 114 | return data_loaded -------------------------------------------------------------------------------- /example_images/horse_mud.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/example_images/horse_mud.jpg -------------------------------------------------------------------------------- /example_images/sketch_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/example_images/sketch_cat.jpg -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/imgs/teaser.jpg -------------------------------------------------------------------------------- /main_run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from diffusers import StableDiffusionPipeline 3 | from diffusers import DDIMScheduler 4 | import os 5 | from prompt_to_prompt.ptp_classes import AttentionStore, AttentionReplace, AttentionRefine, EmptyControl,load_512 6 | from prompt_to_prompt.ptp_utils import register_attention_control, text2image_ldm_stable, view_images 7 | from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process 8 | from ddm_inversion.utils import image_grid,dataset_from_yaml 9 | 10 | from torch import autocast, inference_mode 11 | from ddm_inversion.ddim_inversion import ddim_inversion 12 | 13 | import calendar 14 | import time 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--device_num", type=int, default=0) 19 | parser.add_argument("--cfg_src", type=float, default=3.5) 20 | parser.add_argument("--cfg_tar", type=float, default=15) 21 | parser.add_argument("--num_diffusion_steps", type=int, default=100) 22 | parser.add_argument("--dataset_yaml", default="test.yaml") 23 | parser.add_argument("--eta", type=float, default=1) 24 | parser.add_argument("--mode", default="our_inv", help="modes: our_inv,p2pinv,p2pddim,ddim") 25 | parser.add_argument("--skip", type=int, default=36) 26 | parser.add_argument("--xa", type=float, default=0.6) 27 | parser.add_argument("--sa", type=float, default=0.2) 28 | 29 | args = parser.parse_args() 30 | full_data = dataset_from_yaml(args.dataset_yaml) 31 | 32 | # create scheduler 33 | # load diffusion model 34 | model_id = "CompVis/stable-diffusion-v1-4" 35 | # model_id = "stable_diff_local" # load local save of model (for internet problems) 36 | 37 | device = f"cuda:{args.device_num}" 38 | 39 | cfg_scale_src = args.cfg_src 40 | cfg_scale_tar_list = [args.cfg_tar] 41 | eta = args.eta # = 1 42 | skip_zs = [args.skip] 43 | xa_sa_string = f'_xa_{args.xa}_sa{args.sa}_' if args.mode=='p2pinv' else '_' 44 | 45 | current_GMT = time.gmtime() 46 | time_stamp = calendar.timegm(current_GMT) 47 | 48 | # load/reload model: 49 | ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device) 50 | 51 | for i in range(len(full_data)): 52 | current_image_data = full_data[i] 53 | image_path = current_image_data['init_img'] 54 | image_path = '.' + image_path 55 | image_folder = image_path.split('/')[1] # after '.' 56 | prompt_src = current_image_data.get('source_prompt', "") # default empty string 57 | prompt_tar_list = current_image_data['target_prompts'] 58 | 59 | if args.mode=="p2pddim" or args.mode=="ddim": 60 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) 61 | ldm_stable.scheduler = scheduler 62 | else: 63 | ldm_stable.scheduler = DDIMScheduler.from_config(model_id, subfolder = "scheduler") 64 | 65 | ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps) 66 | 67 | # load image 68 | offsets=(0,0,0,0) 69 | x0 = load_512(image_path, *offsets, device) 70 | 71 | # vae encode image 72 | with autocast("cuda"), inference_mode(): 73 | w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float() 74 | 75 | # find Zs and wts - forward process 76 | if args.mode=="p2pddim" or args.mode=="ddim": 77 | wT = ddim_inversion(ldm_stable, w0, prompt_src, cfg_scale_src) 78 | else: 79 | wt, zs, wts = inversion_forward_process(ldm_stable, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=args.num_diffusion_steps) 80 | 81 | # iterate over decoder prompts 82 | for k in range(len(prompt_tar_list)): 83 | prompt_tar = prompt_tar_list[k] 84 | save_path = os.path.join(f'./results/', args.mode+xa_sa_string+str(time_stamp), image_path.split(sep='.')[0], 'src_' + prompt_src.replace(" ", "_"), 'dec_' + prompt_tar.replace(" ", "_")) 85 | os.makedirs(save_path, exist_ok=True) 86 | 87 | # Check if number of words in encoder and decoder text are equal 88 | src_tar_len_eq = (len(prompt_src.split(" ")) == len(prompt_tar.split(" "))) 89 | 90 | for cfg_scale_tar in cfg_scale_tar_list: 91 | for skip in skip_zs: 92 | if args.mode=="our_inv": 93 | # reverse process (via Zs and wT) 94 | controller = AttentionStore() 95 | register_attention_control(ldm_stable, controller) 96 | w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller) 97 | 98 | elif args.mode=="p2pinv": 99 | # inversion with attention replace 100 | cfg_scale_list = [cfg_scale_src, cfg_scale_tar] 101 | prompts = [prompt_src, prompt_tar] 102 | if src_tar_len_eq: 103 | controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable) 104 | else: 105 | # Should use Refine for target prompts with different number of tokens 106 | controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable) 107 | 108 | register_attention_control(ldm_stable, controller) 109 | w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=prompts, cfg_scales=cfg_scale_list, prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller) 110 | w0 = w0[1].unsqueeze(0) 111 | 112 | elif args.mode=="p2pddim" or args.mode=="ddim": 113 | # only z=0 114 | if skip != 0: 115 | continue 116 | prompts = [prompt_src, prompt_tar] 117 | if args.mode=="p2pddim": 118 | if src_tar_len_eq: 119 | controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable) 120 | # Should use Refine for target prompts with different number of tokens 121 | else: 122 | controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable) 123 | else: 124 | controller = EmptyControl() 125 | 126 | register_attention_control(ldm_stable, controller) 127 | # perform ddim inversion 128 | cfg_scale_list = [cfg_scale_src, cfg_scale_tar] 129 | w0, latent = text2image_ldm_stable(ldm_stable, prompts, controller, args.num_diffusion_steps, cfg_scale_list, None, wT) 130 | w0 = w0[1:2] 131 | else: 132 | raise NotImplementedError 133 | 134 | # vae decode image 135 | with autocast("cuda"), inference_mode(): 136 | x0_dec = ldm_stable.vae.decode(1 / 0.18215 * w0).sample 137 | if x0_dec.dim()<4: 138 | x0_dec = x0_dec[None,:,:,:] 139 | img = image_grid(x0_dec) 140 | 141 | # same output 142 | current_GMT = time.gmtime() 143 | time_stamp_name = calendar.timegm(current_GMT) 144 | image_name_png = f'cfg_d_{cfg_scale_tar}_' + f'skip_{skip}_{time_stamp_name}' + ".png" 145 | 146 | save_full_path = os.path.join(save_path, image_name_png) 147 | img.save(save_full_path) -------------------------------------------------------------------------------- /prompt_to_prompt/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /prompt_to_prompt/README.md: -------------------------------------------------------------------------------- 1 | # Prompt-to-Prompt 2 | 3 | > *Latent Diffusion* and *Stable Diffusion* Implementation 4 | 5 | ## :partying_face: ***New:*** :partying_face: Code for Null-Text Inversion is now provided [here](#null-text-inversion-for-editing-real-images) 6 | 7 | 8 | ![teaser](docs/teaser.png) 9 | ### [Project Page](https://prompt-to-prompt.github.io)   [Paper](https://prompt-to-prompt.github.io/ptp_files/Prompt-to-Prompt_preprint.pdf) 10 | 11 | 12 | ## Setup 13 | 14 | This code was tested with Python 3.8, [Pytorch](https://pytorch.org/) 1.11 using pre-trained models through [huggingface / diffusers](https://github.com/huggingface/diffusers#readme). 15 | Specifically, we implemented our method over [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) and [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion-v1-4). 16 | Additional required packages are listed in the requirements file. 17 | The code was tested on a Tesla V100 16GB but should work on other cards with at least **12GB** VRAM. 18 | 19 | ## Quickstart 20 | 21 | In order to get started, we recommend taking a look at our notebooks: [**prompt-to-prompt_ldm**][p2p-ldm] and [**prompt-to-prompt_stable**][p2p-stable]. The notebooks contain end-to-end examples of usage of prompt-to-prompt on top of *Latent Diffusion* and *Stable Diffusion* respectively. Take a look at these notebooks to learn how to use the different types of prompt edits and understand the API. 22 | 23 | ## Prompt Edits 24 | 25 | In our notebooks, we perform our main logic by implementing the abstract class `AttentionControl` object, of the following form: 26 | 27 | ``` python 28 | class AttentionControl(abc.ABC): 29 | @abc.abstractmethod 30 | def forward (self, attn, is_cross: bool, place_in_unet: str): 31 | raise NotImplementedError 32 | ``` 33 | 34 | The `forward` method is called in each attention layer of the diffusion model during the image generation, and we use it to modify the weights of the attention. Our method (See Section 3 of our [paper](https://arxiv.org/abs/2208.01626)) edits images with the procedure above, and each different prompt edit type modifies the weights of the attention in a different manner. 35 | 36 | The general flow of our code is as follows, with variations based on the attention control type: 37 | 38 | ``` python 39 | prompts = ["A painting of a squirrel eating a burger", ...] 40 | controller = AttentionControl(prompts, ...) 41 | run_and_display(prompts, controller, ...) 42 | ``` 43 | 44 | ### Replacement 45 | In this case, the user swaps tokens of the original prompt with others, e.g., the editing the prompt `"A painting of a squirrel eating a burger"` to `"A painting of a squirrel eating a lasagna"` or `"A painting of a lion eating a burger"`. For this we define the class `AttentionReplace`. 46 | 47 | ### Refinement 48 | In this case, the user adds new tokens to the prompt, e.g., editing the prompt `"A painting of a squirrel eating a burger"` to `"A watercolor painting of a squirrel eating a burger"`. For this we define the class `AttentionEditRefine`. 49 | 50 | ### Re-weight 51 | In this case, the user changes the weight of certain tokens in the prompt, e.g., for the prompt `"A photo of a poppy field at night"`, strengthen or weaken the extent to which the word `night` affects the resulting image. For this we define the class `AttentionReweight`. 52 | 53 | 54 | ## Attention Control Options 55 | * `cross_replace_steps`: specifies the fraction of steps to edit the cross attention maps. Can also be set to a dictionary `[str:float]` which specifies fractions for different words in the prompt. 56 | * `self_replace_steps`: specifies the fraction of steps to replace the self attention maps. 57 | * `local_blend` (optional): `LocalBlend` object which is used to make local edits. `LocalBlend` is initialized with the words from each prompt that correspond with the region in the image we want to edit. 58 | * `equalizer`: used for attention Re-weighting only. A vector of coefficients to multiply each cross-attention weight 59 | 60 | ## Citation 61 | 62 | ``` bibtex 63 | @article{hertz2022prompt, 64 | title = {Prompt-to-Prompt Image Editing with Cross Attention Control}, 65 | author = {Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel}, 66 | journal = {arXiv preprint arXiv:2208.01626}, 67 | year = {2022}, 68 | } 69 | ``` 70 | 71 | # Null-Text Inversion for Editing Real Images 72 | 73 | ### [Project Page](https://null-text-inversion.github.io/)   [Paper](https://arxiv.org/abs/2211.09794) 74 | 75 | 76 | 77 | Null-text inversion enables intuitive text-based editing of **real images** with the Stable Diffusion model. We use an initial DDIM inversion as an anchor for our optimization which only tunes the null-text embedding used in classifier-free guidance. 78 | 79 | 80 | ![teaser](docs/null_text_teaser.png) 81 | 82 | ## Editing Real Images 83 | 84 | Prompt-to-Prompt editing of real images by first using Null-text inversion is provided in this [**Notebooke**][null_text]. 85 | 86 | 87 | ``` bibtex 88 | @article{mokady2022null, 89 | title={Null-text Inversion for Editing Real Images using Guided Diffusion Models}, 90 | author={Mokady, Ron and Hertz, Amir and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel}, 91 | journal={arXiv preprint arXiv:2211.09794}, 92 | year={2022} 93 | } 94 | ``` 95 | 96 | 97 | ## Disclaimer 98 | 99 | This is not an officially supported Google product. 100 | 101 | [p2p-ldm]: prompt-to-prompt_ldm.ipynb 102 | [p2p-stable]: prompt-to-prompt_stable.ipynb 103 | [null_text]: null_text_w_ptp.ipynb 104 | -------------------------------------------------------------------------------- /prompt_to_prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__init__.py -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/ddim_inversion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ddim_inversion.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/inversion_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/inversion_utils.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/inversion_utils_vova.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/inversion_utils_vova.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/ptp_classes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ptp_classes.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/ptp_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ptp_utils.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/seq_aligner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/seq_aligner.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /prompt_to_prompt/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /prompt_to_prompt/data/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/data/horse.jpg -------------------------------------------------------------------------------- /prompt_to_prompt/docs/null_text_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/docs/null_text_teaser.png -------------------------------------------------------------------------------- /prompt_to_prompt/docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/docs/teaser.png -------------------------------------------------------------------------------- /prompt_to_prompt/example_images/gnochi_mirror.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/example_images/gnochi_mirror.jpeg -------------------------------------------------------------------------------- /prompt_to_prompt/ptp_classes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally taken from 3 | https://github.com/google/prompt-to-prompt 4 | """ 5 | 6 | 7 | LOW_RESOURCE = True 8 | MAX_NUM_WORDS = 77 9 | 10 | 11 | from typing import Optional, Union, Tuple, List, Callable, Dict 12 | import prompt_to_prompt.ptp_utils as ptp_utils 13 | import prompt_to_prompt.seq_aligner as seq_aligner 14 | import torch 15 | import torch.nn.functional as nnf 16 | import abc 17 | import numpy as np 18 | 19 | 20 | class LocalBlend: 21 | 22 | def __call__(self, x_t, attention_store): 23 | k = 1 24 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 25 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps] 26 | maps = torch.cat(maps, dim=1) 27 | maps = (maps * self.alpha_layers).sum(-1).mean(1) 28 | mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) 29 | mask = nnf.interpolate(mask, size=(x_t.shape[2:])) 30 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 31 | mask = mask.gt(self.threshold) 32 | mask = (mask[:1] + mask[1:]).float() 33 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 34 | return x_t 35 | 36 | def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3, device=None, tokenizer=None): 37 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 38 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 39 | if type(words_) is str: 40 | words_ = [words_] 41 | for word in words_: 42 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 43 | alpha_layers[i, :, :, :, :, ind] = 1 44 | self.alpha_layers = alpha_layers.to(device) 45 | self.threshold = threshold 46 | 47 | 48 | class AttentionControl(abc.ABC): 49 | 50 | def step_callback(self, x_t): 51 | return x_t 52 | 53 | def between_steps(self): 54 | return 55 | 56 | @property 57 | def num_uncond_att_layers(self): 58 | return self.num_att_layers if LOW_RESOURCE else 0 59 | 60 | @abc.abstractmethod 61 | def forward (self, attn, is_cross: bool, place_in_unet: str): 62 | raise NotImplementedError 63 | 64 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 65 | if self.cur_att_layer >= self.num_uncond_att_layers: 66 | if LOW_RESOURCE: 67 | attn = self.forward(attn, is_cross, place_in_unet) 68 | else: 69 | h = attn.shape[0] 70 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 71 | self.cur_att_layer += 1 72 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 73 | self.cur_att_layer = 0 74 | self.cur_step += 1 75 | self.between_steps() 76 | return attn 77 | 78 | def reset(self): 79 | self.cur_step = 0 80 | self.cur_att_layer = 0 81 | 82 | def __init__(self): 83 | self.cur_step = 0 84 | self.num_att_layers = -1 85 | self.cur_att_layer = 0 86 | 87 | class EmptyControl(AttentionControl): 88 | 89 | def forward (self, attn, is_cross: bool, place_in_unet: str): 90 | return attn 91 | 92 | 93 | class AttentionStore(AttentionControl): 94 | 95 | @staticmethod 96 | def get_empty_store(): 97 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 98 | "down_self": [], "mid_self": [], "up_self": []} 99 | 100 | def forward(self, attn, is_cross: bool, place_in_unet: str): 101 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 102 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 103 | self.step_store[key].append(attn) 104 | return attn 105 | 106 | def between_steps(self): 107 | if len(self.attention_store) == 0: 108 | self.attention_store = self.step_store 109 | else: 110 | for key in self.attention_store: 111 | for i in range(len(self.attention_store[key])): 112 | self.attention_store[key][i] += self.step_store[key][i] 113 | self.step_store = self.get_empty_store() 114 | 115 | def get_average_attention(self): 116 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 117 | return average_attention 118 | 119 | 120 | def reset(self): 121 | super(AttentionStore, self).reset() 122 | self.step_store = self.get_empty_store() 123 | self.attention_store = {} 124 | 125 | def __init__(self): 126 | super(AttentionStore, self).__init__() 127 | self.step_store = self.get_empty_store() 128 | self.attention_store = {} 129 | 130 | 131 | class AttentionControlEdit(AttentionStore, abc.ABC): 132 | 133 | def step_callback(self, x_t): 134 | if self.local_blend is not None: 135 | x_t = self.local_blend(x_t, self.attention_store) 136 | return x_t 137 | 138 | def replace_self_attention(self, attn_base, att_replace): 139 | if att_replace.shape[2] <= 16 ** 2: 140 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 141 | else: 142 | return att_replace 143 | 144 | @abc.abstractmethod 145 | def replace_cross_attention(self, attn_base, att_replace): 146 | raise NotImplementedError 147 | 148 | def forward(self, attn, is_cross: bool, place_in_unet: str): 149 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 150 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 151 | h = attn.shape[0] // (self.batch_size) 152 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 153 | attn_base, attn_repalce = attn[0], attn[1:] 154 | if is_cross: 155 | alpha_words = self.cross_replace_alpha[self.cur_step] 156 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce 157 | attn[1:] = attn_repalce_new 158 | else: 159 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 160 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 161 | return attn 162 | 163 | def __init__(self, prompts, num_steps: int, 164 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 165 | self_replace_steps: Union[float, Tuple[float, float]], 166 | local_blend: Optional[LocalBlend], 167 | device=None, 168 | tokenizer=None): 169 | super(AttentionControlEdit, self).__init__() 170 | self.batch_size = len(prompts) 171 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) 172 | if type(self_replace_steps) is float: 173 | self_replace_steps = 0, self_replace_steps 174 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 175 | self.local_blend = local_blend 176 | 177 | class AttentionReplace(AttentionControlEdit): 178 | 179 | def replace_cross_attention(self, attn_base, att_replace): 180 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 181 | 182 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 183 | local_blend: Optional[LocalBlend] = None, model=None): 184 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, device=model.device) 185 | self.mapper = seq_aligner.get_replacement_mapper(prompts, model.tokenizer).to(model.device) 186 | 187 | 188 | class AttentionRefine(AttentionControlEdit): 189 | 190 | def replace_cross_attention(self, attn_base, att_replace): 191 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 192 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 193 | return attn_replace 194 | 195 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 196 | local_blend: Optional[LocalBlend] = None, model=None): 197 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, device=model.device) 198 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, model.tokenizer) 199 | self.mapper, alphas = self.mapper.to(model.device), alphas.to(model.device) 200 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 201 | 202 | 203 | class AttentionReweight(AttentionControlEdit): 204 | 205 | def replace_cross_attention(self, attn_base, att_replace): 206 | if self.prev_controller is not None: 207 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 208 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 209 | return attn_replace 210 | 211 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 212 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None, device=None, tokenizer=None): 213 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 214 | self.equalizer = equalizer.to(device) 215 | self.prev_controller = controller 216 | 217 | 218 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], 219 | Tuple[float, ...]], tokenizer=None): 220 | if type(word_select) is int or type(word_select) is str: 221 | word_select = (word_select,) 222 | equalizer = torch.ones(len(values), 77) 223 | values = torch.tensor(values, dtype=torch.float32) 224 | for word in word_select: 225 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 226 | equalizer[:, inds] = values 227 | return equalizer 228 | 229 | from PIL import Image 230 | 231 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts=None): 232 | out = [] 233 | attention_maps = attention_store.get_average_attention() 234 | num_pixels = res ** 2 235 | for location in from_where: 236 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 237 | if item.shape[1] == num_pixels: 238 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 239 | out.append(cross_maps) 240 | out = torch.cat(out, dim=0) 241 | out = out.sum(0) / out.shape[0] 242 | return out.cpu() 243 | 244 | 245 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0, prompts=None, tokenizer=None): 246 | tokens = tokenizer.encode(prompts[select]) 247 | decoder = tokenizer.decode 248 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts) 249 | images = [] 250 | for i in range(len(tokens)): 251 | image = attention_maps[:, :, i] 252 | image = 255 * image / image.max() 253 | image = image.unsqueeze(-1).expand(*image.shape, 3) 254 | image = image.numpy().astype(np.uint8) 255 | image = np.array(Image.fromarray(image).resize((256, 256))) 256 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 257 | images.append(image) 258 | return(ptp_utils.view_images(np.stack(images, axis=0))) 259 | 260 | 261 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 262 | max_com=10, select: int = 0): 263 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) 264 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 265 | images = [] 266 | for i in range(max_com): 267 | image = vh[i].reshape(res, res) 268 | image = image - image.min() 269 | image = 255 * image / image.max() 270 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 271 | image = Image.fromarray(image).resize((256, 256)) 272 | image = np.array(image) 273 | images.append(image) 274 | ptp_utils.view_images(np.concatenate(images, axis=1)) 275 | 276 | def run_and_display(model, prompts, controller, latent=None, run_baseline=False, generator=None): 277 | if run_baseline: 278 | print("w.o. prompt-to-prompt") 279 | images, latent = run_and_display(model, prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) 280 | print("with prompt-to-prompt") 281 | images, x_t = ptp_utils.text2image_ld 282 | 283 | 284 | def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None): 285 | if type(image_path) is str: 286 | image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3] 287 | else: 288 | image = image_path 289 | h, w, c = image.shape 290 | left = min(left, w-1) 291 | right = min(right, w - left - 1) 292 | top = min(top, h - left - 1) 293 | bottom = min(bottom, h - top - 1) 294 | image = image[top:h-bottom, left:w-right] 295 | h, w, c = image.shape 296 | if h < w: 297 | offset = (w - h) // 2 298 | image = image[:, offset:offset + h] 299 | elif w < h: 300 | offset = (h - w) // 2 301 | image = image[offset:offset + w] 302 | image = np.array(Image.fromarray(image).resize((512, 512))) 303 | image = torch.from_numpy(image).float() / 127.5 - 1 304 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 305 | 306 | return image -------------------------------------------------------------------------------- /prompt_to_prompt/ptp_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally taken from 3 | https://github.com/google/prompt-to-prompt 4 | """ 5 | 6 | # Copyright 2022 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | 21 | 22 | import numpy as np 23 | import torch 24 | from PIL import Image, ImageDraw, ImageFont 25 | import cv2 26 | from typing import Optional, Union, Tuple, List, Callable, Dict 27 | # from IPython.display import display 28 | from tqdm import tqdm 29 | 30 | 31 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 32 | h, w, c = image.shape 33 | offset = int(h * .2) 34 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 35 | font = cv2.FONT_HERSHEY_SIMPLEX 36 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 37 | img[:h] = image 38 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 39 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 40 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) 41 | return img 42 | 43 | 44 | def view_images(images, num_rows=1, offset_ratio=0.02): 45 | if type(images) is list: 46 | num_empty = len(images) % num_rows 47 | elif images.ndim == 4: 48 | num_empty = images.shape[0] % num_rows 49 | else: 50 | images = [images] 51 | num_empty = 0 52 | 53 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 54 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 55 | num_items = len(images) 56 | 57 | h, w, c = images[0].shape 58 | offset = int(h * offset_ratio) 59 | num_cols = num_items // num_rows 60 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 61 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 62 | for i in range(num_rows): 63 | for j in range(num_cols): 64 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 65 | i * num_cols + j] 66 | 67 | pil_img = Image.fromarray(image_) 68 | # display(pil_img) 69 | return pil_img 70 | 71 | 72 | 73 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 74 | if low_resource: 75 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 76 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 77 | else: 78 | latents_input = torch.cat([latents] * 2) 79 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 80 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 81 | cfg_scales_tensor = torch.Tensor(guidance_scale).view(-1,1,1,1).to(model.device) 82 | noise_pred = noise_pred_uncond + cfg_scales_tensor * (noise_prediction_text - noise_pred_uncond) 83 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 84 | latents = controller.step_callback(latents) 85 | return latents 86 | 87 | 88 | def latent2image(vae, latents): 89 | latents = 1 / 0.18215 * latents 90 | image = vae.decode(latents)['sample'] 91 | image = (image / 2 + 0.5).clamp(0, 1) 92 | image = image.cpu().permute(0, 2, 3, 1).numpy() 93 | image = (image * 255).astype(np.uint8) 94 | return image 95 | 96 | 97 | def init_latent(latent, model, height, width, generator, batch_size): 98 | if latent is None: 99 | latent = torch.randn( 100 | (1, model.unet.in_channels, height // 8, width // 8), 101 | generator=generator, 102 | ) 103 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 104 | return latent, latents 105 | 106 | 107 | @torch.no_grad() 108 | def text2image_ldm( 109 | model, 110 | prompt: List[str], 111 | controller, 112 | num_inference_steps: int = 50, 113 | guidance_scale: Optional[float] = 7., 114 | generator: Optional[torch.Generator] = None, 115 | latent: Optional[torch.FloatTensor] = None, 116 | ): 117 | register_attention_control(model, controller) 118 | height = width = 256 119 | batch_size = len(prompt) 120 | 121 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 122 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 123 | 124 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 125 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 126 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 127 | context = torch.cat([uncond_embeddings, text_embeddings]) 128 | 129 | model.scheduler.set_timesteps(num_inference_steps) 130 | for t in tqdm(model.scheduler.timesteps): 131 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 132 | 133 | image = latent2image(model.vqvae, latents) 134 | 135 | return image, latent 136 | 137 | 138 | @torch.no_grad() 139 | def text2image_ldm_stable( 140 | model, 141 | prompt: List[str], 142 | controller, 143 | num_inference_steps: int = 50, 144 | guidance_scale: float = 7.5, 145 | generator: Optional[torch.Generator] = None, 146 | latent: Optional[torch.FloatTensor] = None, 147 | restored_wt = None, 148 | restored_zs = None, 149 | low_resource: bool = False, 150 | ): 151 | register_attention_control(model, controller) 152 | height = width = 512 153 | batch_size = len(prompt) 154 | 155 | text_input = model.tokenizer( 156 | prompt, 157 | padding="max_length", 158 | max_length=model.tokenizer.model_max_length, 159 | truncation=True, 160 | return_tensors="pt", 161 | ) 162 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 163 | max_length = text_input.input_ids.shape[-1] 164 | uncond_input = model.tokenizer( 165 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 166 | ) 167 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 168 | 169 | context = [uncond_embeddings, text_embeddings] 170 | if not low_resource: 171 | context = torch.cat(context) 172 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 173 | 174 | # set timesteps 175 | # extra_set_kwargs = {"offset": 1} 176 | model.scheduler.set_timesteps(num_inference_steps)#, **extra_set_kwargs) 177 | for t in tqdm(model.scheduler.timesteps): 178 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 179 | 180 | # image = latent2image(model.vae, latents) 181 | 182 | return latents, latent 183 | 184 | 185 | def register_attention_control(model, controller): 186 | def ca_forward(self, place_in_unet): 187 | to_out = self.to_out 188 | if type(to_out) is torch.nn.modules.container.ModuleList: 189 | to_out = self.to_out[0] 190 | else: 191 | to_out = self.to_out 192 | 193 | def forward(x, context=None, mask=None): 194 | batch_size, sequence_length, dim = x.shape 195 | h = self.heads 196 | q = self.to_q(x) 197 | is_cross = context is not None 198 | context = context if is_cross else x 199 | k = self.to_k(context) 200 | v = self.to_v(context) 201 | q = self.reshape_heads_to_batch_dim(q) 202 | k = self.reshape_heads_to_batch_dim(k) 203 | v = self.reshape_heads_to_batch_dim(v) 204 | 205 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 206 | 207 | if mask is not None: 208 | mask = mask.reshape(batch_size, -1) 209 | max_neg_value = -torch.finfo(sim.dtype).max 210 | mask = mask[:, None, :].repeat(h, 1, 1) 211 | sim.masked_fill_(~mask, max_neg_value) 212 | 213 | # attention, what we cannot get enough of 214 | attn = sim.softmax(dim=-1) 215 | attn = controller(attn, is_cross, place_in_unet) 216 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 217 | out = self.reshape_batch_dim_to_heads(out) 218 | return to_out(out) 219 | 220 | return forward 221 | 222 | class DummyController: 223 | 224 | def __call__(self, *args): 225 | return args[0] 226 | 227 | def __init__(self): 228 | self.num_att_layers = 0 229 | 230 | if controller is None: 231 | controller = DummyController() 232 | 233 | def register_recr(net_, count, place_in_unet): 234 | if net_.__class__.__name__ == 'CrossAttention': 235 | net_.forward = ca_forward(net_, place_in_unet) 236 | return count + 1 237 | elif hasattr(net_, 'children'): 238 | for net__ in net_.children(): 239 | count = register_recr(net__, count, place_in_unet) 240 | return count 241 | 242 | cross_att_count = 0 243 | sub_nets = model.unet.named_children() 244 | for net in sub_nets: 245 | if "down" in net[0]: 246 | cross_att_count += register_recr(net[1], 0, "down") 247 | elif "up" in net[0]: 248 | cross_att_count += register_recr(net[1], 0, "up") 249 | elif "mid" in net[0]: 250 | cross_att_count += register_recr(net[1], 0, "mid") 251 | 252 | controller.num_att_layers = cross_att_count 253 | 254 | 255 | def get_word_inds(text: str, word_place: int, tokenizer): 256 | split_text = text.split(" ") 257 | if type(word_place) is str: 258 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 259 | elif type(word_place) is int: 260 | word_place = [word_place] 261 | out = [] 262 | if len(word_place) > 0: 263 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 264 | cur_len, ptr = 0, 0 265 | 266 | for i in range(len(words_encode)): 267 | cur_len += len(words_encode[i]) 268 | if ptr in word_place: 269 | out.append(i + 1) 270 | if cur_len >= len(split_text[ptr]): 271 | ptr += 1 272 | cur_len = 0 273 | return np.array(out) 274 | 275 | 276 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 277 | word_inds: Optional[torch.Tensor]=None): 278 | if type(bounds) is float: 279 | bounds = 0, bounds 280 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 281 | if word_inds is None: 282 | word_inds = torch.arange(alpha.shape[2]) 283 | alpha[: start, prompt_ind, word_inds] = 0 284 | alpha[start: end, prompt_ind, word_inds] = 1 285 | alpha[end:, prompt_ind, word_inds] = 0 286 | return alpha 287 | 288 | 289 | def get_time_words_attention_alpha(prompts, num_steps, 290 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 291 | tokenizer, max_num_words=77): 292 | if type(cross_replace_steps) is not dict: 293 | cross_replace_steps = {"default_": cross_replace_steps} 294 | if "default_" not in cross_replace_steps: 295 | cross_replace_steps["default_"] = (0., 1.) 296 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 297 | for i in range(len(prompts) - 1): 298 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 299 | i) 300 | for key, item in cross_replace_steps.items(): 301 | if key != "default_": 302 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 303 | for i, ind in enumerate(inds): 304 | if len(ind) > 0: 305 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 306 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 307 | return alpha_time_words 308 | -------------------------------------------------------------------------------- /prompt_to_prompt/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.8.0 2 | transformers 3 | ftfy 4 | opencv-python 5 | ipywidgets -------------------------------------------------------------------------------- /prompt_to_prompt/seq_aligner.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was originally taken from 3 | https://github.com/google/prompt-to-prompt 4 | """ 5 | 6 | # Copyright 2022 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | import torch 20 | import numpy as np 21 | 22 | 23 | class ScoreParams: 24 | 25 | def __init__(self, gap, match, mismatch): 26 | self.gap = gap 27 | self.match = match 28 | self.mismatch = mismatch 29 | 30 | def mis_match_char(self, x, y): 31 | if x != y: 32 | return self.mismatch 33 | else: 34 | return self.match 35 | 36 | 37 | def get_matrix(size_x, size_y, gap): 38 | matrix = [] 39 | for i in range(len(size_x) + 1): 40 | sub_matrix = [] 41 | for j in range(len(size_y) + 1): 42 | sub_matrix.append(0) 43 | matrix.append(sub_matrix) 44 | for j in range(1, len(size_y) + 1): 45 | matrix[0][j] = j*gap 46 | for i in range(1, len(size_x) + 1): 47 | matrix[i][0] = i*gap 48 | return matrix 49 | 50 | 51 | def get_matrix(size_x, size_y, gap): 52 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 53 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 54 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 55 | return matrix 56 | 57 | 58 | def get_traceback_matrix(size_x, size_y): 59 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) 60 | matrix[0, 1:] = 1 61 | matrix[1:, 0] = 2 62 | matrix[0, 0] = 4 63 | return matrix 64 | 65 | 66 | def global_align(x, y, score): 67 | matrix = get_matrix(len(x), len(y), score.gap) 68 | trace_back = get_traceback_matrix(len(x), len(y)) 69 | for i in range(1, len(x) + 1): 70 | for j in range(1, len(y) + 1): 71 | left = matrix[i, j - 1] + score.gap 72 | up = matrix[i - 1, j] + score.gap 73 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 74 | matrix[i, j] = max(left, up, diag) 75 | if matrix[i, j] == left: 76 | trace_back[i, j] = 1 77 | elif matrix[i, j] == up: 78 | trace_back[i, j] = 2 79 | else: 80 | trace_back[i, j] = 3 81 | return matrix, trace_back 82 | 83 | 84 | def get_aligned_sequences(x, y, trace_back): 85 | x_seq = [] 86 | y_seq = [] 87 | i = len(x) 88 | j = len(y) 89 | mapper_y_to_x = [] 90 | while i > 0 or j > 0: 91 | if trace_back[i, j] == 3: 92 | x_seq.append(x[i-1]) 93 | y_seq.append(y[j-1]) 94 | i = i-1 95 | j = j-1 96 | mapper_y_to_x.append((j, i)) 97 | elif trace_back[i][j] == 1: 98 | x_seq.append('-') 99 | y_seq.append(y[j-1]) 100 | j = j-1 101 | mapper_y_to_x.append((j, -1)) 102 | elif trace_back[i][j] == 2: 103 | x_seq.append(x[i-1]) 104 | y_seq.append('-') 105 | i = i-1 106 | elif trace_back[i][j] == 4: 107 | break 108 | mapper_y_to_x.reverse() 109 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 110 | 111 | 112 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 113 | x_seq = tokenizer.encode(x) 114 | y_seq = tokenizer.encode(y) 115 | score = ScoreParams(0, 1, -1) 116 | matrix, trace_back = global_align(x_seq, y_seq, score) 117 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 118 | alphas = torch.ones(max_len) 119 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 120 | mapper = torch.zeros(max_len, dtype=torch.int64) 121 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 122 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 123 | return mapper, alphas 124 | 125 | 126 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 127 | x_seq = prompts[0] 128 | mappers, alphas = [], [] 129 | for i in range(1, len(prompts)): 130 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 131 | mappers.append(mapper) 132 | alphas.append(alpha) 133 | return torch.stack(mappers), torch.stack(alphas) 134 | 135 | 136 | def get_word_inds(text: str, word_place: int, tokenizer): 137 | split_text = text.split(" ") 138 | if type(word_place) is str: 139 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 140 | elif type(word_place) is int: 141 | word_place = [word_place] 142 | out = [] 143 | if len(word_place) > 0: 144 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 145 | cur_len, ptr = 0, 0 146 | 147 | for i in range(len(words_encode)): 148 | cur_len += len(words_encode[i]) 149 | if ptr in word_place: 150 | out.append(i + 1) 151 | if cur_len >= len(split_text[ptr]): 152 | ptr += 1 153 | cur_len = 0 154 | return np.array(out) 155 | 156 | 157 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 158 | words_x = x.split(' ') 159 | words_y = y.split(' ') 160 | if len(words_x) != len(words_y): 161 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 162 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 163 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 164 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 165 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 166 | mapper = np.zeros((max_len, max_len)) 167 | i = j = 0 168 | cur_inds = 0 169 | while i < max_len and j < max_len: 170 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 171 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 172 | if len(inds_source_) == len(inds_target_): 173 | mapper[inds_source_, inds_target_] = 1 174 | else: 175 | ratio = 1 / len(inds_target_) 176 | for i_t in inds_target_: 177 | mapper[inds_source_, i_t] = ratio 178 | cur_inds += 1 179 | i += len(inds_source_) 180 | j += len(inds_target_) 181 | elif cur_inds < len(inds_source): 182 | mapper[i, j] = 1 183 | i += 1 184 | j += 1 185 | else: 186 | mapper[j, j] = 1 187 | i += 1 188 | j += 1 189 | 190 | return torch.from_numpy(mapper).float() 191 | 192 | 193 | 194 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 195 | x_seq = prompts[0] 196 | mappers = [] 197 | for i in range(1, len(prompts)): 198 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 199 | mappers.append(mapper) 200 | return torch.stack(mappers) 201 | 202 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | certifi==2022.12.7 3 | charset-normalizer==3.1.0 4 | cmake==3.26.3 5 | contourpy==1.0.7 6 | cycler==0.11.0 7 | diffusers==0.8.0 8 | filelock==3.12.0 9 | fonttools==4.39.3 10 | huggingface-hub==0.13.4 11 | idna==3.4 12 | importlib-metadata==6.5.0 13 | importlib-resources==5.12.0 14 | Jinja2==3.1.2 15 | kiwisolver==1.4.4 16 | lit==16.0.1 17 | MarkupSafe==2.1.2 18 | matplotlib==3.7.1 19 | mpmath==1.3.0 20 | networkx==3.1 21 | numpy==1.24.2 22 | nvidia-cublas-cu11==11.10.3.66 23 | nvidia-cuda-cupti-cu11==11.7.101 24 | nvidia-cuda-nvrtc-cu11==11.7.99 25 | nvidia-cuda-runtime-cu11==11.7.99 26 | nvidia-cudnn-cu11==8.5.0.96 27 | nvidia-cufft-cu11==10.9.0.58 28 | nvidia-curand-cu11==10.2.10.91 29 | nvidia-cusolver-cu11==11.4.0.1 30 | nvidia-cusparse-cu11==11.7.4.91 31 | nvidia-nccl-cu11==2.14.3 32 | nvidia-nvtx-cu11==11.7.91 33 | opencv-python==4.7.0.72 34 | packaging==23.1 35 | Pillow==9.5.0 36 | pkg_resources==0.0.0 37 | psutil==5.9.5 38 | pyparsing==3.0.9 39 | python-dateutil==2.8.2 40 | PyYAML==6.0 41 | regex==2023.3.23 42 | requests==2.28.2 43 | six==1.16.0 44 | sympy==1.11.1 45 | tokenizers==0.13.3 46 | torch==2.0.0 47 | torchaudio==2.0.1 48 | torchvision==0.15.1 49 | tqdm==4.65.0 50 | transformers==4.28.1 51 | triton==2.0.0 52 | typing_extensions==4.5.0 53 | urllib3==1.26.15 54 | zipp==3.15.0 55 | -------------------------------------------------------------------------------- /test.yaml: -------------------------------------------------------------------------------- 1 | - 2 | init_img: /example_images/horse_mud.jpg 3 | source_prompt: a photo of a horse in the mud 4 | 5 | target_prompts: 6 | - a photo of a horse in the snow 7 | - a photo of a zebra in the snow 8 | - a photo of a zebra in the mud 9 | 10 | - 11 | init_img: /example_images/sketch_cat.jpg 12 | source_prompt: a sketch of a cat 13 | 14 | target_prompts: 15 | - a sculpture of a cat 16 | --------------------------------------------------------------------------------