├── .gitignore
├── LICENSE
├── README.md
├── ddm_inversion
├── ddim_inversion.py
├── inversion_utils.py
└── utils.py
├── example_images
├── horse_mud.jpg
└── sketch_cat.jpg
├── imgs
└── teaser.jpg
├── main_run.py
├── prompt_to_prompt
├── LICENSE
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── ddim_inversion.cpython-38.pyc
│ ├── inversion_utils.cpython-38.pyc
│ ├── inversion_utils_vova.cpython-38.pyc
│ ├── ptp_classes.cpython-38.pyc
│ ├── ptp_utils.cpython-38.pyc
│ ├── seq_aligner.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── contributing.md
├── data
│ └── horse.jpg
├── docs
│ ├── null_text_teaser.png
│ └── teaser.png
├── example_images
│ └── gnochi_mirror.jpeg
├── ptp_classes.py
├── ptp_utils.py
├── requirements.txt
└── seq_aligner.py
├── requirements.txt
└── test.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | results_100/
3 | prompt_to_prompt/__pycache__/utils.cpython-38.pyc
4 | prompt_to_prompt/__pycache__/*
5 | *.pyc
6 | results/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 inbarhub
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://www.python.org/downloads/release/python-38/)
3 | [](https://pytorch.org/)
4 |
5 |
6 | # DDPM inversion, CVPR 2024
7 |
8 | [Project page](https://inbarhub.github.io/DDPM_inversion/) | [Arxiv](https://arxiv.org/abs/2304.06140) | [Supplementary materials](https://inbarhub.github.io/DDPM_inversion/resources/inversion_supp.pdf) | [Hugging Face Demo](https://huggingface.co/spaces/LinoyTsaban/edit_friendly_ddpm_inversion)
9 | ### Official pytorch implementation of the paper:
"An Edit Friendly DDPM Noise Space: Inversion and Manipulations"
10 | #### Inbar Huberman-Spiegelglas, Vladimir Kulikov and Tomer Michaeli
11 |
12 |
13 | 
14 | Our inversion can be used for text-based **editing of real images**, either by itself or in combination with other editing methods.
15 | Due to the stochastic nature of our method, we can generate **diverse outputs**, a feature that is not naturally available with methods relying on the DDIM inversion.
16 |
17 | In this repository we support editing using our inversion, prompt-to-prompt (p2p)+our inversion, ddim or [p2p](https://github.com/google/prompt-to-prompt) (with ddim inversion).
18 | **our inversion**: our ddpm inversion followed by generating an image conditioned on the target prompt.
19 |
20 | **prompt-to-prompt (p2p) + our inversion**: p2p method using our ddpm inversion.
21 |
22 | **ddim**: ddim inversion followed by generating an image conditioned on the target prompt.
23 |
24 | **p2p**: p2p method using ddim inversion (original paper).
25 |
26 | ## Table of Contents
27 | * [Requirements](#Requirements)
28 | * [Repository Structure](#Repository-Structure)
29 | * [Algorithm Inputs and Parameters](#Algorithm-Inputs-and-Parameters)
30 | * [Usage Example](#Usage-Example)
31 |
32 | * [Citation](#Citation)
33 |
34 | ## Requirements
35 |
36 | ```
37 | python -m pip install -r requirements.txt
38 | ```
39 | This code was tested with python 3.8 and torch 2.0.0.
40 |
41 | ## Repository Structure
42 | ```
43 | ├── ddm_inversion - folder contains inversions in order to work on real images: ddim inversion as well as ddpm inversion (our method).
44 | ├── example_images - folder of input images to be edited
45 | ├── imgs - images used in this repository readme.md file
46 | ├── prompt_to_prompt - p2p code
47 | ├── main_run.py - main python file for real image editing
48 | └── test.yaml - yaml file contains images and prompts to test on
49 | ```
50 |
51 | A folder named 'results' will be automatically created and all the results will be saved to this folder. We also add a timestamp to the saved images in this folder.
52 |
53 | ## Algorithm Inputs and Parameters
54 | Method's inputs:
55 | ```
56 | init_img - the path to the input images
57 | source_prompt - a prompt describing the input image
58 | target_prompts - the edit prompt (creates several images if multiple prompts are given)
59 | ```
60 | These three inputs are supplied through a YAML file (please use the provided 'test.yaml' file as a reference).
61 |
62 |
63 | Method's parameters are:
64 |
65 | ```
66 | skip - controlling the adherence to the input image
67 | cfg_tar - classifier free guidance strengths
68 | ```
69 | These two parameters have default values, as descibed in the paper.
70 |
71 | ## Usage Example
72 | ```
73 | python3 main_run.py --mode="our_inv" --dataset_yaml="test.yaml" --skip=36 --cfg_tar=15
74 | python3 main_run.py --mode="p2pinv" --dataset_yaml="test.yaml" --skip=12 --cfg_tar=9
75 |
76 | ```
77 | The ```mode``` argument can also be: ```ddim``` or ```p2p```.
78 |
79 | In ```our_inv``` and ```p2pinv``` modes we suggest to play around with ```skip``` in the range [0,40] and ```cfg_tar``` in the range [7,18].
80 |
81 | **p2pinv and p2p**:
82 | Note that you can play with the cross-and self-attention via ```--xa``` and ```--sa``` arguments. We suggest to set them to (0.6,0.2) and (0.8,0.4) for p2pinv and p2p respectively.
83 |
84 | **ddim and p2p**:
85 | ```skip``` is overwritten to be 0.
86 |
87 |
95 |
96 | You can edit the test.yaml file to load your image and choose the desired prompts.
97 |
98 |
103 |
104 | ## Citation
105 | If you use this code for your research, please cite our paper:
106 | ```
107 | @inproceedings{huberman2024edit,
108 | title={An edit friendly {DDPM} noise space: Inversion and manipulations},
109 | author={Huberman-Spiegelglas, Inbar and Kulikov, Vladimir and Michaeli, Tomer},
110 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
111 | pages={12469--12478},
112 | year={2024}
113 | }
114 | ```
--------------------------------------------------------------------------------
/ddm_inversion/ddim_inversion.py:
--------------------------------------------------------------------------------
1 |
2 | from ddm_inversion.inversion_utils import encode_text
3 | from typing import Union
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | def next_step(model, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
9 | timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
10 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep] if timestep >= 0 else model.scheduler.final_alpha_cumprod
11 | alpha_prod_t_next = model.scheduler.alphas_cumprod[next_timestep]
12 | beta_prod_t = 1 - alpha_prod_t
13 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
14 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
15 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
16 | return next_sample
17 |
18 | def get_noise_pred(model, latent, t, context, cfg_scale):
19 | latents_input = torch.cat([latent] * 2)
20 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
21 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
22 | noise_pred = noise_pred_uncond + cfg_scale * (noise_prediction_text - noise_pred_uncond)
23 | # latents = next_step(model, noise_pred, t, latent)
24 | return noise_pred
25 |
26 | @torch.no_grad()
27 | def ddim_loop(model, w0, prompt, cfg_scale):
28 | # uncond_embeddings, cond_embeddings = self.context.chunk(2)
29 | # all_latent = [latent]
30 | text_embedding = encode_text(model, prompt)
31 | uncond_embedding = encode_text(model, "")
32 | context = torch.cat([uncond_embedding, text_embedding])
33 | latent = w0.clone().detach()
34 | for i in tqdm(range(model.scheduler.num_inference_steps)):
35 | t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
36 | noise_pred = get_noise_pred(model, latent, t, context, cfg_scale)
37 | latent = next_step(model, noise_pred, t, latent)
38 | # all_latent.append(latent)
39 | return latent
40 |
41 | @torch.no_grad()
42 | def ddim_inversion(model, w0, prompt, cfg_scale):
43 | wT = ddim_loop(model, w0, prompt, cfg_scale)
44 | return wT
--------------------------------------------------------------------------------
/ddm_inversion/inversion_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from tqdm import tqdm
4 |
5 | def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
6 | from ddm_inversion.utils import pil_to_tensor
7 | from PIL import Image
8 | from glob import glob
9 | if img_name is not None:
10 | path = os.path.join(folder, img_name)
11 | else:
12 | path = glob(folder + "*")[idx]
13 |
14 | img = Image.open(path).resize((img_size,
15 | img_size))
16 |
17 | img = pil_to_tensor(img).to(device)
18 |
19 | if img.shape[1]== 4:
20 | img = img[:,:3,:,:]
21 | return img
22 |
23 | def mu_tilde(model, xt,x0, timestep):
24 | "mu_tilde(x_t, x_0) DDPM paper eq. 7"
25 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
26 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
27 | alpha_t = model.scheduler.alphas[timestep]
28 | beta_t = 1 - alpha_t
29 | alpha_bar = model.scheduler.alphas_cumprod[timestep]
30 | return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
31 |
32 | def sample_xts_from_x0(model, x0, num_inference_steps=50):
33 | """
34 | Samples from P(x_1:T|x_0)
35 | """
36 | # torch.manual_seed(43256465436)
37 | alpha_bar = model.scheduler.alphas_cumprod
38 | sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
39 | alphas = model.scheduler.alphas
40 | betas = 1 - alphas
41 | variance_noise_shape = (
42 | num_inference_steps,
43 | model.unet.in_channels,
44 | model.unet.sample_size,
45 | model.unet.sample_size)
46 |
47 | timesteps = model.scheduler.timesteps.to(model.device)
48 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
49 | xts = torch.zeros((num_inference_steps+1,model.unet.in_channels, model.unet.sample_size, model.unet.sample_size)).to(x0.device)
50 | xts[0] = x0
51 | for t in reversed(timesteps):
52 | idx = num_inference_steps-t_to_idx[int(t)]
53 | xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
54 |
55 |
56 | return xts
57 |
58 |
59 | def encode_text(model, prompts):
60 | text_input = model.tokenizer(
61 | prompts,
62 | padding="max_length",
63 | max_length=model.tokenizer.model_max_length,
64 | truncation=True,
65 | return_tensors="pt",
66 | )
67 | with torch.no_grad():
68 | text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
69 | return text_encoding
70 |
71 | def forward_step(model, model_output, timestep, sample):
72 | next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
73 | timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
74 |
75 | # 2. compute alphas, betas
76 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
77 | # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
78 |
79 | beta_prod_t = 1 - alpha_prod_t
80 |
81 | # 3. compute predicted original sample from predicted noise also called
82 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
83 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
84 |
85 | # 5. TODO: simple noising implementatiom
86 | next_sample = model.scheduler.add_noise(pred_original_sample,
87 | model_output,
88 | torch.LongTensor([next_timestep]))
89 | return next_sample
90 |
91 |
92 | def get_variance(model, timestep): #, prev_timestep):
93 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
94 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
95 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
96 | beta_prod_t = 1 - alpha_prod_t
97 | beta_prod_t_prev = 1 - alpha_prod_t_prev
98 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
99 | return variance
100 |
101 | def inversion_forward_process(model, x0,
102 | etas = None,
103 | prog_bar = False,
104 | prompt = "",
105 | cfg_scale = 3.5,
106 | num_inference_steps=50, eps = None):
107 |
108 | if not prompt=="":
109 | text_embeddings = encode_text(model, prompt)
110 | uncond_embedding = encode_text(model, "")
111 | timesteps = model.scheduler.timesteps.to(model.device)
112 | variance_noise_shape = (
113 | num_inference_steps,
114 | model.unet.in_channels,
115 | model.unet.sample_size,
116 | model.unet.sample_size)
117 | if etas is None or (type(etas) in [int, float] and etas == 0):
118 | eta_is_zero = True
119 | zs = None
120 | else:
121 | eta_is_zero = False
122 | if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
123 | xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
124 | alpha_bar = model.scheduler.alphas_cumprod
125 | zs = torch.zeros(size=variance_noise_shape, device=model.device)
126 | t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
127 | xt = x0
128 | # op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
129 | op = tqdm(timesteps) if prog_bar else timesteps
130 |
131 | for t in op:
132 | # idx = t_to_idx[int(t)]
133 | idx = num_inference_steps-t_to_idx[int(t)]-1
134 | # 1. predict noise residual
135 | if not eta_is_zero:
136 | xt = xts[idx+1][None]
137 | # xt = xts_cycle[idx+1][None]
138 |
139 | with torch.no_grad():
140 | out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
141 | if not prompt=="":
142 | cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
143 |
144 | if not prompt=="":
145 | ## classifier free guidance
146 | noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
147 | else:
148 | noise_pred = out.sample
149 | if eta_is_zero:
150 | # 2. compute more noisy image and set x_t -> x_t+1
151 | xt = forward_step(model, noise_pred, t, xt)
152 |
153 | else:
154 | # xtm1 = xts[idx+1][None]
155 | xtm1 = xts[idx][None]
156 | # pred of x0
157 | pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
158 |
159 | # direction to xt
160 | prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
161 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
162 |
163 | variance = get_variance(model, t)
164 | pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
165 |
166 | mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
167 |
168 | z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
169 | zs[idx] = z
170 |
171 | # correction to avoid error accumulation
172 | xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
173 | xts[idx] = xtm1
174 |
175 | if not zs is None:
176 | zs[0] = torch.zeros_like(zs[0])
177 |
178 | return xt, zs, xts
179 |
180 |
181 | def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
182 | # 1. get previous step value (=t-1)
183 | prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
184 | # 2. compute alphas, betas
185 | alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
186 | alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
187 | beta_prod_t = 1 - alpha_prod_t
188 | # 3. compute predicted original sample from predicted noise also called
189 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
190 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
191 | # 5. compute variance: "sigma_t(η)" -> see formula (16)
192 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
193 | # variance = self.scheduler._get_variance(timestep, prev_timestep)
194 | variance = get_variance(model, timestep) #, prev_timestep)
195 | std_dev_t = eta * variance ** (0.5)
196 | # Take care of asymetric reverse process (asyrp)
197 | model_output_direction = model_output
198 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
199 | # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
200 | pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
201 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
202 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
203 | # 8. Add noice if eta > 0
204 | if eta > 0:
205 | if variance_noise is None:
206 | variance_noise = torch.randn(model_output.shape, device=model.device)
207 | sigma_z = eta * variance ** (0.5) * variance_noise
208 | prev_sample = prev_sample + sigma_z
209 |
210 | return prev_sample
211 |
212 | def inversion_reverse_process(model,
213 | xT,
214 | etas = 0,
215 | prompts = "",
216 | cfg_scales = None,
217 | prog_bar = False,
218 | zs = None,
219 | controller=None,
220 | asyrp = False):
221 |
222 | batch_size = len(prompts)
223 |
224 | cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
225 |
226 | text_embeddings = encode_text(model, prompts)
227 | uncond_embedding = encode_text(model, [""] * batch_size)
228 |
229 | if etas is None: etas = 0
230 | if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
231 | assert len(etas) == model.scheduler.num_inference_steps
232 | timesteps = model.scheduler.timesteps.to(model.device)
233 |
234 | xt = xT.expand(batch_size, -1, -1, -1)
235 | op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
236 |
237 | t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
238 |
239 | for t in op:
240 | idx = model.scheduler.num_inference_steps-t_to_idx[int(t)]-(model.scheduler.num_inference_steps-zs.shape[0]+1)
241 | ## Unconditional embedding
242 | with torch.no_grad():
243 | uncond_out = model.unet.forward(xt, timestep = t,
244 | encoder_hidden_states = uncond_embedding)
245 |
246 | ## Conditional embedding
247 | if prompts:
248 | with torch.no_grad():
249 | cond_out = model.unet.forward(xt, timestep = t,
250 | encoder_hidden_states = text_embeddings)
251 |
252 |
253 | z = zs[idx] if not zs is None else None
254 | z = z.expand(batch_size, -1, -1, -1)
255 | if prompts:
256 | ## classifier free guidance
257 | noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
258 | else:
259 | noise_pred = uncond_out.sample
260 | # 2. compute less noisy image and set x_t -> x_t-1
261 | xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
262 | if controller is not None:
263 | xt = controller.step_callback(xt)
264 | return xt, zs
265 |
266 |
267 |
--------------------------------------------------------------------------------
/ddm_inversion/utils.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | from PIL import Image, ImageDraw ,ImageFont
3 | from matplotlib import pyplot as plt
4 | import torchvision.transforms as T
5 | import os
6 | import torch
7 | import yaml
8 |
9 | def show_torch_img(img):
10 | img = to_np_image(img)
11 | plt.imshow(img)
12 | plt.axis("off")
13 |
14 | def to_np_image(all_images):
15 | all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
16 | return all_images
17 |
18 | def tensor_to_pil(tensor_imgs):
19 | if type(tensor_imgs) == list:
20 | tensor_imgs = torch.cat(tensor_imgs)
21 | tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
22 | to_pil = T.ToPILImage()
23 | pil_imgs = [to_pil(img) for img in tensor_imgs]
24 | return pil_imgs
25 |
26 | def pil_to_tensor(pil_imgs):
27 | to_torch = T.ToTensor()
28 | if type(pil_imgs) == PIL.Image.Image:
29 | tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
30 | elif type(pil_imgs) == list:
31 | tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
32 | else:
33 | raise Exception("Input need to be PIL.Image or list of PIL.Image")
34 | return tensor_imgs
35 |
36 |
37 | ## TODO implement this
38 | # n = 10
39 | # num_rows = 4
40 | # num_col = n // num_rows
41 | # num_col = num_col + 1 if n % num_rows else num_col
42 | # num_col
43 | def add_margin(pil_img, top = 0, right = 0, bottom = 0,
44 | left = 0, color = (255,255,255)):
45 | width, height = pil_img.size
46 | new_width = width + right + left
47 | new_height = height + top + bottom
48 | result = Image.new(pil_img.mode, (new_width, new_height), color)
49 |
50 | result.paste(pil_img, (left, top))
51 | return result
52 |
53 | def image_grid(imgs, rows = 1, cols = None,
54 | size = None,
55 | titles = None, text_pos = (0, 0)):
56 | if type(imgs) == list and type(imgs[0]) == torch.Tensor:
57 | imgs = torch.cat(imgs)
58 | if type(imgs) == torch.Tensor:
59 | imgs = tensor_to_pil(imgs)
60 |
61 | if not size is None:
62 | imgs = [img.resize((size,size)) for img in imgs]
63 | if cols is None:
64 | cols = len(imgs)
65 | assert len(imgs) >= rows*cols
66 |
67 | top=20
68 | w, h = imgs[0].size
69 | delta = 0
70 | if len(imgs)> 1 and not imgs[1].size[1] == h:
71 | delta = top
72 | h = imgs[1].size[1]
73 | if not titles is None:
74 | font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
75 | size = 20, encoding="unic")
76 | h = top + h
77 | grid = Image.new('RGB', size=(cols*w, rows*h+delta))
78 | for i, img in enumerate(imgs):
79 |
80 | if not titles is None:
81 | img = add_margin(img, top = top, bottom = 0,left=0)
82 | draw = ImageDraw.Draw(img)
83 | draw.text(text_pos, titles[i],(0,0,0),
84 | font = font)
85 | if not delta == 0 and i > 0:
86 | grid.paste(img, box=(i%cols*w, i//cols*h+delta))
87 | else:
88 | grid.paste(img, box=(i%cols*w, i//cols*h))
89 |
90 | return grid
91 |
92 |
93 | """
94 | input_folder - dataset folder
95 | """
96 | def load_dataset(input_folder):
97 | # full_file_names = glob.glob(input_folder)
98 | # class_names = [x[0] for x in os.walk(input_folder)]
99 | class_names = next(os.walk(input_folder))[1]
100 | class_names[:] = [d for d in class_names if not d[0] == '.']
101 | file_names=[]
102 | for class_name in class_names:
103 | cur_path = os.path.join(input_folder, class_name)
104 | filenames = next(os.walk(cur_path), (None, None, []))[2]
105 | filenames = [f for f in filenames if not f[0] == '.']
106 | file_names.append(filenames)
107 | return class_names, file_names
108 |
109 |
110 | def dataset_from_yaml(yaml_location):
111 | with open(yaml_location, 'r') as stream:
112 | data_loaded = yaml.safe_load(stream)
113 |
114 | return data_loaded
--------------------------------------------------------------------------------
/example_images/horse_mud.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/example_images/horse_mud.jpg
--------------------------------------------------------------------------------
/example_images/sketch_cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/example_images/sketch_cat.jpg
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/imgs/teaser.jpg
--------------------------------------------------------------------------------
/main_run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from diffusers import StableDiffusionPipeline
3 | from diffusers import DDIMScheduler
4 | import os
5 | from prompt_to_prompt.ptp_classes import AttentionStore, AttentionReplace, AttentionRefine, EmptyControl,load_512
6 | from prompt_to_prompt.ptp_utils import register_attention_control, text2image_ldm_stable, view_images
7 | from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process
8 | from ddm_inversion.utils import image_grid,dataset_from_yaml
9 |
10 | from torch import autocast, inference_mode
11 | from ddm_inversion.ddim_inversion import ddim_inversion
12 |
13 | import calendar
14 | import time
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--device_num", type=int, default=0)
19 | parser.add_argument("--cfg_src", type=float, default=3.5)
20 | parser.add_argument("--cfg_tar", type=float, default=15)
21 | parser.add_argument("--num_diffusion_steps", type=int, default=100)
22 | parser.add_argument("--dataset_yaml", default="test.yaml")
23 | parser.add_argument("--eta", type=float, default=1)
24 | parser.add_argument("--mode", default="our_inv", help="modes: our_inv,p2pinv,p2pddim,ddim")
25 | parser.add_argument("--skip", type=int, default=36)
26 | parser.add_argument("--xa", type=float, default=0.6)
27 | parser.add_argument("--sa", type=float, default=0.2)
28 |
29 | args = parser.parse_args()
30 | full_data = dataset_from_yaml(args.dataset_yaml)
31 |
32 | # create scheduler
33 | # load diffusion model
34 | model_id = "CompVis/stable-diffusion-v1-4"
35 | # model_id = "stable_diff_local" # load local save of model (for internet problems)
36 |
37 | device = f"cuda:{args.device_num}"
38 |
39 | cfg_scale_src = args.cfg_src
40 | cfg_scale_tar_list = [args.cfg_tar]
41 | eta = args.eta # = 1
42 | skip_zs = [args.skip]
43 | xa_sa_string = f'_xa_{args.xa}_sa{args.sa}_' if args.mode=='p2pinv' else '_'
44 |
45 | current_GMT = time.gmtime()
46 | time_stamp = calendar.timegm(current_GMT)
47 |
48 | # load/reload model:
49 | ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)
50 |
51 | for i in range(len(full_data)):
52 | current_image_data = full_data[i]
53 | image_path = current_image_data['init_img']
54 | image_path = '.' + image_path
55 | image_folder = image_path.split('/')[1] # after '.'
56 | prompt_src = current_image_data.get('source_prompt', "") # default empty string
57 | prompt_tar_list = current_image_data['target_prompts']
58 |
59 | if args.mode=="p2pddim" or args.mode=="ddim":
60 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
61 | ldm_stable.scheduler = scheduler
62 | else:
63 | ldm_stable.scheduler = DDIMScheduler.from_config(model_id, subfolder = "scheduler")
64 |
65 | ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps)
66 |
67 | # load image
68 | offsets=(0,0,0,0)
69 | x0 = load_512(image_path, *offsets, device)
70 |
71 | # vae encode image
72 | with autocast("cuda"), inference_mode():
73 | w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float()
74 |
75 | # find Zs and wts - forward process
76 | if args.mode=="p2pddim" or args.mode=="ddim":
77 | wT = ddim_inversion(ldm_stable, w0, prompt_src, cfg_scale_src)
78 | else:
79 | wt, zs, wts = inversion_forward_process(ldm_stable, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=args.num_diffusion_steps)
80 |
81 | # iterate over decoder prompts
82 | for k in range(len(prompt_tar_list)):
83 | prompt_tar = prompt_tar_list[k]
84 | save_path = os.path.join(f'./results/', args.mode+xa_sa_string+str(time_stamp), image_path.split(sep='.')[0], 'src_' + prompt_src.replace(" ", "_"), 'dec_' + prompt_tar.replace(" ", "_"))
85 | os.makedirs(save_path, exist_ok=True)
86 |
87 | # Check if number of words in encoder and decoder text are equal
88 | src_tar_len_eq = (len(prompt_src.split(" ")) == len(prompt_tar.split(" ")))
89 |
90 | for cfg_scale_tar in cfg_scale_tar_list:
91 | for skip in skip_zs:
92 | if args.mode=="our_inv":
93 | # reverse process (via Zs and wT)
94 | controller = AttentionStore()
95 | register_attention_control(ldm_stable, controller)
96 | w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)
97 |
98 | elif args.mode=="p2pinv":
99 | # inversion with attention replace
100 | cfg_scale_list = [cfg_scale_src, cfg_scale_tar]
101 | prompts = [prompt_src, prompt_tar]
102 | if src_tar_len_eq:
103 | controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)
104 | else:
105 | # Should use Refine for target prompts with different number of tokens
106 | controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)
107 |
108 | register_attention_control(ldm_stable, controller)
109 | w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=prompts, cfg_scales=cfg_scale_list, prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)
110 | w0 = w0[1].unsqueeze(0)
111 |
112 | elif args.mode=="p2pddim" or args.mode=="ddim":
113 | # only z=0
114 | if skip != 0:
115 | continue
116 | prompts = [prompt_src, prompt_tar]
117 | if args.mode=="p2pddim":
118 | if src_tar_len_eq:
119 | controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)
120 | # Should use Refine for target prompts with different number of tokens
121 | else:
122 | controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)
123 | else:
124 | controller = EmptyControl()
125 |
126 | register_attention_control(ldm_stable, controller)
127 | # perform ddim inversion
128 | cfg_scale_list = [cfg_scale_src, cfg_scale_tar]
129 | w0, latent = text2image_ldm_stable(ldm_stable, prompts, controller, args.num_diffusion_steps, cfg_scale_list, None, wT)
130 | w0 = w0[1:2]
131 | else:
132 | raise NotImplementedError
133 |
134 | # vae decode image
135 | with autocast("cuda"), inference_mode():
136 | x0_dec = ldm_stable.vae.decode(1 / 0.18215 * w0).sample
137 | if x0_dec.dim()<4:
138 | x0_dec = x0_dec[None,:,:,:]
139 | img = image_grid(x0_dec)
140 |
141 | # same output
142 | current_GMT = time.gmtime()
143 | time_stamp_name = calendar.timegm(current_GMT)
144 | image_name_png = f'cfg_d_{cfg_scale_tar}_' + f'skip_{skip}_{time_stamp_name}' + ".png"
145 |
146 | save_full_path = os.path.join(save_path, image_name_png)
147 | img.save(save_full_path)
--------------------------------------------------------------------------------
/prompt_to_prompt/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/prompt_to_prompt/README.md:
--------------------------------------------------------------------------------
1 | # Prompt-to-Prompt
2 |
3 | > *Latent Diffusion* and *Stable Diffusion* Implementation
4 |
5 | ## :partying_face: ***New:*** :partying_face: Code for Null-Text Inversion is now provided [here](#null-text-inversion-for-editing-real-images)
6 |
7 |
8 | 
9 | ### [Project Page](https://prompt-to-prompt.github.io) [Paper](https://prompt-to-prompt.github.io/ptp_files/Prompt-to-Prompt_preprint.pdf)
10 |
11 |
12 | ## Setup
13 |
14 | This code was tested with Python 3.8, [Pytorch](https://pytorch.org/) 1.11 using pre-trained models through [huggingface / diffusers](https://github.com/huggingface/diffusers#readme).
15 | Specifically, we implemented our method over [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) and [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion-v1-4).
16 | Additional required packages are listed in the requirements file.
17 | The code was tested on a Tesla V100 16GB but should work on other cards with at least **12GB** VRAM.
18 |
19 | ## Quickstart
20 |
21 | In order to get started, we recommend taking a look at our notebooks: [**prompt-to-prompt_ldm**][p2p-ldm] and [**prompt-to-prompt_stable**][p2p-stable]. The notebooks contain end-to-end examples of usage of prompt-to-prompt on top of *Latent Diffusion* and *Stable Diffusion* respectively. Take a look at these notebooks to learn how to use the different types of prompt edits and understand the API.
22 |
23 | ## Prompt Edits
24 |
25 | In our notebooks, we perform our main logic by implementing the abstract class `AttentionControl` object, of the following form:
26 |
27 | ``` python
28 | class AttentionControl(abc.ABC):
29 | @abc.abstractmethod
30 | def forward (self, attn, is_cross: bool, place_in_unet: str):
31 | raise NotImplementedError
32 | ```
33 |
34 | The `forward` method is called in each attention layer of the diffusion model during the image generation, and we use it to modify the weights of the attention. Our method (See Section 3 of our [paper](https://arxiv.org/abs/2208.01626)) edits images with the procedure above, and each different prompt edit type modifies the weights of the attention in a different manner.
35 |
36 | The general flow of our code is as follows, with variations based on the attention control type:
37 |
38 | ``` python
39 | prompts = ["A painting of a squirrel eating a burger", ...]
40 | controller = AttentionControl(prompts, ...)
41 | run_and_display(prompts, controller, ...)
42 | ```
43 |
44 | ### Replacement
45 | In this case, the user swaps tokens of the original prompt with others, e.g., the editing the prompt `"A painting of a squirrel eating a burger"` to `"A painting of a squirrel eating a lasagna"` or `"A painting of a lion eating a burger"`. For this we define the class `AttentionReplace`.
46 |
47 | ### Refinement
48 | In this case, the user adds new tokens to the prompt, e.g., editing the prompt `"A painting of a squirrel eating a burger"` to `"A watercolor painting of a squirrel eating a burger"`. For this we define the class `AttentionEditRefine`.
49 |
50 | ### Re-weight
51 | In this case, the user changes the weight of certain tokens in the prompt, e.g., for the prompt `"A photo of a poppy field at night"`, strengthen or weaken the extent to which the word `night` affects the resulting image. For this we define the class `AttentionReweight`.
52 |
53 |
54 | ## Attention Control Options
55 | * `cross_replace_steps`: specifies the fraction of steps to edit the cross attention maps. Can also be set to a dictionary `[str:float]` which specifies fractions for different words in the prompt.
56 | * `self_replace_steps`: specifies the fraction of steps to replace the self attention maps.
57 | * `local_blend` (optional): `LocalBlend` object which is used to make local edits. `LocalBlend` is initialized with the words from each prompt that correspond with the region in the image we want to edit.
58 | * `equalizer`: used for attention Re-weighting only. A vector of coefficients to multiply each cross-attention weight
59 |
60 | ## Citation
61 |
62 | ``` bibtex
63 | @article{hertz2022prompt,
64 | title = {Prompt-to-Prompt Image Editing with Cross Attention Control},
65 | author = {Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
66 | journal = {arXiv preprint arXiv:2208.01626},
67 | year = {2022},
68 | }
69 | ```
70 |
71 | # Null-Text Inversion for Editing Real Images
72 |
73 | ### [Project Page](https://null-text-inversion.github.io/) [Paper](https://arxiv.org/abs/2211.09794)
74 |
75 |
76 |
77 | Null-text inversion enables intuitive text-based editing of **real images** with the Stable Diffusion model. We use an initial DDIM inversion as an anchor for our optimization which only tunes the null-text embedding used in classifier-free guidance.
78 |
79 |
80 | 
81 |
82 | ## Editing Real Images
83 |
84 | Prompt-to-Prompt editing of real images by first using Null-text inversion is provided in this [**Notebooke**][null_text].
85 |
86 |
87 | ``` bibtex
88 | @article{mokady2022null,
89 | title={Null-text Inversion for Editing Real Images using Guided Diffusion Models},
90 | author={Mokady, Ron and Hertz, Amir and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},
91 | journal={arXiv preprint arXiv:2211.09794},
92 | year={2022}
93 | }
94 | ```
95 |
96 |
97 | ## Disclaimer
98 |
99 | This is not an officially supported Google product.
100 |
101 | [p2p-ldm]: prompt-to-prompt_ldm.ipynb
102 | [p2p-stable]: prompt-to-prompt_stable.ipynb
103 | [null_text]: null_text_w_ptp.ipynb
104 |
--------------------------------------------------------------------------------
/prompt_to_prompt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__init__.py
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/ddim_inversion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ddim_inversion.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/inversion_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/inversion_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/inversion_utils_vova.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/inversion_utils_vova.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/ptp_classes.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ptp_classes.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/ptp_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/ptp_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/seq_aligner.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/seq_aligner.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/prompt_to_prompt/contributing.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code Reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/prompt_to_prompt/data/horse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/data/horse.jpg
--------------------------------------------------------------------------------
/prompt_to_prompt/docs/null_text_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/docs/null_text_teaser.png
--------------------------------------------------------------------------------
/prompt_to_prompt/docs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/docs/teaser.png
--------------------------------------------------------------------------------
/prompt_to_prompt/example_images/gnochi_mirror.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inbarhub/DDPM_inversion/58fc881772d34f0c24be9e34725d213731aa009d/prompt_to_prompt/example_images/gnochi_mirror.jpeg
--------------------------------------------------------------------------------
/prompt_to_prompt/ptp_classes.py:
--------------------------------------------------------------------------------
1 | """
2 | This code was originally taken from
3 | https://github.com/google/prompt-to-prompt
4 | """
5 |
6 |
7 | LOW_RESOURCE = True
8 | MAX_NUM_WORDS = 77
9 |
10 |
11 | from typing import Optional, Union, Tuple, List, Callable, Dict
12 | import prompt_to_prompt.ptp_utils as ptp_utils
13 | import prompt_to_prompt.seq_aligner as seq_aligner
14 | import torch
15 | import torch.nn.functional as nnf
16 | import abc
17 | import numpy as np
18 |
19 |
20 | class LocalBlend:
21 |
22 | def __call__(self, x_t, attention_store):
23 | k = 1
24 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
25 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
26 | maps = torch.cat(maps, dim=1)
27 | maps = (maps * self.alpha_layers).sum(-1).mean(1)
28 | mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
29 | mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
30 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
31 | mask = mask.gt(self.threshold)
32 | mask = (mask[:1] + mask[1:]).float()
33 | x_t = x_t[:1] + mask * (x_t - x_t[:1])
34 | return x_t
35 |
36 | def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3, device=None, tokenizer=None):
37 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
38 | for i, (prompt, words_) in enumerate(zip(prompts, words)):
39 | if type(words_) is str:
40 | words_ = [words_]
41 | for word in words_:
42 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
43 | alpha_layers[i, :, :, :, :, ind] = 1
44 | self.alpha_layers = alpha_layers.to(device)
45 | self.threshold = threshold
46 |
47 |
48 | class AttentionControl(abc.ABC):
49 |
50 | def step_callback(self, x_t):
51 | return x_t
52 |
53 | def between_steps(self):
54 | return
55 |
56 | @property
57 | def num_uncond_att_layers(self):
58 | return self.num_att_layers if LOW_RESOURCE else 0
59 |
60 | @abc.abstractmethod
61 | def forward (self, attn, is_cross: bool, place_in_unet: str):
62 | raise NotImplementedError
63 |
64 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
65 | if self.cur_att_layer >= self.num_uncond_att_layers:
66 | if LOW_RESOURCE:
67 | attn = self.forward(attn, is_cross, place_in_unet)
68 | else:
69 | h = attn.shape[0]
70 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
71 | self.cur_att_layer += 1
72 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
73 | self.cur_att_layer = 0
74 | self.cur_step += 1
75 | self.between_steps()
76 | return attn
77 |
78 | def reset(self):
79 | self.cur_step = 0
80 | self.cur_att_layer = 0
81 |
82 | def __init__(self):
83 | self.cur_step = 0
84 | self.num_att_layers = -1
85 | self.cur_att_layer = 0
86 |
87 | class EmptyControl(AttentionControl):
88 |
89 | def forward (self, attn, is_cross: bool, place_in_unet: str):
90 | return attn
91 |
92 |
93 | class AttentionStore(AttentionControl):
94 |
95 | @staticmethod
96 | def get_empty_store():
97 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
98 | "down_self": [], "mid_self": [], "up_self": []}
99 |
100 | def forward(self, attn, is_cross: bool, place_in_unet: str):
101 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
102 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
103 | self.step_store[key].append(attn)
104 | return attn
105 |
106 | def between_steps(self):
107 | if len(self.attention_store) == 0:
108 | self.attention_store = self.step_store
109 | else:
110 | for key in self.attention_store:
111 | for i in range(len(self.attention_store[key])):
112 | self.attention_store[key][i] += self.step_store[key][i]
113 | self.step_store = self.get_empty_store()
114 |
115 | def get_average_attention(self):
116 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
117 | return average_attention
118 |
119 |
120 | def reset(self):
121 | super(AttentionStore, self).reset()
122 | self.step_store = self.get_empty_store()
123 | self.attention_store = {}
124 |
125 | def __init__(self):
126 | super(AttentionStore, self).__init__()
127 | self.step_store = self.get_empty_store()
128 | self.attention_store = {}
129 |
130 |
131 | class AttentionControlEdit(AttentionStore, abc.ABC):
132 |
133 | def step_callback(self, x_t):
134 | if self.local_blend is not None:
135 | x_t = self.local_blend(x_t, self.attention_store)
136 | return x_t
137 |
138 | def replace_self_attention(self, attn_base, att_replace):
139 | if att_replace.shape[2] <= 16 ** 2:
140 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
141 | else:
142 | return att_replace
143 |
144 | @abc.abstractmethod
145 | def replace_cross_attention(self, attn_base, att_replace):
146 | raise NotImplementedError
147 |
148 | def forward(self, attn, is_cross: bool, place_in_unet: str):
149 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
150 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
151 | h = attn.shape[0] // (self.batch_size)
152 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
153 | attn_base, attn_repalce = attn[0], attn[1:]
154 | if is_cross:
155 | alpha_words = self.cross_replace_alpha[self.cur_step]
156 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
157 | attn[1:] = attn_repalce_new
158 | else:
159 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
160 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
161 | return attn
162 |
163 | def __init__(self, prompts, num_steps: int,
164 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
165 | self_replace_steps: Union[float, Tuple[float, float]],
166 | local_blend: Optional[LocalBlend],
167 | device=None,
168 | tokenizer=None):
169 | super(AttentionControlEdit, self).__init__()
170 | self.batch_size = len(prompts)
171 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
172 | if type(self_replace_steps) is float:
173 | self_replace_steps = 0, self_replace_steps
174 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
175 | self.local_blend = local_blend
176 |
177 | class AttentionReplace(AttentionControlEdit):
178 |
179 | def replace_cross_attention(self, attn_base, att_replace):
180 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
181 |
182 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
183 | local_blend: Optional[LocalBlend] = None, model=None):
184 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, device=model.device)
185 | self.mapper = seq_aligner.get_replacement_mapper(prompts, model.tokenizer).to(model.device)
186 |
187 |
188 | class AttentionRefine(AttentionControlEdit):
189 |
190 | def replace_cross_attention(self, attn_base, att_replace):
191 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
192 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
193 | return attn_replace
194 |
195 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
196 | local_blend: Optional[LocalBlend] = None, model=None):
197 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, device=model.device)
198 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, model.tokenizer)
199 | self.mapper, alphas = self.mapper.to(model.device), alphas.to(model.device)
200 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
201 |
202 |
203 | class AttentionReweight(AttentionControlEdit):
204 |
205 | def replace_cross_attention(self, attn_base, att_replace):
206 | if self.prev_controller is not None:
207 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
208 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
209 | return attn_replace
210 |
211 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
212 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None, device=None, tokenizer=None):
213 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
214 | self.equalizer = equalizer.to(device)
215 | self.prev_controller = controller
216 |
217 |
218 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
219 | Tuple[float, ...]], tokenizer=None):
220 | if type(word_select) is int or type(word_select) is str:
221 | word_select = (word_select,)
222 | equalizer = torch.ones(len(values), 77)
223 | values = torch.tensor(values, dtype=torch.float32)
224 | for word in word_select:
225 | inds = ptp_utils.get_word_inds(text, word, tokenizer)
226 | equalizer[:, inds] = values
227 | return equalizer
228 |
229 | from PIL import Image
230 |
231 | def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts=None):
232 | out = []
233 | attention_maps = attention_store.get_average_attention()
234 | num_pixels = res ** 2
235 | for location in from_where:
236 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
237 | if item.shape[1] == num_pixels:
238 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
239 | out.append(cross_maps)
240 | out = torch.cat(out, dim=0)
241 | out = out.sum(0) / out.shape[0]
242 | return out.cpu()
243 |
244 |
245 | def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0, prompts=None, tokenizer=None):
246 | tokens = tokenizer.encode(prompts[select])
247 | decoder = tokenizer.decode
248 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts)
249 | images = []
250 | for i in range(len(tokens)):
251 | image = attention_maps[:, :, i]
252 | image = 255 * image / image.max()
253 | image = image.unsqueeze(-1).expand(*image.shape, 3)
254 | image = image.numpy().astype(np.uint8)
255 | image = np.array(Image.fromarray(image).resize((256, 256)))
256 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
257 | images.append(image)
258 | return(ptp_utils.view_images(np.stack(images, axis=0)))
259 |
260 |
261 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
262 | max_com=10, select: int = 0):
263 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
264 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
265 | images = []
266 | for i in range(max_com):
267 | image = vh[i].reshape(res, res)
268 | image = image - image.min()
269 | image = 255 * image / image.max()
270 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
271 | image = Image.fromarray(image).resize((256, 256))
272 | image = np.array(image)
273 | images.append(image)
274 | ptp_utils.view_images(np.concatenate(images, axis=1))
275 |
276 | def run_and_display(model, prompts, controller, latent=None, run_baseline=False, generator=None):
277 | if run_baseline:
278 | print("w.o. prompt-to-prompt")
279 | images, latent = run_and_display(model, prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
280 | print("with prompt-to-prompt")
281 | images, x_t = ptp_utils.text2image_ld
282 |
283 |
284 | def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
285 | if type(image_path) is str:
286 | image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
287 | else:
288 | image = image_path
289 | h, w, c = image.shape
290 | left = min(left, w-1)
291 | right = min(right, w - left - 1)
292 | top = min(top, h - left - 1)
293 | bottom = min(bottom, h - top - 1)
294 | image = image[top:h-bottom, left:w-right]
295 | h, w, c = image.shape
296 | if h < w:
297 | offset = (w - h) // 2
298 | image = image[:, offset:offset + h]
299 | elif w < h:
300 | offset = (h - w) // 2
301 | image = image[offset:offset + w]
302 | image = np.array(Image.fromarray(image).resize((512, 512)))
303 | image = torch.from_numpy(image).float() / 127.5 - 1
304 | image = image.permute(2, 0, 1).unsqueeze(0).to(device)
305 |
306 | return image
--------------------------------------------------------------------------------
/prompt_to_prompt/ptp_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This code was originally taken from
3 | https://github.com/google/prompt-to-prompt
4 | """
5 |
6 | # Copyright 2022 Google LLC
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 |
21 |
22 | import numpy as np
23 | import torch
24 | from PIL import Image, ImageDraw, ImageFont
25 | import cv2
26 | from typing import Optional, Union, Tuple, List, Callable, Dict
27 | # from IPython.display import display
28 | from tqdm import tqdm
29 |
30 |
31 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
32 | h, w, c = image.shape
33 | offset = int(h * .2)
34 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
35 | font = cv2.FONT_HERSHEY_SIMPLEX
36 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
37 | img[:h] = image
38 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
39 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
40 | cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
41 | return img
42 |
43 |
44 | def view_images(images, num_rows=1, offset_ratio=0.02):
45 | if type(images) is list:
46 | num_empty = len(images) % num_rows
47 | elif images.ndim == 4:
48 | num_empty = images.shape[0] % num_rows
49 | else:
50 | images = [images]
51 | num_empty = 0
52 |
53 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
54 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
55 | num_items = len(images)
56 |
57 | h, w, c = images[0].shape
58 | offset = int(h * offset_ratio)
59 | num_cols = num_items // num_rows
60 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
61 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
62 | for i in range(num_rows):
63 | for j in range(num_cols):
64 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
65 | i * num_cols + j]
66 |
67 | pil_img = Image.fromarray(image_)
68 | # display(pil_img)
69 | return pil_img
70 |
71 |
72 |
73 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
74 | if low_resource:
75 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
76 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
77 | else:
78 | latents_input = torch.cat([latents] * 2)
79 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
80 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
81 | cfg_scales_tensor = torch.Tensor(guidance_scale).view(-1,1,1,1).to(model.device)
82 | noise_pred = noise_pred_uncond + cfg_scales_tensor * (noise_prediction_text - noise_pred_uncond)
83 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
84 | latents = controller.step_callback(latents)
85 | return latents
86 |
87 |
88 | def latent2image(vae, latents):
89 | latents = 1 / 0.18215 * latents
90 | image = vae.decode(latents)['sample']
91 | image = (image / 2 + 0.5).clamp(0, 1)
92 | image = image.cpu().permute(0, 2, 3, 1).numpy()
93 | image = (image * 255).astype(np.uint8)
94 | return image
95 |
96 |
97 | def init_latent(latent, model, height, width, generator, batch_size):
98 | if latent is None:
99 | latent = torch.randn(
100 | (1, model.unet.in_channels, height // 8, width // 8),
101 | generator=generator,
102 | )
103 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
104 | return latent, latents
105 |
106 |
107 | @torch.no_grad()
108 | def text2image_ldm(
109 | model,
110 | prompt: List[str],
111 | controller,
112 | num_inference_steps: int = 50,
113 | guidance_scale: Optional[float] = 7.,
114 | generator: Optional[torch.Generator] = None,
115 | latent: Optional[torch.FloatTensor] = None,
116 | ):
117 | register_attention_control(model, controller)
118 | height = width = 256
119 | batch_size = len(prompt)
120 |
121 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
122 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
123 |
124 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
125 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
126 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
127 | context = torch.cat([uncond_embeddings, text_embeddings])
128 |
129 | model.scheduler.set_timesteps(num_inference_steps)
130 | for t in tqdm(model.scheduler.timesteps):
131 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
132 |
133 | image = latent2image(model.vqvae, latents)
134 |
135 | return image, latent
136 |
137 |
138 | @torch.no_grad()
139 | def text2image_ldm_stable(
140 | model,
141 | prompt: List[str],
142 | controller,
143 | num_inference_steps: int = 50,
144 | guidance_scale: float = 7.5,
145 | generator: Optional[torch.Generator] = None,
146 | latent: Optional[torch.FloatTensor] = None,
147 | restored_wt = None,
148 | restored_zs = None,
149 | low_resource: bool = False,
150 | ):
151 | register_attention_control(model, controller)
152 | height = width = 512
153 | batch_size = len(prompt)
154 |
155 | text_input = model.tokenizer(
156 | prompt,
157 | padding="max_length",
158 | max_length=model.tokenizer.model_max_length,
159 | truncation=True,
160 | return_tensors="pt",
161 | )
162 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
163 | max_length = text_input.input_ids.shape[-1]
164 | uncond_input = model.tokenizer(
165 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
166 | )
167 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
168 |
169 | context = [uncond_embeddings, text_embeddings]
170 | if not low_resource:
171 | context = torch.cat(context)
172 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
173 |
174 | # set timesteps
175 | # extra_set_kwargs = {"offset": 1}
176 | model.scheduler.set_timesteps(num_inference_steps)#, **extra_set_kwargs)
177 | for t in tqdm(model.scheduler.timesteps):
178 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
179 |
180 | # image = latent2image(model.vae, latents)
181 |
182 | return latents, latent
183 |
184 |
185 | def register_attention_control(model, controller):
186 | def ca_forward(self, place_in_unet):
187 | to_out = self.to_out
188 | if type(to_out) is torch.nn.modules.container.ModuleList:
189 | to_out = self.to_out[0]
190 | else:
191 | to_out = self.to_out
192 |
193 | def forward(x, context=None, mask=None):
194 | batch_size, sequence_length, dim = x.shape
195 | h = self.heads
196 | q = self.to_q(x)
197 | is_cross = context is not None
198 | context = context if is_cross else x
199 | k = self.to_k(context)
200 | v = self.to_v(context)
201 | q = self.reshape_heads_to_batch_dim(q)
202 | k = self.reshape_heads_to_batch_dim(k)
203 | v = self.reshape_heads_to_batch_dim(v)
204 |
205 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
206 |
207 | if mask is not None:
208 | mask = mask.reshape(batch_size, -1)
209 | max_neg_value = -torch.finfo(sim.dtype).max
210 | mask = mask[:, None, :].repeat(h, 1, 1)
211 | sim.masked_fill_(~mask, max_neg_value)
212 |
213 | # attention, what we cannot get enough of
214 | attn = sim.softmax(dim=-1)
215 | attn = controller(attn, is_cross, place_in_unet)
216 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
217 | out = self.reshape_batch_dim_to_heads(out)
218 | return to_out(out)
219 |
220 | return forward
221 |
222 | class DummyController:
223 |
224 | def __call__(self, *args):
225 | return args[0]
226 |
227 | def __init__(self):
228 | self.num_att_layers = 0
229 |
230 | if controller is None:
231 | controller = DummyController()
232 |
233 | def register_recr(net_, count, place_in_unet):
234 | if net_.__class__.__name__ == 'CrossAttention':
235 | net_.forward = ca_forward(net_, place_in_unet)
236 | return count + 1
237 | elif hasattr(net_, 'children'):
238 | for net__ in net_.children():
239 | count = register_recr(net__, count, place_in_unet)
240 | return count
241 |
242 | cross_att_count = 0
243 | sub_nets = model.unet.named_children()
244 | for net in sub_nets:
245 | if "down" in net[0]:
246 | cross_att_count += register_recr(net[1], 0, "down")
247 | elif "up" in net[0]:
248 | cross_att_count += register_recr(net[1], 0, "up")
249 | elif "mid" in net[0]:
250 | cross_att_count += register_recr(net[1], 0, "mid")
251 |
252 | controller.num_att_layers = cross_att_count
253 |
254 |
255 | def get_word_inds(text: str, word_place: int, tokenizer):
256 | split_text = text.split(" ")
257 | if type(word_place) is str:
258 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
259 | elif type(word_place) is int:
260 | word_place = [word_place]
261 | out = []
262 | if len(word_place) > 0:
263 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
264 | cur_len, ptr = 0, 0
265 |
266 | for i in range(len(words_encode)):
267 | cur_len += len(words_encode[i])
268 | if ptr in word_place:
269 | out.append(i + 1)
270 | if cur_len >= len(split_text[ptr]):
271 | ptr += 1
272 | cur_len = 0
273 | return np.array(out)
274 |
275 |
276 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
277 | word_inds: Optional[torch.Tensor]=None):
278 | if type(bounds) is float:
279 | bounds = 0, bounds
280 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
281 | if word_inds is None:
282 | word_inds = torch.arange(alpha.shape[2])
283 | alpha[: start, prompt_ind, word_inds] = 0
284 | alpha[start: end, prompt_ind, word_inds] = 1
285 | alpha[end:, prompt_ind, word_inds] = 0
286 | return alpha
287 |
288 |
289 | def get_time_words_attention_alpha(prompts, num_steps,
290 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
291 | tokenizer, max_num_words=77):
292 | if type(cross_replace_steps) is not dict:
293 | cross_replace_steps = {"default_": cross_replace_steps}
294 | if "default_" not in cross_replace_steps:
295 | cross_replace_steps["default_"] = (0., 1.)
296 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
297 | for i in range(len(prompts) - 1):
298 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
299 | i)
300 | for key, item in cross_replace_steps.items():
301 | if key != "default_":
302 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
303 | for i, ind in enumerate(inds):
304 | if len(ind) > 0:
305 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
306 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
307 | return alpha_time_words
308 |
--------------------------------------------------------------------------------
/prompt_to_prompt/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.8.0
2 | transformers
3 | ftfy
4 | opencv-python
5 | ipywidgets
--------------------------------------------------------------------------------
/prompt_to_prompt/seq_aligner.py:
--------------------------------------------------------------------------------
1 | """
2 | This code was originally taken from
3 | https://github.com/google/prompt-to-prompt
4 | """
5 |
6 | # Copyright 2022 Google LLC
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | import torch
20 | import numpy as np
21 |
22 |
23 | class ScoreParams:
24 |
25 | def __init__(self, gap, match, mismatch):
26 | self.gap = gap
27 | self.match = match
28 | self.mismatch = mismatch
29 |
30 | def mis_match_char(self, x, y):
31 | if x != y:
32 | return self.mismatch
33 | else:
34 | return self.match
35 |
36 |
37 | def get_matrix(size_x, size_y, gap):
38 | matrix = []
39 | for i in range(len(size_x) + 1):
40 | sub_matrix = []
41 | for j in range(len(size_y) + 1):
42 | sub_matrix.append(0)
43 | matrix.append(sub_matrix)
44 | for j in range(1, len(size_y) + 1):
45 | matrix[0][j] = j*gap
46 | for i in range(1, len(size_x) + 1):
47 | matrix[i][0] = i*gap
48 | return matrix
49 |
50 |
51 | def get_matrix(size_x, size_y, gap):
52 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
53 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
54 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
55 | return matrix
56 |
57 |
58 | def get_traceback_matrix(size_x, size_y):
59 | matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
60 | matrix[0, 1:] = 1
61 | matrix[1:, 0] = 2
62 | matrix[0, 0] = 4
63 | return matrix
64 |
65 |
66 | def global_align(x, y, score):
67 | matrix = get_matrix(len(x), len(y), score.gap)
68 | trace_back = get_traceback_matrix(len(x), len(y))
69 | for i in range(1, len(x) + 1):
70 | for j in range(1, len(y) + 1):
71 | left = matrix[i, j - 1] + score.gap
72 | up = matrix[i - 1, j] + score.gap
73 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
74 | matrix[i, j] = max(left, up, diag)
75 | if matrix[i, j] == left:
76 | trace_back[i, j] = 1
77 | elif matrix[i, j] == up:
78 | trace_back[i, j] = 2
79 | else:
80 | trace_back[i, j] = 3
81 | return matrix, trace_back
82 |
83 |
84 | def get_aligned_sequences(x, y, trace_back):
85 | x_seq = []
86 | y_seq = []
87 | i = len(x)
88 | j = len(y)
89 | mapper_y_to_x = []
90 | while i > 0 or j > 0:
91 | if trace_back[i, j] == 3:
92 | x_seq.append(x[i-1])
93 | y_seq.append(y[j-1])
94 | i = i-1
95 | j = j-1
96 | mapper_y_to_x.append((j, i))
97 | elif trace_back[i][j] == 1:
98 | x_seq.append('-')
99 | y_seq.append(y[j-1])
100 | j = j-1
101 | mapper_y_to_x.append((j, -1))
102 | elif trace_back[i][j] == 2:
103 | x_seq.append(x[i-1])
104 | y_seq.append('-')
105 | i = i-1
106 | elif trace_back[i][j] == 4:
107 | break
108 | mapper_y_to_x.reverse()
109 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
110 |
111 |
112 | def get_mapper(x: str, y: str, tokenizer, max_len=77):
113 | x_seq = tokenizer.encode(x)
114 | y_seq = tokenizer.encode(y)
115 | score = ScoreParams(0, 1, -1)
116 | matrix, trace_back = global_align(x_seq, y_seq, score)
117 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
118 | alphas = torch.ones(max_len)
119 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
120 | mapper = torch.zeros(max_len, dtype=torch.int64)
121 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
122 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
123 | return mapper, alphas
124 |
125 |
126 | def get_refinement_mapper(prompts, tokenizer, max_len=77):
127 | x_seq = prompts[0]
128 | mappers, alphas = [], []
129 | for i in range(1, len(prompts)):
130 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
131 | mappers.append(mapper)
132 | alphas.append(alpha)
133 | return torch.stack(mappers), torch.stack(alphas)
134 |
135 |
136 | def get_word_inds(text: str, word_place: int, tokenizer):
137 | split_text = text.split(" ")
138 | if type(word_place) is str:
139 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
140 | elif type(word_place) is int:
141 | word_place = [word_place]
142 | out = []
143 | if len(word_place) > 0:
144 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
145 | cur_len, ptr = 0, 0
146 |
147 | for i in range(len(words_encode)):
148 | cur_len += len(words_encode[i])
149 | if ptr in word_place:
150 | out.append(i + 1)
151 | if cur_len >= len(split_text[ptr]):
152 | ptr += 1
153 | cur_len = 0
154 | return np.array(out)
155 |
156 |
157 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
158 | words_x = x.split(' ')
159 | words_y = y.split(' ')
160 | if len(words_x) != len(words_y):
161 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
162 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
163 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
164 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
165 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
166 | mapper = np.zeros((max_len, max_len))
167 | i = j = 0
168 | cur_inds = 0
169 | while i < max_len and j < max_len:
170 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
171 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
172 | if len(inds_source_) == len(inds_target_):
173 | mapper[inds_source_, inds_target_] = 1
174 | else:
175 | ratio = 1 / len(inds_target_)
176 | for i_t in inds_target_:
177 | mapper[inds_source_, i_t] = ratio
178 | cur_inds += 1
179 | i += len(inds_source_)
180 | j += len(inds_target_)
181 | elif cur_inds < len(inds_source):
182 | mapper[i, j] = 1
183 | i += 1
184 | j += 1
185 | else:
186 | mapper[j, j] = 1
187 | i += 1
188 | j += 1
189 |
190 | return torch.from_numpy(mapper).float()
191 |
192 |
193 |
194 | def get_replacement_mapper(prompts, tokenizer, max_len=77):
195 | x_seq = prompts[0]
196 | mappers = []
197 | for i in range(1, len(prompts)):
198 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
199 | mappers.append(mapper)
200 | return torch.stack(mappers)
201 |
202 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | certifi==2022.12.7
3 | charset-normalizer==3.1.0
4 | cmake==3.26.3
5 | contourpy==1.0.7
6 | cycler==0.11.0
7 | diffusers==0.8.0
8 | filelock==3.12.0
9 | fonttools==4.39.3
10 | huggingface-hub==0.13.4
11 | idna==3.4
12 | importlib-metadata==6.5.0
13 | importlib-resources==5.12.0
14 | Jinja2==3.1.2
15 | kiwisolver==1.4.4
16 | lit==16.0.1
17 | MarkupSafe==2.1.2
18 | matplotlib==3.7.1
19 | mpmath==1.3.0
20 | networkx==3.1
21 | numpy==1.24.2
22 | nvidia-cublas-cu11==11.10.3.66
23 | nvidia-cuda-cupti-cu11==11.7.101
24 | nvidia-cuda-nvrtc-cu11==11.7.99
25 | nvidia-cuda-runtime-cu11==11.7.99
26 | nvidia-cudnn-cu11==8.5.0.96
27 | nvidia-cufft-cu11==10.9.0.58
28 | nvidia-curand-cu11==10.2.10.91
29 | nvidia-cusolver-cu11==11.4.0.1
30 | nvidia-cusparse-cu11==11.7.4.91
31 | nvidia-nccl-cu11==2.14.3
32 | nvidia-nvtx-cu11==11.7.91
33 | opencv-python==4.7.0.72
34 | packaging==23.1
35 | Pillow==9.5.0
36 | pkg_resources==0.0.0
37 | psutil==5.9.5
38 | pyparsing==3.0.9
39 | python-dateutil==2.8.2
40 | PyYAML==6.0
41 | regex==2023.3.23
42 | requests==2.28.2
43 | six==1.16.0
44 | sympy==1.11.1
45 | tokenizers==0.13.3
46 | torch==2.0.0
47 | torchaudio==2.0.1
48 | torchvision==0.15.1
49 | tqdm==4.65.0
50 | transformers==4.28.1
51 | triton==2.0.0
52 | typing_extensions==4.5.0
53 | urllib3==1.26.15
54 | zipp==3.15.0
55 |
--------------------------------------------------------------------------------
/test.yaml:
--------------------------------------------------------------------------------
1 | -
2 | init_img: /example_images/horse_mud.jpg
3 | source_prompt: a photo of a horse in the mud
4 |
5 | target_prompts:
6 | - a photo of a horse in the snow
7 | - a photo of a zebra in the snow
8 | - a photo of a zebra in the mud
9 |
10 | -
11 | init_img: /example_images/sketch_cat.jpg
12 | source_prompt: a sketch of a cat
13 |
14 | target_prompts:
15 | - a sculpture of a cat
16 |
--------------------------------------------------------------------------------