├── README.md ├── app.py ├── canvas.png ├── config.py ├── figures └── teaser.jpg ├── pipeline_layout.py ├── pipeline_stable_diffusion.py ├── requirements.txt ├── run_learning.py ├── run_sample.py ├── samples ├── .DS_Store └── image1 │ ├── .DS_Store │ ├── image.png │ ├── mask.png │ ├── semantic_dict.json │ └── target_layout │ ├── layout1.png │ ├── layout2.png │ └── layout3.png ├── single_image_learning.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── custom_utils.cpython-310.pyc ├── ddim_inversion.cpython-310.pyc └── ptp_utils.cpython-310.pyc ├── custom_utils.py ├── ddim_inversion.py ├── ptp_utils.py └── retrieve.py /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Layout Editing of Single Images with Diffusion Models 2 | 3 | ![alt text](figures/teaser.jpg) 4 | 5 | Zhiyuan Zhang $^{1*}$, Zhitong Huang $^{1*}$, [Jing Liao](https://liaojing.github.io/html/) $^{1\dagger}$ 6 | 7 | $^1$: City University of Hong Kong, Hong Kong SAR 8 | $^*$: Both authors contributed equally to this research    $^\dagger$: Corresponding author 9 | 10 | ## Abstract: 11 | Recent advancements in large-scale text-to-image diffusion models have enabled many applications in image editing. However, none of these methods have been able to edit the layout of single existing images. To address this gap, we propose the first framework for layout editing of a single image while preserving its visual properties, thus allowing for continuous editing on a single image. Our approach is achieved through two key modules. First, to preserve the characteristics of multiple objects within an image, we disentangle the concepts of different objects and embed them into separate textual tokens using a novel method called masked textual inversion. Next, we propose a training-free optimization method to perform layout control for a pre-trained diffusion model, which allows us to regenerate images with learned concepts and align them with user-specified layouts. As the first framework to edit the layout of existing images, we demonstrate that our method is effective and outperforms other baselines that were modified to support this task. Our code will be freely available for public use upon acceptance. 12 | 13 | ## Installation 14 | 15 | ```pip install -r requirements.txt ``` 16 | 17 | We test on a single V100 GPU, you can set the iv_train_batch_size:2, iv_gradient_accumulation_steps:2, ft_train_batch_size:2, ft_gradient_accumulation_steps:1 to train on GPU with 20G memory 18 | 19 | ## Usage 20 | 21 | 1. Prepare the input image, mask image, target layout images and semantic dict(for example, in the sample/image1 folder, semantic_dict.json specify the color mapping: cat -> yellow(75,254,1), pot -> green(251,1,1) and plant -< red(251,1,1).) 22 | 23 | 2. Retrieve images for regularization 24 | 25 | ```bash 26 | python utils/retrieve.py --target_name "cat+pot+plant" --outpath real_reg --num_class_images 200 27 | ``` 28 | 29 | 3. Single Image Learning 30 | 31 | ```bash 32 | accelerate launch run_learning.py --train_prompt "A high quality picture of cat, pot and plant" \ 33 | --scale_lr --with_prior_preservation --image_path samples/image1/image.png \ 34 | --mask_image samples/image1/mask.png \ 35 | --semantic_dict samples/image1/semantic_dict.json \ 36 | --iv_initializer_tokens "cat+pot+plant" \ 37 | --addition_tokens "sks+uy" \ 38 | --output_dir samples/image1/outputs/embeds \ 39 | --iv_max_train_steps 200 \ 40 | --ft_max_train_steps 800 \ 41 | --iv_lr 5e-4 \ 42 | --ft_lr 1e-5 \ 43 | --iv_train_batch_size 4 \ 44 | --iv_gradient_accumulation_steps 1 \ 45 | --ft_train_batch_size 2 \ 46 | --ft_gradient_accumulation_steps 1 \ 47 | --report_to="wandb" 48 | ``` 49 | 50 | 4. Continuous Layout Editing 51 | 52 | ```bash 53 | python run_sample.py --prompt "A high quality picture of , and " \ 54 | --image_path samples/image1/image.png \ 55 | --mask_image samples/image1/mask.png \ 56 | --target_layout samples/image1/target_layout/layout1.png \ 57 | --semantic_dict samples/image1/semantic_dict.json \ 58 | --delta_ckpt samples/image1/outputs/embeds/fine_tune/delta.bin \ 59 | --output_dir samples/image1/outputs/images \ 60 | --blend_steps 15 61 | 62 | python run_sample.py --prompt "A high quality picture of , and " \ 63 | --image_path samples/image1/image.png \ 64 | --mask_image samples/image1/mask.png \ 65 | --target_layout samples/image1/target_layout/layout2.png \ 66 | --semantic_dict samples/image1/semantic_dict.json \ 67 | --delta_ckpt samples/image1/outputs/embeds/fine_tune/delta.bin \ 68 | --output_dir samples/image1/outputs/images2 \ 69 | --blend_steps 15 70 | 71 | python run_sample.py --prompt "A high quality picture of , and " \ 72 | --image_path samples/image1/image.png \ 73 | --mask_image samples/image1/mask.png \ 74 | --target_layout samples/image1/target_layout/layout3.png \ 75 | --semantic_dict samples/image1/semantic_dict.json \ 76 | --delta_ckpt samples/image1/outputs/embeds/fine_tune/delta.bin \ 77 | --output_dir samples/image1/outputs/images3 \ 78 | --blend_steps 15 79 | ``` 80 | 81 | ## UI 82 | 83 | ```bash 84 | gradio app.py 85 | ``` 86 | 87 | 1. Put the regularization dataset under the real_reg folder first. 88 | 2. Upload the image, draw the target layout, fill in all the texts and then press GetColor. 89 | 3. Assign object to color and then press submit. The first run will be longer as learning code is runned. 90 | 4. Press clear and try more layout!(or you can press GetColor and submit to get a new image using current layout) 91 | 92 | ### checkpoints 93 | 94 | cat_pot_plant: https://drive.google.com/file/d/1PA49yIjM_7fh97iPzYTABYpVw8OUJWFU/view?usp=sharing 95 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from sklearn.cluster import KMeans 3 | import numpy as np 4 | from PIL import Image 5 | from transformers import AutoProcessor, CLIPSegForImageSegmentation 6 | from utils.ptp_utils import AttentionStore 7 | from utils.ddim_inversion import create_inversion_latents 8 | from single_image_learning import run as run1 9 | from run_sample import run as run2 10 | from run_sample import load_model 11 | import torch 12 | import json 13 | import os 14 | import random 15 | 16 | # -------------------------------------Helper Functions----------------------------------------- 17 | def seg_image(img, color2token, bg_color): 18 | global seg_processor 19 | global seg_model 20 | colors = [] 21 | texts = [] 22 | for color, token in color2token: 23 | colors.append(np.array(list(color))) 24 | texts.append(token) 25 | 26 | texts.append("background") 27 | inputs = seg_processor(text=texts, images=[img] * len(texts), padding=True, return_tensors="pt") 28 | outputs = seg_model(**inputs) 29 | logits = np.argmax(outputs.logits.detach().numpy(), axis=0) 30 | img = np.zeros((352, 352, 3)) 31 | img[:,:] = bg_color 32 | for i, token in enumerate(texts[:-1]): 33 | pred_i = logits == i 34 | img[pred_i] = colors[i] 35 | image = Image.fromarray(img.astype(np.uint8)).resize((64,64), resample = 0) 36 | image.save('UI/ini_mask.png') 37 | return image 38 | 39 | def get_masks(color2token, mask_image, size = (16,16)): 40 | masks = [] 41 | pil_image = Image.open(mask_image).resize(size,resample=0) 42 | img = np.array(pil_image) 43 | for k, _ in color2token: 44 | color = np.array(k) 45 | mask = np.all(img == color, axis=-1) 46 | assert mask.sum()!=0 47 | masks.append(mask) 48 | return masks 49 | 50 | def get_blend_mask(tgt_bg): 51 | m1 = np.array(Image.open("UI/ini_mask.png")) 52 | m2 = np.array(Image.open(f"UI/target_layout/tgt_mask{layout_count}.png")) 53 | b = np.array(tgt_bg) 54 | blend_mask = np.all(m1 == b, axis=-1) & np.all(m2 == b, axis=-1) 55 | blend_mask = torch.from_numpy(blend_mask).cuda() 56 | return blend_mask 57 | 58 | def prompt_to_idx(prompt): 59 | words = list(prompt.strip(" ").replace(", ", ",").replace(",", " , ").split(" ")) 60 | # "cat,dog" -> ["cat", ",", "dog"] 61 | # "cat, dog" -> ["cat", ",", "dog"] 62 | if len(words) > 1: 63 | return str(dict(zip(range(1, 1+len(words)), words))) 64 | return "" 65 | 66 | def RGB_to_Hex(rgb): 67 | if rgb[0]>=250 and rgb[1]>=250 and rgb[2]>=250: 68 | return "#FFFFFF" 69 | color = "#" 70 | for i in rgb: 71 | num = int(i) 72 | color += str(hex(num))[-2:].replace("x", "0").upper() 73 | return color 74 | 75 | def get_hex_colors(mask): 76 | global colors 77 | colors = [] 78 | hex_colors = [] 79 | image_colors = mask.getcolors(maxcolors=64*64) 80 | for i, (_, color) in enumerate(image_colors): 81 | hex_color = RGB_to_Hex(color) 82 | if hex_color != "#FFFFFF": 83 | colors.append(color) 84 | hex_colors.append(hex_color) 85 | else: 86 | white = color 87 | colors.append(white) 88 | return hex_colors 89 | 90 | def clean_image(image, num_color): 91 | image = image.resize((64,64), resample=0) 92 | X = np.array(image).reshape(-1,3) 93 | kmeans = KMeans(n_clusters=num_color, random_state=0, n_init='auto').fit(X) 94 | for i in range(num_color): 95 | t = kmeans.labels_ == i 96 | mean_color = X[t].mean(axis = 0).astype(int) 97 | X[t] = mean_color 98 | X = X.reshape(64,64,3) 99 | return Image.fromarray(X) 100 | 101 | def generate(prompt, pipe, blend_mask, masks, indices_to_alter): 102 | inversion_latents = create_inversion_latents(pipe, "UI/initial.png", prompt, \ 103 | guidance_scale=5, ddim_steps=50) 104 | blend_dict = { 105 | "blend_mask":blend_mask, 106 | "inversion_latents":inversion_latents, 107 | "blend_steps":15 108 | } 109 | seed = random.randint(0, 9999) 110 | g = torch.Generator('cuda').manual_seed(seed) 111 | controller = AttentionStore() 112 | image = run2(pipe, 113 | prompt=prompt, 114 | guidance_scale = 5, 115 | n_inference_steps = 50, 116 | eta = 0, 117 | controller=controller, 118 | indices_to_alter= indices_to_alter, 119 | generator=g, 120 | run_standard_sd=False, 121 | scale_factor = 20, 122 | thresholds = {0:0.6, 10: 0.7, 20: 0.8}, 123 | max_iter_to_alter=25, 124 | max_refinement_steps=40, 125 | scale_range = (1., 0.5), 126 | masks = masks, 127 | blend_dict = blend_dict, 128 | ) 129 | image.save(f"UI/images/{layout_count}.png") 130 | return image 131 | 132 | def train_model(color2token, train_prompt, out_dir): 133 | from config import RunConfig 134 | 135 | args = RunConfig() 136 | args.image_path = "UI/initial.png" 137 | args.output_dir = out_dir 138 | os.makedirs(out_dir, exist_ok = True) 139 | 140 | args.iv_initializer_tokens = [token for _, token in color2token] 141 | args.iv_modifier_tokens = [f"<{token}>" for token in args.iv_initializer_tokens] 142 | masks = get_masks(color2token, "UI/ini_mask.png",(64,64)) 143 | args.iv_mask = np.stack(masks)[:,None,:,:] 144 | 145 | args.ft_initializer_tokens = ["sks","uy"] 146 | args.ft_modifier_tokens = [f"" for i in range(len(args.ft_initializer_tokens))] 147 | tail_str = " ".join(args.ft_modifier_tokens) 148 | 149 | args.reg_dirs = [] 150 | for i in args.iv_initializer_tokens: 151 | path = f"real_reg/samples_{i}" 152 | if os.path.exists(path): 153 | args.reg_dirs.append(path) 154 | 155 | train_prompts = [] 156 | all_replaced_prompt = train_prompt 157 | for initializer_token, modifier_token in zip(args.iv_initializer_tokens, args.iv_modifier_tokens): 158 | all_replaced_prompt = all_replaced_prompt.replace(initializer_token, modifier_token) 159 | train_prompts.append(train_prompt.replace(initializer_token, modifier_token)) 160 | train_prompts.append(all_replaced_prompt + " " + tail_str) 161 | args.train_prompt = train_prompts 162 | print(args.train_prompt) 163 | 164 | run1(args) 165 | 166 | def replace_color(path, old_mapping, new_mapping): 167 | image = np.array(Image.open(path)) 168 | for k in old_mapping: 169 | v = old_mapping[k] 170 | m = np.all(image == v, axis=-1) 171 | image[m] = np.array(new_mapping[k]) 172 | image = Image.fromarray(image) 173 | image.save(path) 174 | return image 175 | # -------------------------------------------------------------------------------------------- 176 | 177 | def generateImage(ini_image, text, obj1, obj2, obj3, obj4): 178 | global colors 179 | global pipe 180 | global global_idxs_to_alter 181 | global new_photo 182 | global seg_mask 183 | global layout_count 184 | 185 | bg_color = colors[-1] 186 | # mapping from token to color 187 | mapping = dict(zip([obj1, obj2, obj3, obj4], colors)) 188 | mapping["background"] = bg_color 189 | 190 | color2token = [] 191 | # reorder the mapping in the order of the tokens appear in the sentence 192 | for choice in choices: 193 | color2token.append((mapping[choice], choice)) 194 | 195 | if new_photo: 196 | layout_count = 0 197 | torch.cuda.empty_cache() 198 | new_photo = False 199 | out_dir = "UI/embeds/" 200 | # use clip segmentation for getting segmentatin mask 201 | seg_mask = seg_image(ini_image, color2token, bg_color) 202 | # run learning code 203 | train_model(color2token, text, out_dir) 204 | # load trained model 205 | delta_ckpt = out_dir + "fine_tune/delta.bin" 206 | pipe, _ = load_model(delta_ckpt) 207 | print("model loaded") 208 | else: 209 | # replace the colors of mask with the new color 210 | with open("UI/semantic_dict.json", 'r') as f: 211 | old_mapping = json.load(f) 212 | for i in range(layout_count): 213 | replace_color(f"UI/target_layout/tgt_mask{i}.png", old_mapping, mapping) 214 | seg_mask = replace_color("UI/ini_mask.png", old_mapping, mapping) 215 | 216 | # save color2token to semantic_dict.json 217 | with open("UI/semantic_dict.json", 'w') as f: 218 | json.dump(mapping, f) 219 | 220 | tokens = [token for _, token in color2token] 221 | for token in tokens: 222 | text = text.replace(token, f"<{token}>") 223 | blend_mask = get_blend_mask(bg_color) 224 | masks = [torch.Tensor(mask).cuda() for mask in get_masks(color2token, f"UI/target_layout/tgt_mask{layout_count}.png")] 225 | 226 | indices_to_alter = [int(idx) for idx in global_idxs_to_alter] 227 | 228 | image = generate(text, pipe, blend_mask, masks, indices_to_alter) 229 | 230 | outputs = [gr.update(value=image), gr.update(value=seg_mask),gr.update(value="UI/initial.png"), gr.update(value=f"UI/target_layout/tgt_mask{layout_count}.png")] 231 | labels = [gr.update(label=f"color1({obj1})"), gr.update(label=f"color2({obj2})"),gr.update(label=f"color3({obj3})"), gr.update(label=obj4)] 232 | layout_count += 1 233 | return outputs + labels 234 | 235 | 236 | def upload_file(files): 237 | global new_photo 238 | file_paths = [file.name for file in files] 239 | file_path = file_paths[0] 240 | image = Image.open(file_path).resize((512,512)) 241 | image.save("UI/initial.png") 242 | new_photo = True 243 | return gr.update(visible=True, value=image) 244 | 245 | def fetch_colors(image, idxs_to_alter, prompt): 246 | global global_idxs_to_alter 247 | global choices 248 | global_idxs_to_alter = idxs_to_alter.split(',') 249 | tokens = prompt.strip(" ").replace(", ", ",").replace(",", " , ").split(" ") 250 | mask = clean_image(image, len(global_idxs_to_alter)+1) 251 | mask.save(f"UI/target_layout/tgt_mask{layout_count}.png") 252 | hex_colors = get_hex_colors(mask) 253 | 254 | choices = [tokens[int(idx)-1] for idx in global_idxs_to_alter] 255 | visibility = [] 256 | values = [] 257 | for i in range(4): 258 | if i < len(hex_colors): 259 | visibility.append(True) 260 | values.append(hex_colors[i]) 261 | else: 262 | visibility.append(False) 263 | values.append(None) 264 | 265 | # update the Dropdown 266 | ret = [gr.Dropdown.update(choices=choices,visible=v1) for v1 in visibility] 267 | # update the color selector 268 | ret += [gr.update(visible=v1, value=v2) for v1,v2 in zip(visibility,values)] 269 | return ret 270 | 271 | def clear_image(): 272 | res = [gr.update(value = None) for _ in range(4)] + [gr.update(visible=False) for _ in range(8)] 273 | return res 274 | 275 | colors = [] 276 | global_idxs_to_alter = [] 277 | choices = [] 278 | new_photo = False 279 | seg_mask = None 280 | layout_count = 0 281 | 282 | seg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") 283 | seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") 284 | print("segmentation model loaded") 285 | os.makedirs("UI/target_layout/", exist_ok=True) 286 | os.makedirs("UI/images/", exist_ok=True) 287 | 288 | with gr.Blocks() as demo: 289 | with gr.Row(): 290 | with gr.Column(): 291 | with gr.Box(): 292 | tgt_layout = gr.Image("canvas.png", source = "canvas", tool="color-sketch", type="pil", shape=(256, 256)) 293 | upload_button = gr.UploadButton("Click to Upload a Image", file_types=["image"], file_count="multiple") 294 | prompt = gr.Textbox(label="Prompt") 295 | idxs_to_alter = gr.Textbox(label="Index of the objects") 296 | with gr.Row(): 297 | token1 = gr.Dropdown(visible=False,label="color1_object",interactive = True) 298 | token2 = gr.Dropdown(visible=False,label="color2_object",interactive = True) 299 | token3 = gr.Dropdown(visible=False,label="color3_object",interactive = True) 300 | token4 = gr.Dropdown(visible=False,label="color4_object",interactive = True) 301 | with gr.Row(): 302 | fetch = gr.Button(value="Getcolor") 303 | submit = gr.Button(value="Submit") 304 | clear = gr.Button(value="Clear") 305 | 306 | with gr.Column(): 307 | with gr.Box(): 308 | tokenIndex = gr.Textbox(label="Token Index") 309 | with gr.Row(): 310 | input_image = gr.Image(label="Input Image", visible=True,interactive = False, shape=(256, 256)) 311 | input_mask = gr.Image(label="Input Mask", visible=True,interactive = False, shape=(256, 256)) 312 | with gr.Row(): 313 | output_image = gr.Image(label="Output Image",visible=True,interactive = False, shape=(256, 256)) 314 | output_mask = gr.Image(label="Output Mask",visible=True,interactive = False, shape=(256, 256)) 315 | with gr.Row(): 316 | c1 = gr.ColorPicker(visible=False,label="color1") 317 | c2 = gr.ColorPicker(visible=False,label="color2") 318 | c3 = gr.ColorPicker(visible=False,label="color3") 319 | c4 = gr.ColorPicker(visible=False,label="color4") 320 | 321 | submit_inputs = [ 322 | input_image, 323 | prompt, 324 | token1, 325 | token2, 326 | token3, 327 | token4, 328 | ] 329 | 330 | fetch_outputs = [ 331 | token1, 332 | token2, 333 | token3, 334 | token4, 335 | c1, 336 | c2, 337 | c3, 338 | c4 339 | ] 340 | prompt.change(prompt_to_idx, prompt, tokenIndex) 341 | fetch.click(fetch_colors, 342 | inputs=[tgt_layout,idxs_to_alter,prompt], 343 | outputs=fetch_outputs) 344 | upload_button.upload(upload_file, upload_button, input_image) 345 | submit.click(generateImage, inputs=submit_inputs, outputs=[output_image,input_mask, input_image, output_mask,c1,c2,c3,c4]) 346 | clear.click(clear_image, outputs=[tgt_layout, output_image, input_mask, output_mask] + fetch_outputs) 347 | 348 | if __name__ == "__main__": 349 | demo.launch() 350 | -------------------------------------------------------------------------------- /canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/canvas.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class RunConfig: 5 | save_steps = 500 6 | prior_loss_weight = 1.0 7 | with_prior_preservation = True 8 | freeze_model = "crossattn_kv" 9 | pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" 10 | revision = None 11 | tokenizer_name = None 12 | iv_modifier_tokens = None 13 | iv_initializer_tokens = None 14 | output_dir = None 15 | seed = 42 16 | resolution = 512 17 | iv_train_batch_size = 4 18 | ft_train_batch_size = 1 19 | iv_max_train_steps = 200 20 | ft_max_train_steps = 800 21 | iv_gradient_accumulation_steps = 1 22 | ft_gradient_accumulation_steps = 2 23 | gradient_checkpointing = False 24 | iv_lr = 0.0005 25 | ft_lr = 1e-05 26 | scale_lr = True 27 | lr_scheduler = "constant" 28 | lr_warmup_steps = 0 29 | dataloader_num_workers = 0 30 | adam_beta1 = 0.9 31 | adam_beta2 = 0.999 32 | adam_weight_decay = 0.01 33 | adam_epsilon = 1e-08 34 | push_to_hub = False 35 | hub_token = None 36 | hub_model_id = None 37 | logging_dir = "logs" 38 | mixed_precision = "no" 39 | allow_tf32 = False 40 | report_to = "tensorboard" 41 | train_prompt = None 42 | num_validation_images = 1 43 | local_rank = -1 44 | checkpointing_steps = 500 45 | checkpoints_total_limit = None 46 | enable_xformers_memory_efficient_attention = False 47 | image_path = None 48 | mask_image = None 49 | semantic_dict = None 50 | addition_tokens = None 51 | lambda_factor = 0 52 | -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/figures/teaser.jpg -------------------------------------------------------------------------------- /pipeline_layout.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Tuple 2 | import torch 3 | from utils.ptp_utils import AttentionStore, aggregate_attention 4 | import numpy as np 5 | from pipeline_stable_diffusion import StableDiffusionPipeline 6 | from diffusers import LMSDiscreteScheduler 7 | ######################################### 8 | ''' 9 | This is modified from Attend-and-Excite/pipeline_attend_and_excite.py 10 | https://github.com/yuval-alaluf/Attend-and-Excite 11 | ''' 12 | ######################################### 13 | 14 | def get_out_loss(image, mask): 15 | mask_region = image*mask 16 | out_of_region_loss = 1-mask_region.sum()/image.sum() 17 | return out_of_region_loss 18 | 19 | class LayoutPipeline(StableDiffusionPipeline): 20 | @staticmethod 21 | def _compute_out_losses(masks: List, 22 | attention_maps: torch.Tensor, 23 | indices_to_alter: List[int]) -> List[torch.Tensor]: 24 | """ Computes the maximum attention value for each of the tokens we wish to alter. """ 25 | attention_for_text = attention_maps[:, :, 1:-1] 26 | attention_for_text *= 100 27 | attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) 28 | 29 | # Shift indices since we removed the first token 30 | indices_to_alter = [index - 1 for index in indices_to_alter] 31 | 32 | # Extract out of region loss for each object 33 | out_losses = [] 34 | for mask, i in zip(masks, indices_to_alter): 35 | image = attention_for_text[:, :, i] 36 | metrics = get_out_loss(image, mask) 37 | out_losses.append(metrics) 38 | return out_losses 39 | 40 | def _aggregate_and_get_out_losses(self, masks: List, 41 | attention_store: AttentionStore, 42 | indices_to_alter: List[int], 43 | attention_res: int = 16): 44 | """ Aggregates the attention for each token and computes the max activation value for each token to alter. """ 45 | attention_maps = aggregate_attention( 46 | attention_store=attention_store, 47 | res=attention_res, 48 | from_where=("up", "down", "mid"), 49 | is_cross=True, 50 | select=0) 51 | out_losses = self._compute_out_losses( 52 | masks=masks, 53 | attention_maps=attention_maps, 54 | indices_to_alter=indices_to_alter) 55 | return out_losses 56 | 57 | @staticmethod 58 | def _compute_loss(losses: List[torch.Tensor], return_losses: bool = False) -> torch.Tensor: 59 | """ Computes the mean + max out of region loss. """ 60 | loss = sum(losses)/len(losses) + max(losses) 61 | if return_losses: 62 | return loss, losses 63 | else: 64 | return loss 65 | 66 | @staticmethod 67 | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: 68 | """ Update the latent according to the computed loss. """ 69 | grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] 70 | latents = latents - step_size * grad_cond 71 | return latents 72 | 73 | def _perform_iterative_refinement_step(self, 74 | masks: List, 75 | latents: torch.Tensor, 76 | indices_to_alter: List[int], 77 | loss: torch.Tensor, 78 | threshold: float, 79 | text_embeddings: torch.Tensor, 80 | text_input, 81 | attention_store: AttentionStore, 82 | step_size: float, 83 | t: int, 84 | attention_res: int = 16, 85 | max_refinement_steps: int = 20): 86 | """ 87 | Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent 88 | code according to our loss objective until the given threshold is reached for all tokens. 89 | """ 90 | iteration = 0 91 | target_loss = max(0, 1. - threshold) 92 | while loss > target_loss: 93 | iteration += 1 94 | 95 | latents = latents.clone().detach().requires_grad_(True) 96 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 97 | self.unet.zero_grad() 98 | 99 | # Get max activation value for each subject token 100 | out_losses = self._aggregate_and_get_out_losses( 101 | masks=masks, 102 | attention_store=attention_store, 103 | indices_to_alter=indices_to_alter, 104 | attention_res=attention_res) 105 | 106 | loss, losses = self._compute_loss(out_losses, return_losses=True) 107 | 108 | if loss != 0: 109 | latents = self._update_latent(latents, loss, step_size) 110 | 111 | try: 112 | low_token = np.argmax([l.item() if type(l) != int else l for l in losses]) 113 | except Exception as e: 114 | print(e) # catch edge case :) 115 | low_token = np.argmax(losses) 116 | 117 | low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]]) 118 | print(f'\t Try {iteration}. {low_word} has the max losses of {out_losses[low_token]}') 119 | 120 | if iteration >= max_refinement_steps: 121 | print(f'\t Exceeded max number of iterations ({max_refinement_steps})! ' 122 | f'Finished with the max loss: {out_losses[low_token]}') 123 | break 124 | 125 | # Run one more time but don't compute gradients and update the latents. 126 | # We just need to compute the new loss - the grad update will occur below 127 | latents = latents.clone().detach().requires_grad_(True) 128 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 129 | self.unet.zero_grad() 130 | 131 | # Get max activation value for each subject token 132 | out_losses = self._aggregate_and_get_out_losses( 133 | masks=masks, 134 | attention_store=attention_store, 135 | indices_to_alter=indices_to_alter, 136 | attention_res=attention_res) 137 | loss, losses = self._compute_loss(out_losses, return_losses=True) 138 | print(f"\t Finished with max + mean loss of: {loss}") 139 | return loss, latents, out_losses 140 | 141 | def encode_text(self, prompt): 142 | text_input = self.tokenizer( 143 | prompt, 144 | padding="max_length", 145 | max_length=self.tokenizer.model_max_length, 146 | truncation=True, 147 | return_tensors="pt", 148 | ) 149 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 150 | max_length = text_input.input_ids.shape[-1] 151 | uncond_input = self.tokenizer( 152 | [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" 153 | ) 154 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 155 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 156 | return text_embeddings 157 | 158 | @torch.no_grad() 159 | def __call__( 160 | self, 161 | max_refinement_steps: int, 162 | prompt: Union[str, List[str]], 163 | attention_store: AttentionStore, 164 | indices_to_alter: List[int], 165 | attention_res: int = 16, 166 | height: Optional[int] = 512, 167 | width: Optional[int] = 512, 168 | num_inference_steps: Optional[int] = 50, 169 | guidance_scale: Optional[float] = 7.5, 170 | eta: Optional[float] = 0.0, 171 | generator: Optional[torch.Generator] = None, 172 | latents: Optional[torch.FloatTensor] = None, 173 | output_type: Optional[str] = "pil", 174 | return_dict: bool = True, 175 | max_iter_to_alter: Optional[int] = 25, 176 | run_standard_sd: bool = False, 177 | thresholds: Optional[dict] = {0: 0.05, 10: 0.5, 20: 0.8}, 178 | scale_factor: int = 20, 179 | scale_range: Tuple[float, float] = (1., 0.5), 180 | masks: List = [], 181 | blend_dict: dict = {}, 182 | **kwargs): 183 | 184 | text_embeddings, text_input, latents, do_classifier_free_guidance, extra_step_kwargs = self._setup_inference( 185 | prompt=prompt, 186 | height=height, 187 | width=width, 188 | num_inference_steps=num_inference_steps, 189 | guidance_scale=guidance_scale, 190 | eta=eta, 191 | generator=generator, 192 | latents=latents, **kwargs 193 | ) 194 | 195 | scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps)) 196 | 197 | if max_iter_to_alter is None: 198 | max_iter_to_alter = len(self.scheduler.timesteps) + 1 199 | 200 | blend_mask = blend_dict["blend_mask"].repeat(1,4,1,1) 201 | inversion_latents = blend_dict["inversion_latents"] 202 | 203 | for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): 204 | 205 | with torch.enable_grad(): 206 | 207 | latents = latents.clone().detach().requires_grad_(True) 208 | 209 | # Forward pass of denoising with text conditioning 210 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 211 | self.unet.zero_grad() 212 | 213 | # Get out region loss of each object 214 | out_losses = self._aggregate_and_get_out_losses( 215 | masks=masks, 216 | attention_store=attention_store, 217 | indices_to_alter=indices_to_alter, 218 | attention_res=attention_res) 219 | 220 | if not run_standard_sd: 221 | # Calculate the mean + max loss 222 | loss = self._compute_loss(out_losses) 223 | 224 | # If this is an iterative refinement step, verify we have reached the desired threshold for all 225 | if i in thresholds.keys() and loss > 1. - thresholds[i]: 226 | del noise_pred_text 227 | torch.cuda.empty_cache() 228 | loss, latents, out_losses = self._perform_iterative_refinement_step( 229 | masks=masks, 230 | latents=latents, 231 | indices_to_alter=indices_to_alter, 232 | loss=loss, 233 | threshold=thresholds[i], 234 | text_embeddings=text_embeddings, 235 | text_input=text_input, 236 | attention_store=attention_store, 237 | step_size=scale_factor * np.sqrt(scale_range[i]), 238 | t=t, 239 | attention_res=attention_res, 240 | max_refinement_steps = max_refinement_steps) 241 | 242 | # Perform gradient update 243 | if i < max_iter_to_alter: 244 | loss = self._compute_loss(out_losses) 245 | if loss != 0: 246 | latents = self._update_latent(latents=latents, loss=loss, 247 | step_size=scale_factor * np.sqrt(scale_range[i])) 248 | print(f'Iteration {i} | Loss: {loss:0.4f}') 249 | 250 | # blending(for the overlap background region) 251 | if i < blend_dict["blend_steps"]: 252 | latents[blend_mask] = inversion_latents[i][blend_mask] 253 | 254 | noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample 255 | noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample 256 | 257 | # Perform guidance 258 | if do_classifier_free_guidance: 259 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 260 | 261 | # Compute the previous noisy sample x_t -> x_t-1 262 | if isinstance(self.scheduler, LMSDiscreteScheduler): 263 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 264 | else: 265 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 266 | 267 | 268 | outputs = self._prepare_output(latents, output_type, return_dict) 269 | 270 | return outputs 271 | -------------------------------------------------------------------------------- /pipeline_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 7 | from diffusers.pipeline_utils import DiffusionPipeline 8 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 9 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 10 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler 11 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 12 | 13 | 14 | class StableDiffusionPipeline(DiffusionPipeline): 15 | def __init__( 16 | self, 17 | vae: AutoencoderKL, 18 | text_encoder: CLIPTextModel, 19 | tokenizer: CLIPTokenizer, 20 | unet: UNet2DConditionModel, 21 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 22 | safety_checker: StableDiffusionSafetyChecker, 23 | feature_extractor: CLIPFeatureExtractor, 24 | ): 25 | super().__init__() 26 | self.register_modules( 27 | vae=vae, 28 | text_encoder=text_encoder, 29 | tokenizer=tokenizer, 30 | unet=unet, 31 | scheduler=scheduler, 32 | safety_checker=safety_checker, 33 | feature_extractor=feature_extractor, 34 | ) 35 | 36 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 37 | if slice_size == "auto": 38 | # half the attention head size is usually a good trade-off between 39 | # speed and memory 40 | slice_size = self.unet.config.attention_head_dim // 2 41 | self.unet.set_attention_slice(slice_size) 42 | 43 | def disable_attention_slicing(self): 44 | # set slice_size = `None` to disable `attention slicing` 45 | self.enable_attention_slicing(None) 46 | 47 | def _setup_inference(self, 48 | prompt: Union[str, List[str]], 49 | height: Optional[int] = 512, 50 | width: Optional[int] = 512, 51 | num_inference_steps: Optional[int] = 50, 52 | guidance_scale: Optional[float] = 7.5, 53 | eta: Optional[float] = 0.0, 54 | generator: Optional[torch.Generator] = None, 55 | latents: Optional[torch.FloatTensor] = None, 56 | **kwargs): 57 | """ Setup the pipeline for inference. """ 58 | 59 | if "torch_device" in kwargs: 60 | device = kwargs.pop("torch_device") 61 | warnings.warn( 62 | "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." 63 | " Consider using `pipe.to(torch_device)` instead." 64 | ) 65 | 66 | # Set device as before (to be removed in 0.3.0) 67 | if device is None: 68 | device = "cuda" if torch.cuda.is_available() else "cpu" 69 | self.to(device) 70 | 71 | if isinstance(prompt, str): 72 | batch_size = 1 73 | elif isinstance(prompt, list): 74 | batch_size = len(prompt) 75 | else: 76 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 77 | 78 | if height % 8 != 0 or width % 8 != 0: 79 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 80 | 81 | # get prompt text embeddings 82 | text_input = self.tokenizer( 83 | prompt, 84 | padding="max_length", 85 | max_length=self.tokenizer.model_max_length, 86 | truncation=True, 87 | return_tensors="pt", 88 | ) 89 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 90 | 91 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 92 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 93 | # corresponds to doing no classifier free guidance. 94 | do_classifier_free_guidance = guidance_scale > 1.0 95 | # get unconditional embeddings for classifier free guidance 96 | if do_classifier_free_guidance: 97 | max_length = text_input.input_ids.shape[-1] 98 | uncond_input = self.tokenizer( 99 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 100 | ) 101 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 102 | 103 | # For classifier free guidance, we need to do two forward passes. 104 | # Here we concatenate the unconditional and text embeddings into a single batch 105 | # to avoid doing two forward passes 106 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 107 | 108 | # get the initial random noise unless the user supplied it 109 | # Unlike in other pipelines, latents need to be generated in the target device 110 | # for 1-to-1 results reproducibility with the CompVis implementation. 111 | # However this currently doesn't work in `mps`. 112 | latents_device = "cpu" if self.device.type == "mps" else self.device 113 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) 114 | if latents is None: 115 | latents = torch.randn( 116 | latents_shape, 117 | generator=generator, 118 | device=latents_device, 119 | ) 120 | else: 121 | if latents.shape != latents_shape: 122 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 123 | latents = latents.to(self.device) 124 | 125 | # set timesteps 126 | accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) 127 | extra_set_kwargs = {} 128 | if accepts_offset: 129 | extra_set_kwargs["offset"] = 1 130 | 131 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 132 | 133 | # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas 134 | if isinstance(self.scheduler, LMSDiscreteScheduler): 135 | latents = latents * self.scheduler.sigmas[0] 136 | 137 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 138 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 139 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 140 | # and should be between [0, 1] 141 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 142 | extra_step_kwargs = {} 143 | if accepts_eta: 144 | extra_step_kwargs["eta"] = eta 145 | 146 | return text_embeddings, text_input, latents, do_classifier_free_guidance, extra_step_kwargs 147 | 148 | def _prepare_output(self, latents: torch.Tensor, 149 | output_type: Optional[str] = "pil", 150 | return_dict: bool = True) -> StableDiffusionPipelineOutput: 151 | """ Given the final latent code, generate the output image. """ 152 | # scale and decode the image latents with vae 153 | latents = 1 / 0.18215 * latents 154 | image = self.vae.decode(latents).sample 155 | 156 | image = (image / 2 + 0.5).clamp(0, 1) 157 | image = image.cpu().permute(0, 2, 3, 1).numpy() 158 | 159 | # run safety checker 160 | safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) 161 | image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) 162 | 163 | if output_type == "pil": 164 | image = self.numpy_to_pil(image) 165 | 166 | if not return_dict: 167 | return (image, has_nsfw_concept) 168 | 169 | output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 170 | return output 171 | 172 | def save_pretrained(self, save_path, only_text_inv, to_merge, modifier_tokens): 173 | delta_dict = {} 174 | if only_text_inv: 175 | key = modifier_tokens[0] 176 | delta_dict['modifier_token'] = {key:to_merge[key]} 177 | else: 178 | delta_dict['modifier_token'] = to_merge 179 | delta_dict['unet'] = {} 180 | for name, params in self.unet.named_parameters(): 181 | if 'attn2.to_k' in name or 'attn2.to_v' in name: 182 | delta_dict['unet'][name] = params.cpu().clone() 183 | 184 | torch.save(delta_dict, save_path) 185 | 186 | @torch.no_grad() 187 | def __call__( 188 | self, 189 | prompt: Union[str, List[str]], 190 | height: Optional[int] = 512, 191 | width: Optional[int] = 512, 192 | num_inference_steps: Optional[int] = 50, 193 | guidance_scale: Optional[float] = 7.5, 194 | eta: Optional[float] = 0.0, 195 | generator: Optional[torch.Generator] = None, 196 | latents: Optional[torch.FloatTensor] = None, 197 | output_type: Optional[str] = "pil", 198 | return_dict: bool = True, 199 | **kwargs): 200 | 201 | text_embeddings, text_input, latents, do_classifier_free_guidance, extra_step_kwargs = self._setup_inference( 202 | prompt=prompt, 203 | height=height, 204 | width=width, 205 | num_inference_steps=num_inference_steps, 206 | guidance_scale=guidance_scale, 207 | eta=eta, 208 | generator=generator, 209 | latents=latents, **kwargs 210 | ) 211 | 212 | for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): 213 | 214 | # expand the latents if we are doing classifier free guidance 215 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 216 | if isinstance(self.scheduler, LMSDiscreteScheduler): 217 | sigma = self.scheduler.sigmas[i] 218 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 219 | latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) 220 | 221 | # predict the noise residual 222 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 223 | 224 | # perform guidance 225 | if do_classifier_free_guidance: 226 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 227 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 228 | 229 | # compute the previous noisy sample x_t -> x_t-1 230 | if isinstance(self.scheduler, LMSDiscreteScheduler): 231 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 232 | else: 233 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 234 | 235 | outputs = self._prepare_output(latents, output_type, return_dict) 236 | return outputs 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clip-retrieval==2.0.0 2 | accelerate==0.19.0 3 | diffusers==0.14.0 4 | wandb==0.15.3 5 | transformers==4.29.2 6 | torchvision==0.15.2 7 | opencv-python==4.7.0.72 8 | gradio==3.43.2 -------------------------------------------------------------------------------- /run_learning.py: -------------------------------------------------------------------------------- 1 | from single_image_learning import run 2 | import argparse 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import json 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 10 | parser.add_argument( 11 | "--save_steps", 12 | type=int, 13 | default=500, 14 | help="Save learned_embeds.bin every X updates steps.", 15 | ) 16 | parser.add_argument( 17 | "--prior_loss_weight", 18 | type=float, 19 | default=1.0, 20 | help="The weight for prior perservation", 21 | ) 22 | parser.add_argument( 23 | "--with_prior_preservation", 24 | default=False, 25 | action="store_true", 26 | help="with prior preservation", 27 | ) 28 | parser.add_argument( 29 | "--freeze_model", 30 | type=str, 31 | default='crossattn_kv', 32 | help="crossattn to enable fine-tuning of all key, value, query matrices", 33 | ) 34 | parser.add_argument( 35 | "--pretrained_model_name_or_path", 36 | type=str, 37 | default="CompVis/stable-diffusion-v1-4", 38 | help="Path to pretrained model or model identifier from huggingface.co/models.", 39 | ) 40 | parser.add_argument( 41 | "--revision", 42 | type=str, 43 | default=None, 44 | help="Revision of pretrained model identifier from huggingface.co/models.", 45 | ) 46 | parser.add_argument( 47 | "--tokenizer_name", 48 | type=str, 49 | default=None, 50 | help="Pretrained tokenizer name or path if not the same as model_name", 51 | ) 52 | parser.add_argument( 53 | "--iv_modifier_tokens", 54 | type=str, 55 | default=None, 56 | help="A token to use as a placeholder for the concept.", 57 | ) 58 | parser.add_argument( 59 | "--iv_initializer_tokens", type=str, default=None, help="A token to use as initializer word." 60 | ) 61 | parser.add_argument( 62 | "--output_dir", 63 | type=str, 64 | help="The output directory where the model predictions and checkpoints will be written.", 65 | ) 66 | parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") 67 | parser.add_argument( 68 | "--resolution", 69 | type=int, 70 | default=512, 71 | help=( 72 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 73 | " resolution" 74 | ), 75 | ) 76 | parser.add_argument( 77 | "--iv_train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 78 | ) 79 | parser.add_argument( 80 | "--ft_train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader." 81 | ) 82 | parser.add_argument( 83 | "--iv_max_train_steps", 84 | type=int, 85 | default=200, 86 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 87 | ) 88 | parser.add_argument( 89 | "--ft_max_train_steps", 90 | type=int, 91 | default=800, 92 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 93 | ) 94 | parser.add_argument( 95 | "--iv_gradient_accumulation_steps", 96 | type=int, 97 | default=1, 98 | help="Number of updates steps to accumulate before performing a backward/update pass.", 99 | ) 100 | parser.add_argument( 101 | "--ft_gradient_accumulation_steps", 102 | type=int, 103 | default=1, 104 | help="Number of updates steps to accumulate before performing a backward/update pass.", 105 | ) 106 | parser.add_argument( 107 | "--gradient_checkpointing", 108 | action="store_true", 109 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 110 | ) 111 | parser.add_argument( 112 | "--iv_lr", 113 | type=float, 114 | default=5e-4, 115 | help="Initial learning rate (after the potential warmup period) to use.", 116 | ) 117 | parser.add_argument( 118 | "--ft_lr", 119 | type=float, 120 | default=1e-5, 121 | help="Initial learning rate (after the potential warmup period) to use.", 122 | ) 123 | parser.add_argument( 124 | "--scale_lr", 125 | action="store_true", 126 | default=False, 127 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 128 | ) 129 | parser.add_argument( 130 | "--lr_scheduler", 131 | type=str, 132 | default="constant", 133 | help=( 134 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 135 | ' "constant", "constant_with_warmup"]' 136 | ), 137 | ) 138 | parser.add_argument( 139 | "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 140 | ) 141 | parser.add_argument( 142 | "--dataloader_num_workers", 143 | type=int, 144 | default=0, 145 | help=( 146 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 147 | ), 148 | ) 149 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 150 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 151 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 152 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 153 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 154 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 155 | parser.add_argument( 156 | "--hub_model_id", 157 | type=str, 158 | default=None, 159 | help="The name of the repository to keep in sync with the local `output_dir`.", 160 | ) 161 | parser.add_argument( 162 | "--logging_dir", 163 | type=str, 164 | default="logs", 165 | help=( 166 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 167 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 168 | ), 169 | ) 170 | parser.add_argument( 171 | "--mixed_precision", 172 | type=str, 173 | default="no", 174 | choices=["no", "fp16", "bf16"], 175 | help=( 176 | "Whether to use mixed precision. Choose" 177 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 178 | "and an Nvidia Ampere GPU." 179 | ), 180 | ) 181 | parser.add_argument( 182 | "--allow_tf32", 183 | action="store_true", 184 | help=( 185 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 186 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 187 | ), 188 | ) 189 | parser.add_argument( 190 | "--report_to", 191 | type=str, 192 | default="wandb", 193 | ) 194 | parser.add_argument( 195 | "--train_prompt", 196 | type=str, 197 | default=None, 198 | help="A prompt that is used during training", 199 | ) 200 | parser.add_argument( 201 | "--num_validation_images", 202 | type=int, 203 | default=1, 204 | help="Number of images that should be generated during validation with `validation_prompt`.", 205 | ) 206 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 207 | parser.add_argument( 208 | "--checkpointing_steps", 209 | type=int, 210 | default=500, 211 | help=( 212 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 213 | " training using `--resume_from_checkpoint`." 214 | ), 215 | ) 216 | parser.add_argument( 217 | "--checkpoints_total_limit", 218 | type=int, 219 | default=None, 220 | help=( 221 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 222 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 223 | " for more docs" 224 | ), 225 | ) 226 | parser.add_argument( 227 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 228 | ) 229 | parser.add_argument( 230 | "--image_path", 231 | type=str, 232 | help="Path to the input image", 233 | ) 234 | parser.add_argument( 235 | "--mask_image", 236 | type=str, 237 | help="Path to the mask image", 238 | ) 239 | parser.add_argument( 240 | "--semantic_dict", 241 | type=str, 242 | help="Path to the semantic dict", 243 | ) 244 | parser.add_argument( 245 | "--addition_tokens", 246 | type=str, 247 | help="1-3 additional tokens", 248 | ) 249 | parser.add_argument( 250 | "--lambda_factor", 251 | type=float, 252 | default=0, 253 | help="preserve the trained area", 254 | ) 255 | 256 | args = parser.parse_args() 257 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 258 | if env_local_rank != -1 and env_local_rank != args.local_rank: 259 | args.local_rank = env_local_rank 260 | 261 | return args 262 | 263 | 264 | def get_mask(sem_dict_path, image_path): 265 | sem_dict = json.load(open(sem_dict_path)) 266 | image = Image.open(image_path) 267 | nc = len(image.getcolors(10000)) 268 | assert nc==(len(sem_dict)) 269 | img = np.array(image) 270 | masks = [] 271 | for k in list(sem_dict.keys())[:-1]: 272 | color = np.array(list(sem_dict[k])) 273 | color = np.all(img == color, axis=-1) 274 | assert color.sum() != 0 275 | masks.append(color) 276 | masks = np.stack(masks)[:,None,:,:] 277 | return masks 278 | 279 | 280 | def main(): 281 | args = parse_args() 282 | 283 | args.iv_initializer_tokens = args.iv_initializer_tokens.split("+") 284 | args.iv_modifier_tokens = [f"<{token}>" for token in args.iv_initializer_tokens] 285 | args.iv_mask = get_mask(args.semantic_dict, args.mask_image) 286 | 287 | args.ft_initializer_tokens = args.addition_tokens.split("+") 288 | args.ft_modifier_tokens = [f"" for i in range(len(args.ft_initializer_tokens))] 289 | tail_str = " ".join(args.ft_modifier_tokens) 290 | 291 | args.reg_dirs = [] 292 | for i in args.iv_initializer_tokens: 293 | path = f"real_reg/samples_{i}" 294 | if os.path.exists(path): 295 | args.reg_dirs.append(path) 296 | 297 | train_prompts = [] 298 | all_replaced_prompt = args.train_prompt 299 | for initializer_token, modifier_token in zip(args.iv_initializer_tokens, args.iv_modifier_tokens): 300 | all_replaced_prompt = all_replaced_prompt.replace(initializer_token, modifier_token) 301 | train_prompts.append(args.train_prompt.replace(initializer_token, modifier_token)) 302 | train_prompts.append(all_replaced_prompt + " " + tail_str) 303 | args.train_prompt = train_prompts 304 | 305 | run(args) 306 | 307 | if __name__ == "__main__": 308 | main() 309 | -------------------------------------------------------------------------------- /run_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Tuple 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import argparse 7 | import pprint 8 | import json 9 | 10 | from pipeline_layout import LayoutPipeline 11 | from utils import ptp_utils, custom_utils 12 | from utils.ptp_utils import AttentionStore 13 | from diffusers import DDIMScheduler 14 | from utils.ddim_inversion import create_inversion_latents 15 | 16 | def get_indices_to_alter(stable, prompt: str) -> List[int]: 17 | token_idx_to_word = {idx: stable.tokenizer.decode(t) 18 | for idx, t in enumerate(stable.tokenizer(prompt)['input_ids']) 19 | if 0 < idx < len(stable.tokenizer(prompt)['input_ids']) - 1} 20 | pprint.pprint(token_idx_to_word) 21 | token_indices = input("Please enter the a comma-separated list indices of the tokens you wish to " 22 | "alter (e.g., 2,5): ") 23 | token_indices = [int(i) for i in token_indices.split(",")] 24 | print(f"Altering tokens: {[token_idx_to_word[i] for i in token_indices]}") 25 | return token_indices 26 | 27 | def run(pipe, 28 | prompt: str, 29 | guidance_scale: float, 30 | eta: float, 31 | n_inference_steps: int, 32 | controller: AttentionStore, 33 | indices_to_alter: List[int], 34 | generator: torch.Generator, 35 | run_standard_sd: bool = False, 36 | scale_factor: int = 20, 37 | thresholds: Dict[int, float] = {0:0.6, 10: 0.7, 20: 0.8}, 38 | max_iter_to_alter: int = 25, 39 | max_refinement_steps: int = 20, 40 | scale_range: Tuple[float, float] = (1., 0.5), 41 | attention_res: int = 16, 42 | masks: List = [], 43 | blend_dict: dict = {}): 44 | if controller is not None: 45 | ptp_utils.register_attention_control(pipe, controller) 46 | outputs = pipe(masks=masks, 47 | blend_dict=blend_dict, 48 | max_refinement_steps=max_refinement_steps, 49 | prompt=prompt, 50 | attention_store=controller, 51 | indices_to_alter=indices_to_alter, 52 | attention_res=attention_res, 53 | guidance_scale=guidance_scale, 54 | generator=generator, 55 | eta = eta, 56 | num_inference_steps=n_inference_steps, 57 | max_iter_to_alter=max_iter_to_alter, 58 | run_standard_sd=run_standard_sd, 59 | thresholds=thresholds, 60 | scale_factor=scale_factor, 61 | scale_range=scale_range) 62 | image = outputs.images[0] 63 | return image 64 | 65 | def get_masks(fp, sem_dict_path): 66 | masks = [] 67 | sem_dict = json.load(open(sem_dict_path)) 68 | for k in list(sem_dict.keys())[:-1]: 69 | pil_image = Image.open(fp).resize((16,16),resample=0) 70 | img = np.array(pil_image) 71 | color = np.array(sem_dict[k]) 72 | mask = torch.Tensor(np.all(img == color, axis=-1)).cuda() 73 | assert mask.sum()!=0 74 | masks.append(mask) 75 | return masks 76 | 77 | def get_blend_mask(mask_image, target_layout, sem_dict_path): 78 | m1 = np.array(Image.open(mask_image)) 79 | m2 = np.array(Image.open(target_layout)) 80 | bg_color = np.array(json.load(open(sem_dict_path))["background"]) 81 | blend_mask = np.all(m1 == bg_color, axis=-1) & np.all(m2 == bg_color, axis=-1) 82 | blend_mask = torch.from_numpy(blend_mask).cuda() 83 | return blend_mask 84 | 85 | def load_model(delta_ckpt): 86 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 87 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) 88 | pipe = LayoutPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to(device) 89 | 90 | tokenizer = pipe.tokenizer 91 | if delta_ckpt is not None: 92 | # load the custom cross-attention matrix 93 | custom_utils.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, delta_ckpt) 94 | return pipe, tokenizer 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser(description="Layout Control Inference.") 98 | parser.add_argument( 99 | "--prompt", 100 | type=str, 101 | required=True, 102 | help="Prompt for generation", 103 | ) 104 | parser.add_argument( 105 | "--image_path", 106 | type=str, 107 | required=True, 108 | help="Path to input image(use for blending)", 109 | ) 110 | parser.add_argument( 111 | "--output_dir", 112 | type=str, 113 | required=True, 114 | help="Path to generated images", 115 | ) 116 | parser.add_argument( 117 | "--mask_image", 118 | type=str, 119 | required=True, 120 | help="Path to mask image", 121 | ) 122 | parser.add_argument( 123 | "--target_layout", 124 | type=str, 125 | required=True, 126 | help="Target Layout", 127 | ) 128 | parser.add_argument( 129 | "--semantic_dict", 130 | type=str, 131 | required=True, 132 | help="semantic dict", 133 | ) 134 | parser.add_argument( 135 | "--delta_ckpt", 136 | type=str, 137 | default=None, 138 | help="Path to trained model", 139 | ) 140 | parser.add_argument( 141 | "--blend_steps", 142 | type=int, 143 | default=15, 144 | help="Number of blending steps", 145 | ) 146 | args = parser.parse_args() 147 | return args 148 | 149 | if __name__ == "__main__": 150 | n_inference_steps = 50 151 | guidance_scale = 5 152 | eta = 0 153 | max_iter_to_alter = 25 154 | max_refinement_steps = 40 155 | scale_factor = 20 156 | run_standard_sd = False 157 | args = parse_args() 158 | os.makedirs(args.output_dir, exist_ok=True) 159 | pipe, tokenizer = load_model(args.delta_ckpt) 160 | 161 | token_indices = get_indices_to_alter(pipe, args.prompt) 162 | masks = get_masks(args.target_layout, args.semantic_dict) 163 | 164 | blend_mask = get_blend_mask(args.mask_image, args.target_layout, args.semantic_dict) 165 | inversion_latents = create_inversion_latents(pipe, args.image_path, args.prompt, guidance_scale, n_inference_steps) 166 | blend_dict = { 167 | "blend_mask":blend_mask, 168 | "inversion_latents":inversion_latents, 169 | "blend_steps":args.blend_steps 170 | } 171 | 172 | for i, seed in enumerate([0,8,88,888,8888]): 173 | g = torch.Generator('cuda').manual_seed(seed) 174 | controller = AttentionStore() 175 | image = run(pipe, 176 | prompt=args.prompt, 177 | guidance_scale = guidance_scale, 178 | n_inference_steps = n_inference_steps, 179 | eta = eta, 180 | controller=controller, 181 | indices_to_alter= token_indices, 182 | generator=g, 183 | run_standard_sd=run_standard_sd, 184 | scale_factor = scale_factor, 185 | thresholds = {0:0.6, 10: 0.7, 20: 0.8}, 186 | max_iter_to_alter=max_iter_to_alter, 187 | max_refinement_steps=max_refinement_steps, 188 | scale_range = (1., 0.5), 189 | masks = masks, 190 | blend_dict = blend_dict, 191 | ) 192 | 193 | image_name = os.path.join(args.output_dir, f"{seed}_prior.png") 194 | image.save(image_name) 195 | 196 | -------------------------------------------------------------------------------- /samples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/.DS_Store -------------------------------------------------------------------------------- /samples/image1/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/.DS_Store -------------------------------------------------------------------------------- /samples/image1/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/image.png -------------------------------------------------------------------------------- /samples/image1/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/mask.png -------------------------------------------------------------------------------- /samples/image1/semantic_dict.json: -------------------------------------------------------------------------------- 1 | {"cat": [251,216,1], "pot": [75,254,1], "plant": [251,1,1], "background": [254,254,254]} -------------------------------------------------------------------------------- /samples/image1/target_layout/layout1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/target_layout/layout1.png -------------------------------------------------------------------------------- /samples/image1/target_layout/layout2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/target_layout/layout2.png -------------------------------------------------------------------------------- /samples/image1/target_layout/layout3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/samples/image1/target_layout/layout3.png -------------------------------------------------------------------------------- /single_image_learning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import logging 18 | import math 19 | import os 20 | import random 21 | import itertools 22 | import warnings 23 | from pathlib import Path 24 | from typing import Optional 25 | 26 | import numpy as np 27 | import PIL 28 | import torch 29 | import torch.nn.functional as F 30 | import torch.utils.checkpoint 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.logging import get_logger 34 | from accelerate.utils import ProjectConfiguration, set_seed 35 | from huggingface_hub import HfFolder, Repository, create_repo, whoami 36 | 37 | from packaging import version 38 | from PIL import Image 39 | from torch.utils.data import Dataset 40 | from torchvision import transforms 41 | from tqdm.auto import tqdm 42 | from transformers import CLIPTextModel, CLIPTokenizer 43 | 44 | import diffusers 45 | from diffusers import ( 46 | AutoencoderKL, 47 | DDPMScheduler, 48 | DPMSolverMultistepScheduler, 49 | UNet2DConditionModel, 50 | ) 51 | from diffusers.optimization import get_scheduler 52 | from diffusers.utils import is_wandb_available 53 | from diffusers.utils.import_utils import is_xformers_available 54 | from diffusers.models.cross_attention import CrossAttention 55 | from utils.custom_utils import CustomDiffusionAttnProcessor, set_use_memory_efficient_attention_xformers 56 | from pipeline_stable_diffusion import StableDiffusionPipeline 57 | 58 | if is_wandb_available(): 59 | import wandb 60 | 61 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 62 | PIL_INTERPOLATION = { 63 | "linear": PIL.Image.Resampling.BILINEAR, 64 | "bilinear": PIL.Image.Resampling.BILINEAR, 65 | "bicubic": PIL.Image.Resampling.BICUBIC, 66 | "lanczos": PIL.Image.Resampling.LANCZOS, 67 | "nearest": PIL.Image.Resampling.NEAREST, 68 | } 69 | else: 70 | PIL_INTERPOLATION = { 71 | "linear": PIL.Image.LINEAR, 72 | "bilinear": PIL.Image.BILINEAR, 73 | "bicubic": PIL.Image.BICUBIC, 74 | "lanczos": PIL.Image.LANCZOS, 75 | "nearest": PIL.Image.NEAREST, 76 | } 77 | # ------------------------------------------------------------------------------ 78 | 79 | logger = get_logger(__name__) 80 | 81 | 82 | def log_validation(args, pipeline, accelerator, epoch, validation_prompt): 83 | logger.info( 84 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 85 | f" {validation_prompt}." 86 | ) 87 | 88 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 89 | pipeline = pipeline.to(accelerator.device) 90 | pipeline.set_progress_bar_config(disable=True) 91 | 92 | # run inference 93 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 94 | images = [] 95 | for _ in range(args.num_validation_images): 96 | with torch.autocast("cuda"): 97 | images += pipeline(validation_prompt, num_inference_steps=25, generator=generator).images 98 | for tracker in accelerator.trackers: 99 | if tracker.name == "tensorboard": 100 | np_images = np.stack([np.asarray(img) for img in images]) 101 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 102 | if tracker.name == "wandb": 103 | tracker.log( 104 | { 105 | "validation": [ 106 | wandb.Image(image, caption=f"{i}: {validation_prompt}") for i, image in enumerate(images) 107 | ] 108 | } 109 | ) 110 | del pipeline 111 | torch.cuda.empty_cache() 112 | 113 | 114 | 115 | def create_custom_diffusion(unet, freeze_model): 116 | for name, params in unet.named_parameters(): 117 | if freeze_model == 'crossattn': 118 | if 'attn2' in name: 119 | params.requires_grad = True 120 | else: 121 | params.requires_grad = False 122 | elif freeze_model == "crossattn_kv": 123 | if 'attn2.to_k' in name or 'attn2.to_v' in name: 124 | params.requires_grad = True 125 | else: 126 | params.requires_grad = False 127 | else: 128 | raise ValueError( 129 | "freeze_model argument only supports crossattn_kv or crossattn" 130 | ) 131 | 132 | # change attn class 133 | def change_attn(unet): 134 | for layer in unet.children(): 135 | if type(layer) == CrossAttention: 136 | bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__) 137 | setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method) 138 | else: 139 | change_attn(layer) 140 | 141 | change_attn(unet) 142 | unet.set_attn_processor(CustomDiffusionAttnProcessor()) 143 | return unet 144 | 145 | 146 | class SingleImageDataset(Dataset): 147 | def __init__( 148 | self, 149 | prompt, 150 | mask, 151 | tokenizer, 152 | repeats, 153 | reg_dirs, 154 | with_prior_preservation, 155 | modifier_tokens, 156 | text_inv 157 | ): 158 | self.tokenizer = tokenizer 159 | self.text_inv = text_inv 160 | self.prompt = [prompt, f"a photo of a {modifier_tokens[0]}"] 161 | self._length = 1*repeats # number of images * repeats 162 | self.mask = torch.Tensor(mask).bool().repeat(4,1,1) 163 | self.with_prior_preservation = with_prior_preservation 164 | self.class_images_path = [] 165 | 166 | self.input_ids = self.tokenizer( 167 | self.prompt, 168 | padding="max_length", 169 | truncation=True, 170 | max_length=self.tokenizer.model_max_length, 171 | return_tensors="pt", 172 | ).input_ids 173 | 174 | self.image_transforms = transforms.Compose( 175 | [ 176 | transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR), 177 | transforms.ToTensor(), 178 | transforms.Normalize([0.5], [0.5]), 179 | ] 180 | ) 181 | 182 | if with_prior_preservation: 183 | for reg_dir in reg_dirs: 184 | with open(f"{reg_dir}/images.txt", "r") as f: 185 | class_images_path = f.read().splitlines() 186 | with open(f"{reg_dir}/caption.txt", "r") as f: 187 | class_prompt = f.read().splitlines() 188 | class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] 189 | self.class_images_path.extend(class_img_path) 190 | random.shuffle(self.class_images_path) 191 | self.num_class_images = len(self.class_images_path) 192 | print("num of class image: ",self.num_class_images) 193 | 194 | def __len__(self): 195 | return self._length 196 | 197 | def __getitem__(self, idx): 198 | example = {} 199 | example["mask"] = self.mask 200 | if self.text_inv: 201 | example["input_ids"] = random.choice(self.input_ids) 202 | else: 203 | example["input_ids"] = self.input_ids[0] 204 | 205 | if self.with_prior_preservation: 206 | class_image, class_prompt = self.class_images_path[idx % self.num_class_images] 207 | class_image = Image.open(class_image) 208 | if not class_image.mode == "RGB": 209 | class_image = class_image.convert("RGB") 210 | example["class_images"] = self.image_transforms(class_image) 211 | example["class_prompt_ids"] = self.tokenizer( 212 | class_prompt, 213 | truncation=True, 214 | padding="max_length", 215 | max_length=self.tokenizer.model_max_length, 216 | return_tensors="pt", 217 | ).input_ids[0] 218 | 219 | return example 220 | 221 | 222 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 223 | if token is None: 224 | token = HfFolder.get_token() 225 | if organization is None: 226 | username = whoami(token)["name"] 227 | return f"{username}/{model_id}" 228 | else: 229 | return f"{organization}/{model_id}" 230 | 231 | 232 | def add_tokens_for_inversion(tokenizer, text_encoder, modifier_tokens, initializer_tokens): 233 | # Add the placeholder token in tokenizer 234 | modifier_token_ids = [] 235 | initializer_token_ids = [] 236 | for modifier_token,initializer_token in zip(modifier_tokens,initializer_tokens): 237 | num_added_tokens = tokenizer.add_tokens(modifier_token) 238 | if num_added_tokens == 0 and modifier_token != initializer_token: 239 | raise ValueError( 240 | f"The tokenizer already contains the token {modifier_token}. Please pass a different" 241 | " `modifier_token` that is not already in the tokenizer." 242 | ) 243 | 244 | # Convert the initializer_token, modifier_token to ids 245 | token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) 246 | 247 | # Check if initializer_token is a single token or a sequence of tokens 248 | if len(token_ids) > 1: 249 | raise ValueError("The initializer token must be a single token.") 250 | 251 | initializer_token_id = token_ids[0] 252 | modifier_token_id = tokenizer.convert_tokens_to_ids(modifier_token) 253 | initializer_token_ids.append(initializer_token_id) 254 | modifier_token_ids.append(modifier_token_id) 255 | 256 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 257 | text_encoder.resize_token_embeddings(len(tokenizer)) 258 | 259 | # Initialise the newly added placeholder token with the embeddings of the initializer token 260 | token_embeds = text_encoder.get_input_embeddings().weight.data 261 | for modifier_token_id, initializer_token_id in zip(modifier_token_ids, initializer_token_ids): 262 | token_embeds[modifier_token_id] = token_embeds[initializer_token_id] 263 | return modifier_token_ids 264 | 265 | def get_processed_image(image_path, size=512, interpolation="bicubic"): 266 | interpolation = { 267 | "linear": PIL_INTERPOLATION["linear"], 268 | "bilinear": PIL_INTERPOLATION["bilinear"], 269 | "bicubic": PIL_INTERPOLATION["bicubic"], 270 | "lanczos": PIL_INTERPOLATION["lanczos"], 271 | }[interpolation] 272 | 273 | image = Image.open(image_path) 274 | if not image.mode == "RGB": 275 | image = image.convert("RGB") 276 | 277 | image = image.resize((size, size), resample=interpolation) 278 | image = np.array(image).astype(np.uint8) 279 | image = (image / 127.5 - 1.0).astype(np.float32) 280 | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) 281 | return image 282 | 283 | 284 | def run(args): 285 | 286 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 287 | 288 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) 289 | 290 | accelerator = Accelerator( 291 | mixed_precision=args.mixed_precision, 292 | log_with=args.report_to, 293 | logging_dir=logging_dir, 294 | project_config=accelerator_project_config, 295 | ) 296 | 297 | if args.report_to == "wandb": 298 | if not is_wandb_available(): 299 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 300 | 301 | # Make one log on every process with the configuration for debugging. 302 | logging.basicConfig( 303 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 304 | datefmt="%m/%d/%Y %H:%M:%S", 305 | level=logging.INFO, 306 | ) 307 | logger.info(accelerator.state, main_process_only=False) 308 | if accelerator.is_local_main_process: 309 | transformers.utils.logging.set_verbosity_warning() 310 | diffusers.utils.logging.set_verbosity_info() 311 | else: 312 | transformers.utils.logging.set_verbosity_error() 313 | diffusers.utils.logging.set_verbosity_error() 314 | 315 | # If passed along, set the training seed now. 316 | if args.seed is not None: 317 | set_seed(args.seed) 318 | 319 | # Handle the repository creation 320 | if accelerator.is_main_process: 321 | if args.push_to_hub: 322 | if args.hub_model_id is None: 323 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 324 | else: 325 | repo_name = args.hub_model_id 326 | create_repo(repo_name, exist_ok=True, token=args.hub_token) 327 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) 328 | 329 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 330 | if "step_*" not in gitignore: 331 | gitignore.write("step_*\n") 332 | if "epoch_*" not in gitignore: 333 | gitignore.write("epoch_*\n") 334 | elif args.output_dir is not None: 335 | os.makedirs(args.output_dir, exist_ok=True) 336 | 337 | # Load tokenizer and text encoder 338 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 339 | text_encoder = CLIPTextModel.from_pretrained( 340 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 341 | ) 342 | 343 | # Load vae and unet 344 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 345 | unet = UNet2DConditionModel.from_pretrained( 346 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 347 | ) 348 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 349 | 350 | # Freeze all parameters except for the token embeddings in text encoder 351 | text_encoder.text_model.encoder.requires_grad_(False) 352 | text_encoder.text_model.final_layer_norm.requires_grad_(False) 353 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 354 | 355 | if args.enable_xformers_memory_efficient_attention: 356 | if is_xformers_available(): 357 | import xformers 358 | 359 | xformers_version = version.parse(xformers.__version__) 360 | if xformers_version == version.parse("0.0.16"): 361 | logger.warn( 362 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 363 | ) 364 | unet.enable_xformers_memory_efficient_attention() 365 | else: 366 | raise ValueError("xformers is not available. Make sure it is installed correctly") 367 | 368 | # Enable TF32 for faster training on Ampere GPUs, 369 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 370 | if args.allow_tf32: 371 | torch.backends.cuda.matmul.allow_tf32 = True 372 | 373 | # For mixed precision training we cast the unet and vae weights to half-precision 374 | # as these models are only used for inference, keeping weights in full precision is not required. 375 | weight_dtype = torch.float32 376 | if accelerator.mixed_precision == "fp16": 377 | weight_dtype = torch.float16 378 | elif accelerator.mixed_precision == "bf16": 379 | weight_dtype = torch.bfloat16 380 | 381 | # Freeze vae and move to device 382 | vae.requires_grad_(False) 383 | vae.to(accelerator.device, dtype=weight_dtype) 384 | 385 | # We need to initialize the trackers we use, and also store our configuration. 386 | # The trackers initializes automatically on the main process. 387 | if accelerator.is_main_process: 388 | accelerator.init_trackers("Single Image Learning") 389 | 390 | # Convert images to latent space 391 | pixel_values = get_processed_image(args.image_path, size = args.resolution).to(accelerator.device) 392 | latents = vae.encode(pixel_values.to(dtype = weight_dtype)).latent_dist.sample().detach() 393 | latents = latents * vae.config.scaling_factor 394 | 395 | # Run textual inversion before fine-tuning 396 | stages = [f"Textual Inversion {modifier_token}" for modifier_token in args.iv_modifier_tokens] 397 | stages.append("Fine Tuning") 398 | to_merge = {} 399 | for stage_i, stage in enumerate(stages): 400 | logger.info(f"***** Running {stage} *****") 401 | if stage != "Fine Tuning": 402 | text_inv = True 403 | learning_rate, train_batch_size, max_train_steps, gradient_accumulation_steps, with_prior_preservation = \ 404 | args.iv_lr, args.iv_train_batch_size, args.iv_max_train_steps, args.iv_gradient_accumulation_steps, False 405 | initializer_tokens = [args.iv_initializer_tokens[stage_i]] 406 | embed_output_dir = os.path.join(args.output_dir, initializer_tokens[0]) 407 | modifier_tokens = [args.iv_modifier_tokens[stage_i]] 408 | validation_prompt = f"a photo of a {modifier_tokens[0]}" 409 | mask = args.iv_mask[stage_i] 410 | unet.requires_grad_(False) 411 | else: 412 | text_inv = False 413 | learning_rate, train_batch_size, max_train_steps, gradient_accumulation_steps, with_prior_preservation = \ 414 | args.ft_lr, args.ft_train_batch_size, args.ft_max_train_steps, args.ft_gradient_accumulation_steps, args.with_prior_preservation 415 | initializer_tokens = args.ft_initializer_tokens 416 | embed_output_dir = os.path.join(args.output_dir, "fine_tune") 417 | modifier_tokens = args.ft_modifier_tokens 418 | validation_prompt = args.train_prompt[-1] 419 | mask = np.ones((1,64,64),dtype=bool) 420 | unet = create_custom_diffusion(unet, args.freeze_model) 421 | 422 | accelerator.gradient_accumulation_steps = gradient_accumulation_steps 423 | os.makedirs(embed_output_dir, exist_ok = True) 424 | 425 | # recalculate number of training epochs 426 | total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps 427 | # Log training informations 428 | if not text_inv and with_prior_preservation: 429 | # The batch size is doubled because prior images are added 430 | num_train_epochs = math.ceil(max_train_steps / (gradient_accumulation_steps * train_batch_size * 2)) 431 | logger.info(f" Num Epochs = {num_train_epochs}") 432 | logger.info(f" Instantaneous batch size per device = {train_batch_size*2}") 433 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size*2}") 434 | else: 435 | num_train_epochs = math.ceil(max_train_steps / (gradient_accumulation_steps * train_batch_size)) 436 | logger.info(f" Num Epochs = {num_train_epochs}") 437 | logger.info(f" Instantaneous batch size per device = {train_batch_size}") 438 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 439 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 440 | 441 | # Set up parameters to optimize 442 | if modifier_tokens: 443 | modifier_token_ids = add_tokens_for_inversion(tokenizer, text_encoder, modifier_tokens,initializer_tokens) 444 | if text_inv: 445 | params_to_optimize = text_encoder.get_input_embeddings().parameters() 446 | else: 447 | params_to_optimize = itertools.chain(text_encoder.get_input_embeddings().parameters() , [x[1] for x in unet.named_parameters() if ('attn2.to_k' in x[0] or 'attn2.to_v' in x[0])] ) 448 | else: 449 | modifier_token_ids = None 450 | params_to_optimize = itertools.chain([x[1] for x in unet.named_parameters() if ('attn2.to_k' in x[0] or 'attn2.to_v' in x[0])] ) 451 | 452 | if args.gradient_checkpointing: 453 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 454 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 455 | unet.train() 456 | text_encoder.gradient_checkpointing_enable() 457 | unet.enable_gradient_checkpointing() 458 | 459 | if args.scale_lr: 460 | learning_rate = ( 461 | learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes 462 | ) 463 | 464 | # Initialize the optimizer 465 | optimizer = torch.optim.AdamW( 466 | params_to_optimize, 467 | lr = learning_rate, 468 | betas=(args.adam_beta1, args.adam_beta2), 469 | weight_decay=args.adam_weight_decay, 470 | eps=args.adam_epsilon, 471 | ) 472 | 473 | # Dataset and DataLoaders creation: 474 | train_dataset = SingleImageDataset( 475 | prompt = args.train_prompt[stage_i], 476 | mask = mask, 477 | tokenizer = tokenizer, 478 | repeats = num_train_epochs*train_batch_size*gradient_accumulation_steps, 479 | reg_dirs = args.reg_dirs, 480 | with_prior_preservation = with_prior_preservation, 481 | modifier_tokens = modifier_tokens, 482 | text_inv = text_inv, 483 | ) 484 | 485 | train_dataloader = torch.utils.data.DataLoader( 486 | train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers 487 | ) 488 | 489 | if latents.shape[0] < train_batch_size: 490 | # Repeat the latents to match batch_size 491 | latents = latents.repeat_interleave(train_batch_size, dim=0) 492 | else: 493 | latents = latents[:train_batch_size] 494 | 495 | # Scheduler 496 | lr_scheduler = get_scheduler( 497 | args.lr_scheduler, 498 | optimizer=optimizer, 499 | num_warmup_steps=args.lr_warmup_steps * gradient_accumulation_steps, 500 | num_training_steps=max_train_steps * gradient_accumulation_steps, 501 | ) 502 | 503 | # Prepare everything with our `accelerator`. 504 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 505 | text_encoder, optimizer, train_dataloader, lr_scheduler 506 | ) 507 | 508 | # Move unet to device and cast to weight_dtype 509 | unet.to(accelerator.device, dtype=weight_dtype) 510 | 511 | # keep original embeddings as reference 512 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 513 | 514 | # Only show the progress bar once on each machine. 515 | progress_bar = tqdm(range(num_train_epochs), disable=not accelerator.is_local_main_process) 516 | progress_bar.set_description("Steps") 517 | progress_bar.reset() 518 | 519 | global_step = 0 520 | # Run training! 521 | for batch in train_dataloader: 522 | text_encoder.train() 523 | with accelerator.accumulate(text_encoder): 524 | if with_prior_preservation: 525 | reg_latents = vae.encode(batch["class_images"].to(dtype=weight_dtype)).latent_dist.sample() * vae.config.scaling_factor 526 | latents = torch.cat([latents,reg_latents]) 527 | input_ids = torch.cat([batch["input_ids"], batch["class_prompt_ids"]]) 528 | else: 529 | input_ids = batch["input_ids"] 530 | 531 | # Sample noise that we'll add to the latents 532 | noise = torch.randn_like(latents) 533 | bsz = latents.shape[0] 534 | 535 | # Sample a random timestep for each image 536 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 537 | timesteps = timesteps.long() 538 | 539 | # Add noise to the latents according to the noise magnitude at each timestep 540 | # (this is the forward diffusion process) 541 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 542 | 543 | # Get the text embedding for conditioning 544 | encoder_hidden_states = text_encoder(input_ids)[0].to(dtype=weight_dtype) 545 | 546 | # Predict the noise residual 547 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 548 | # Get the target for loss depending on the prediction type 549 | if noise_scheduler.config.prediction_type == "epsilon": 550 | target = noise 551 | elif noise_scheduler.config.prediction_type == "v_prediction": 552 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 553 | else: 554 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 555 | 556 | mask = batch["mask"] 557 | if text_inv: 558 | loss = F.mse_loss(model_pred[mask].float(), target[mask].float(), reduction="mean") 559 | else: 560 | # fine tuning 561 | if with_prior_preservation: 562 | latents, _ = torch.chunk(latents, 2, dim=0) 563 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 564 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 565 | target, target_prior = torch.chunk(target, 2, dim=0) 566 | # Compute instance loss 567 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 568 | # Compute prior loss 569 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 570 | # Add the prior loss to the instance loss. 571 | loss = loss + args.prior_loss_weight * prior_loss 572 | else: 573 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 574 | 575 | accelerator.backward(loss) 576 | optimizer.step() 577 | lr_scheduler.step() 578 | optimizer.zero_grad() 579 | 580 | if modifier_tokens is not None: 581 | # Let's make sure we don't update any embedding weights besides the newly added token 582 | index_no_updates = torch.arange(len(tokenizer)) != modifier_token_ids[0] 583 | for i in range(len(modifier_token_ids[1:])): 584 | index_no_updates = index_no_updates & (torch.arange(len(tokenizer)) != modifier_token_ids[i]) 585 | with torch.no_grad(): 586 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 587 | index_no_updates 588 | ] = orig_embeds_params[index_no_updates] 589 | 590 | # Checks if the accelerator has performed an optimization step behind the scenes 591 | if accelerator.sync_gradients: 592 | progress_bar.update(1) 593 | global_step += 1 594 | 595 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 596 | progress_bar.set_postfix(**logs) 597 | accelerator.log(logs, step=global_step) 598 | 599 | 600 | if global_step >= max_train_steps: 601 | break 602 | 603 | if text_inv: 604 | learned_embeds = text_encoder.get_input_embeddings().weight[modifier_token_ids[0]] 605 | to_merge[modifier_tokens[0]] = learned_embeds 606 | 607 | # Create the pipeline using using the trained modules and save it. 608 | accelerator.wait_for_everyone() 609 | if accelerator.is_main_process: 610 | save_path = os.path.join(embed_output_dir, "delta.bin") 611 | pipeline = StableDiffusionPipeline.from_pretrained( 612 | args.pretrained_model_name_or_path, 613 | unet=accelerator.unwrap_model(unet), 614 | text_encoder=accelerator.unwrap_model(text_encoder), 615 | tokenizer=tokenizer, 616 | revision=args.revision) 617 | 618 | pipeline.save_pretrained(save_path, only_text_inv=text_inv, to_merge=to_merge, modifier_tokens = modifier_tokens) 619 | log_validation(args, pipeline, accelerator, num_train_epochs, validation_prompt) 620 | 621 | if args.push_to_hub: 622 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 623 | 624 | accelerator.end_training() 625 | 626 | 627 | 628 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/custom_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/utils/__pycache__/custom_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ddim_inversion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/utils/__pycache__/ddim_inversion.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ptp_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bestzzhang/continuous-layout-editing-code/4bc4dfc5644afaec0498e5ec47b1bb74cf10b05b/utils/__pycache__/ptp_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/custom_utils.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Huggingface repository: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py, and 2 | # https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py 3 | # Copyright 2022- The Hugging Face team. All rights reserved. 4 | # Apache License 5 | # Version 2.0, January 2004 6 | # http://www.apache.org/licenses/ 7 | # ========================================================================================== 8 | # 9 | # modifications are MIT License. To view a copy of the license, visit MIT_LICENSE.md. 10 | # 11 | # ========================================================================================== 12 | # Apache License 13 | # Version 2.0, January 2004 14 | # http://www.apache.org/licenses/ 15 | 16 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 17 | 18 | # 1. Definitions. 19 | 20 | # "License" shall mean the terms and conditions for use, reproduction, 21 | # and distribution as defined by Sections 1 through 9 of this document. 22 | 23 | # "Licensor" shall mean the copyright owner or entity authorized by 24 | # the copyright owner that is granting the License. 25 | 26 | # "Legal Entity" shall mean the union of the acting entity and all 27 | # other entities that control, are controlled by, or are under common 28 | # control with that entity. For the purposes of this definition, 29 | # "control" means (i) the power, direct or indirect, to cause the 30 | # direction or management of such entity, whether by contract or 31 | # otherwise, or (ii) ownership of fifty percent (50%) or more of the 32 | # outstanding shares, or (iii) beneficial ownership of such entity. 33 | 34 | # "You" (or "Your") shall mean an individual or Legal Entity 35 | # exercising permissions granted by this License. 36 | 37 | # "Source" form shall mean the preferred form for making modifications, 38 | # including but not limited to software source code, documentation 39 | # source, and configuration files. 40 | 41 | # "Object" form shall mean any form resulting from mechanical 42 | # transformation or translation of a Source form, including but 43 | # not limited to compiled object code, generated documentation, 44 | # and conversions to other media types. 45 | 46 | # "Work" shall mean the work of authorship, whether in Source or 47 | # Object form, made available under the License, as indicated by a 48 | # copyright notice that is included in or attached to the work 49 | # (an example is provided in the Appendix below). 50 | 51 | # "Derivative Works" shall mean any work, whether in Source or Object 52 | # form, that is based on (or derived from) the Work and for which the 53 | # editorial revisions, annotations, elaborations, or other modifications 54 | # represent, as a whole, an original work of authorship. For the purposes 55 | # of this License, Derivative Works shall not include works that remain 56 | # separable from, or merely link (or bind by name) to the interfaces of, 57 | # the Work and Derivative Works thereof. 58 | 59 | # "Contribution" shall mean any work of authorship, including 60 | # the original version of the Work and any modifications or additions 61 | # to that Work or Derivative Works thereof, that is intentionally 62 | # submitted to Licensor for inclusion in the Work by the copyright owner 63 | # or by an individual or Legal Entity authorized to submit on behalf of 64 | # the copyright owner. For the purposes of this definition, "submitted" 65 | # means any form of electronic, verbal, or written communication sent 66 | # to the Licensor or its representatives, including but not limited to 67 | # communication on electronic mailing lists, source code control systems, 68 | # and issue tracking systems that are managed by, or on behalf of, the 69 | # Licensor for the purpose of discussing and improving the Work, but 70 | # excluding communication that is conspicuously marked or otherwise 71 | # designated in writing by the copyright owner as "Not a Contribution." 72 | 73 | # "Contributor" shall mean Licensor and any individual or Legal Entity 74 | # on behalf of whom a Contribution has been received by Licensor and 75 | # subsequently incorporated within the Work. 76 | 77 | # 2. Grant of Copyright License. Subject to the terms and conditions of 78 | # this License, each Contributor hereby grants to You a perpetual, 79 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | # copyright license to reproduce, prepare Derivative Works of, 81 | # publicly display, publicly perform, sublicense, and distribute the 82 | # Work and such Derivative Works in Source or Object form. 83 | 84 | # 3. Grant of Patent License. Subject to the terms and conditions of 85 | # this License, each Contributor hereby grants to You a perpetual, 86 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 87 | # (except as stated in this section) patent license to make, have made, 88 | # use, offer to sell, sell, import, and otherwise transfer the Work, 89 | # where such license applies only to those patent claims licensable 90 | # by such Contributor that are necessarily infringed by their 91 | # Contribution(s) alone or by combination of their Contribution(s) 92 | # with the Work to which such Contribution(s) was submitted. If You 93 | # institute patent litigation against any entity (including a 94 | # cross-claim or counterclaim in a lawsuit) alleging that the Work 95 | # or a Contribution incorporated within the Work constitutes direct 96 | # or contributory patent infringement, then any patent licenses 97 | # granted to You under this License for that Work shall terminate 98 | # as of the date such litigation is filed. 99 | 100 | # 4. Redistribution. You may reproduce and distribute copies of the 101 | # Work or Derivative Works thereof in any medium, with or without 102 | # modifications, and in Source or Object form, provided that You 103 | # meet the following conditions: 104 | 105 | # (a) You must give any other recipients of the Work or 106 | # Derivative Works a copy of this License; and 107 | 108 | # (b) You must cause any modified files to carry prominent notices 109 | # stating that You changed the files; and 110 | 111 | # (c) You must retain, in the Source form of any Derivative Works 112 | # that You distribute, all copyright, patent, trademark, and 113 | # attribution notices from the Source form of the Work, 114 | # excluding those notices that do not pertain to any part of 115 | # the Derivative Works; and 116 | 117 | # (d) If the Work includes a "NOTICE" text file as part of its 118 | # distribution, then any Derivative Works that You distribute must 119 | # include a readable copy of the attribution notices contained 120 | # within such NOTICE file, excluding those notices that do not 121 | # pertain to any part of the Derivative Works, in at least one 122 | # of the following places: within a NOTICE text file distributed 123 | # as part of the Derivative Works; within the Source form or 124 | # documentation, if provided along with the Derivative Works; or, 125 | # within a display generated by the Derivative Works, if and 126 | # wherever such third-party notices normally appear. The contents 127 | # of the NOTICE file are for informational purposes only and 128 | # do not modify the License. You may add Your own attribution 129 | # notices within Derivative Works that You distribute, alongside 130 | # or as an addendum to the NOTICE text from the Work, provided 131 | # that such additional attribution notices cannot be construed 132 | # as modifying the License. 133 | 134 | # You may add Your own copyright statement to Your modifications and 135 | # may provide additional or different license terms and conditions 136 | # for use, reproduction, or distribution of Your modifications, or 137 | # for any such Derivative Works as a whole, provided Your use, 138 | # reproduction, and distribution of the Work otherwise complies with 139 | # the conditions stated in this License. 140 | 141 | # 5. Submission of Contributions. Unless You explicitly state otherwise, 142 | # any Contribution intentionally submitted for inclusion in the Work 143 | # by You to the Licensor shall be under the terms and conditions of 144 | # this License, without any additional terms or conditions. 145 | # Notwithstanding the above, nothing herein shall supersede or modify 146 | # the terms of any separate license agreement you may have executed 147 | # with Licensor regarding such Contributions. 148 | 149 | # 6. Trademarks. This License does not grant permission to use the trade 150 | # names, trademarks, service marks, or product names of the Licensor, 151 | # except as required for reasonable and customary use in describing the 152 | # origin of the Work and reproducing the content of the NOTICE file. 153 | 154 | # 7. Disclaimer of Warranty. Unless required by applicable law or 155 | # agreed to in writing, Licensor provides the Work (and each 156 | # Contributor provides its Contributions) on an "AS IS" BASIS, 157 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 158 | # implied, including, without limitation, any warranties or conditions 159 | # of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 160 | # PARTICULAR PURPOSE. You are solely responsible for determining the 161 | # appropriateness of using or redistributing the Work and assume any 162 | # risks associated with Your exercise of permissions under this License. 163 | 164 | # 8. Limitation of Liability. In no event and under no legal theory, 165 | # whether in tort (including negligence), contract, or otherwise, 166 | # unless required by applicable law (such as deliberate and grossly 167 | # negligent acts) or agreed to in writing, shall any Contributor be 168 | # liable to You for damages, including any direct, indirect, special, 169 | # incidental, or consequential damages of any character arising as a 170 | # result of this License or out of the use or inability to use the 171 | # Work (including but not limited to damages for loss of goodwill, 172 | # work stoppage, computer failure or malfunction, or any and all 173 | # other commercial damages or losses), even if such Contributor 174 | # has been advised of the possibility of such damages. 175 | 176 | # 9. Accepting Warranty or Additional Liability. While redistributing 177 | # the Work or Derivative Works thereof, You may choose to offer, 178 | # and charge a fee for, acceptance of support, warranty, indemnity, 179 | # or other liability obligations and/or rights consistent with this 180 | # License. However, in accepting such obligations, You may act only 181 | # on Your own behalf and on Your sole responsibility, not on behalf 182 | # of any other Contributor, and only if You agree to indemnify, 183 | # defend, and hold each Contributor harmless for any liability 184 | # incurred by, or claims asserted against, such Contributor by reason 185 | # of your accepting any such warranty or additional liability. 186 | 187 | # END OF TERMS AND CONDITIONS 188 | 189 | # APPENDIX: How to apply the Apache License to your work. 190 | 191 | # To apply the Apache License to your work, attach the following 192 | # boilerplate notice, with the fields enclosed by brackets "[]" 193 | # replaced with your own identifying information. (Don't include 194 | # the brackets!) The text should be enclosed in the appropriate 195 | # comment syntax for the file format. We also recommend that a 196 | # file or class name and description of purpose be included on the 197 | # same "printed page" as the copyright notice for easier 198 | # identification within third-party archives. 199 | 200 | # Copyright [yyyy] [name of copyright owner] 201 | 202 | # Licensed under the Apache License, Version 2.0 (the "License"); 203 | # you may not use this file except in compliance with the License. 204 | # You may obtain a copy of the License at 205 | 206 | # http://www.apache.org/licenses/LICENSE-2.0 207 | 208 | # Unless required by applicable law or agreed to in writing, software 209 | # distributed under the License is distributed on an "AS IS" BASIS, 210 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 211 | # See the License for the specific language governing permissions and 212 | # limitations under the License. 213 | ######################################### 214 | ''' 215 | This is modified from custom-diffusion/src/diffusers_model_pipeline.py 216 | https://github.com/adobe-research/custom-diffusion 217 | ''' 218 | ######################################### 219 | from typing import Callable, Optional 220 | import torch 221 | from diffusers.models.cross_attention import CrossAttention 222 | from diffusers.utils.import_utils import is_xformers_available 223 | 224 | if is_xformers_available(): 225 | import xformers 226 | import xformers.ops 227 | else: 228 | xformers = None 229 | 230 | 231 | 232 | def set_use_memory_efficient_attention_xformers( 233 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None 234 | ): 235 | if use_memory_efficient_attention_xformers: 236 | if self.added_kv_proj_dim is not None: 237 | # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP 238 | # which uses this type of cross attention ONLY because the attention mask of format 239 | # [0, ..., -10.000, ..., 0, ...,] is not supported 240 | raise NotImplementedError( 241 | "Memory efficient attention with `xformers` is currently not supported when" 242 | " `self.added_kv_proj_dim` is defined." 243 | ) 244 | elif not is_xformers_available(): 245 | raise ModuleNotFoundError( 246 | ( 247 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 248 | " xformers" 249 | ), 250 | name="xformers", 251 | ) 252 | elif not torch.cuda.is_available(): 253 | raise ValueError( 254 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 255 | " only available for GPU " 256 | ) 257 | else: 258 | try: 259 | # Make sure we can run the memory efficient attention 260 | _ = xformers.ops.memory_efficient_attention( 261 | torch.randn((1, 2, 40), device="cuda"), 262 | torch.randn((1, 2, 40), device="cuda"), 263 | torch.randn((1, 2, 40), device="cuda"), 264 | ) 265 | except Exception as e: 266 | raise e 267 | 268 | processor = CustomDiffusionXFormersAttnProcessor(attention_op=attention_op) 269 | else: 270 | processor = CustomDiffusionAttnProcessor() 271 | 272 | self.set_processor(processor) 273 | 274 | 275 | class CustomDiffusionAttnProcessor: 276 | def __call__( 277 | self, 278 | attn: CrossAttention, 279 | hidden_states, 280 | encoder_hidden_states=None, 281 | attention_mask=None, 282 | ): 283 | batch_size, sequence_length, _ = hidden_states.shape 284 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 285 | query = attn.to_q(hidden_states) 286 | 287 | crossattn = False 288 | if encoder_hidden_states is None: 289 | encoder_hidden_states = hidden_states 290 | else: 291 | crossattn = True 292 | if attn.cross_attention_norm: 293 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 294 | 295 | key = attn.to_k(encoder_hidden_states) 296 | value = attn.to_v(encoder_hidden_states) 297 | if crossattn: 298 | detach = torch.ones_like(key) # shape: (4, 77, hidden_dim) 299 | detach[:, :1, :] = detach[:, :1, :]*0. 300 | key = detach*key + (1-detach)*key.detach() 301 | value = detach*value + (1-detach)*value.detach() 302 | 303 | query = attn.head_to_batch_dim(query) 304 | key = attn.head_to_batch_dim(key) 305 | value = attn.head_to_batch_dim(value) 306 | 307 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 308 | hidden_states = torch.bmm(attention_probs, value) 309 | hidden_states = attn.batch_to_head_dim(hidden_states) 310 | 311 | # linear proj 312 | hidden_states = attn.to_out[0](hidden_states) 313 | # dropout 314 | hidden_states = attn.to_out[1](hidden_states) 315 | 316 | return hidden_states 317 | 318 | 319 | class CustomDiffusionXFormersAttnProcessor: 320 | def __init__(self, attention_op: Optional[Callable] = None): 321 | self.attention_op = attention_op 322 | 323 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 324 | batch_size, sequence_length, _ = hidden_states.shape 325 | 326 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 327 | 328 | query = attn.to_q(hidden_states) 329 | 330 | crossattn = False 331 | if encoder_hidden_states is None: 332 | encoder_hidden_states = hidden_states 333 | else: 334 | crossattn = True 335 | if attn.cross_attention_norm: 336 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 337 | 338 | key = attn.to_k(encoder_hidden_states) 339 | value = attn.to_v(encoder_hidden_states) 340 | if crossattn: 341 | detach = torch.ones_like(key) 342 | detach[:, :1, :] = detach[:, :1, :]*0. 343 | key = detach*key + (1-detach)*key.detach() 344 | value = detach*value + (1-detach)*value.detach() 345 | 346 | query = attn.head_to_batch_dim(query).contiguous() 347 | key = attn.head_to_batch_dim(key).contiguous() 348 | value = attn.head_to_batch_dim(value).contiguous() 349 | 350 | hidden_states = xformers.ops.memory_efficient_attention( 351 | query, key, value, attn_bias=attention_mask, op=self.attention_op 352 | ) 353 | hidden_states = hidden_states.to(query.dtype) 354 | hidden_states = attn.batch_to_head_dim(hidden_states) 355 | 356 | # linear proj 357 | hidden_states = attn.to_out[0](hidden_states) 358 | # dropout 359 | hidden_states = attn.to_out[1](hidden_states) 360 | return hidden_states 361 | 362 | def load_model(text_encoder, tokenizer, unet, save_path, compress=False): 363 | st = torch.load(save_path) 364 | if 'text_encoder' in st: 365 | text_encoder.load_state_dict(st['text_encoder']) 366 | if 'modifier_token' in st: 367 | modifier_tokens = list(st['modifier_token'].keys()) 368 | modifier_token_id = [] 369 | for modifier_token in modifier_tokens: 370 | _ = tokenizer.add_tokens(modifier_token) 371 | modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token)) 372 | 373 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 374 | text_encoder.resize_token_embeddings(len(tokenizer)) 375 | token_embeds = text_encoder.get_input_embeddings().weight.data 376 | for i, id_ in enumerate(modifier_token_id): 377 | token_embeds[id_] = st['modifier_token'][modifier_tokens[i]] 378 | 379 | if 'unet' in st: 380 | for name, params in unet.named_parameters(): 381 | if 'attn2.to_k' in name or 'attn2.to_v' in name: 382 | if compress: 383 | params.data += st['unet'][name]['u']@st['unet'][name]['v'] 384 | else: 385 | params.data.copy_(st['unet'][f'{name}']) 386 | -------------------------------------------------------------------------------- /utils/ddim_inversion.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | ######################################### 6 | ''' 7 | This is modified from prompt-to-prompt/null_text_w_ptp.ipynb 8 | https://github.com/google/prompt-to-prompt 9 | ''' 10 | ######################################### 11 | 12 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 13 | if type(image_path) is str: 14 | image = np.array(Image.open(image_path))[:, :, :3] 15 | else: 16 | image = image_path 17 | h, w, c = image.shape 18 | left = min(left, w-1) 19 | right = min(right, w - left - 1) 20 | top = min(top, h - left - 1) 21 | bottom = min(bottom, h - top - 1) 22 | image = image[top:h-bottom, left:w-right] 23 | h, w, c = image.shape 24 | if h < w: 25 | offset = (w - h) // 2 26 | image = image[:, offset:offset + h] 27 | elif w < h: 28 | offset = (h - w) // 2 29 | image = image[offset:offset + w] 30 | image = np.array(Image.fromarray(image).resize((512, 512))) 31 | return image 32 | 33 | class DDIMInversion: 34 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 35 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 36 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 37 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 38 | beta_prod_t = 1 - alpha_prod_t 39 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 40 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output 41 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction 42 | return prev_sample 43 | 44 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): 45 | timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep 46 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 47 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 48 | beta_prod_t = 1 - alpha_prod_t 49 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 50 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 51 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 52 | return next_sample 53 | 54 | def get_noise_pred_single(self, latents, t, context): 55 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] 56 | return noise_pred 57 | 58 | def get_noise_pred(self, latents, t, is_forward=True, context=None): 59 | latents_input = torch.cat([latents] * 2) 60 | if context is None: 61 | context = self.context 62 | guidance_scale = 1 if is_forward else self.GUIDANCE_SCALE 63 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 64 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 65 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 66 | if is_forward: 67 | latents = self.next_step(noise_pred, t, latents) 68 | else: 69 | latents = self.prev_step(noise_pred, t, latents) 70 | return latents 71 | 72 | @torch.no_grad() 73 | def latent2image(self, latents, return_type='np'): 74 | latents = 1 / 0.18215 * latents.detach() 75 | image = self.model.vae.decode(latents)['sample'] 76 | if return_type == 'np': 77 | image = (image / 2 + 0.5).clamp(0, 1) 78 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 79 | image = (image * 255).astype(np.uint8) 80 | return image 81 | 82 | @torch.no_grad() 83 | def image2latent(self, image): 84 | with torch.no_grad(): 85 | if type(image) is Image: 86 | image = np.array(image) 87 | if type(image) is torch.Tensor and image.dim() == 4: 88 | latents = image 89 | else: 90 | image = torch.from_numpy(image).float() / 127.5 - 1 91 | image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device) 92 | latents = self.model.vae.encode(image)['latent_dist'].mean 93 | latents = latents * 0.18215 94 | return latents 95 | 96 | @torch.no_grad() 97 | def init_prompt(self, prompt: str): 98 | uncond_input = self.model.tokenizer( 99 | [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, 100 | return_tensors="pt" 101 | ) 102 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 103 | text_input = self.model.tokenizer( 104 | [prompt], 105 | padding="max_length", 106 | max_length=self.model.tokenizer.model_max_length, 107 | truncation=True, 108 | return_tensors="pt", 109 | ) 110 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 111 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 112 | self.prompt = prompt 113 | 114 | @torch.no_grad() 115 | def ddim_loop(self, latent): 116 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 117 | all_latent = [latent] 118 | latent = latent.clone().detach() 119 | for i in range(self.NUM_DDIM_STEPS): 120 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] 121 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) 122 | latent = self.next_step(noise_pred, t, latent) 123 | all_latent.append(latent) 124 | return all_latent 125 | 126 | @property 127 | def scheduler(self): 128 | return self.model.scheduler 129 | 130 | @torch.no_grad() 131 | def ddim_inversion(self, image): 132 | latent = self.image2latent(image) 133 | ddim_latents = self.ddim_loop(latent) 134 | return latent, ddim_latents 135 | 136 | def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): 137 | self.init_prompt(prompt) 138 | image_gt = load_512(image_path, *offsets) 139 | if verbose: 140 | print("DDIM inversion...") 141 | _, ddim_latents = self.ddim_inversion(image_gt) 142 | 143 | return ddim_latents 144 | 145 | 146 | def __init__(self, model, guidance_scale = 5, ddim_steps = 50): 147 | self.model = model 148 | self.GUIDANCE_SCALE = guidance_scale 149 | self.NUM_DDIM_STEPS = ddim_steps 150 | self.tokenizer = model.tokenizer 151 | self.prompt = None 152 | self.context = None 153 | 154 | def create_inversion_latents(ldm_stable, image_path, prompt, guidance_scale, ddim_steps): 155 | # set up number of inference steps 156 | ldm_stable.scheduler.set_timesteps(ddim_steps) 157 | ddim_inversion = DDIMInversion(ldm_stable, guidance_scale, ddim_steps) 158 | z = ddim_inversion.invert(image_path, prompt) 159 | return z[::-1] -------------------------------------------------------------------------------- /utils/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from typing import Union, Tuple, List 7 | from diffusers.models.cross_attention import CrossAttention 8 | ######################################### 9 | ''' 10 | This is retrieve from Attend-and-Excite/utils/ptp_utils.py 11 | https://github.com/yuval-alaluf/Attend-and-Excite 12 | ''' 13 | ######################################### 14 | 15 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: 16 | h, w, c = image.shape 17 | offset = int(h * .2) 18 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 19 | font = cv2.FONT_HERSHEY_SIMPLEX 20 | img[:h] = image 21 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 22 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 23 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 24 | return img 25 | 26 | def view_images(images: Union[np.ndarray, List], 27 | num_rows: int = 1, 28 | offset_ratio: float = 0.02, 29 | display_image: bool = True) -> Image.Image: 30 | """ Displays a list of images in a grid. """ 31 | if type(images) is list: 32 | num_empty = len(images) % num_rows 33 | elif images.ndim == 4: 34 | num_empty = images.shape[0] % num_rows 35 | else: 36 | images = [images] 37 | num_empty = 0 38 | 39 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 40 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 41 | num_items = len(images) 42 | 43 | h, w, c = images[0].shape 44 | offset = int(h * offset_ratio) 45 | num_cols = num_items // num_rows 46 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 47 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 48 | for i in range(num_rows): 49 | for j in range(num_cols): 50 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 51 | i * num_cols + j] 52 | 53 | pil_img = Image.fromarray(image_) 54 | return pil_img 55 | 56 | class AttendExciteCrossAttnProcessor: 57 | 58 | def __init__(self, attnstore, place_in_unet): 59 | super().__init__() 60 | self.attnstore = attnstore 61 | self.place_in_unet = place_in_unet 62 | 63 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 64 | batch_size, sequence_length, _ = hidden_states.shape 65 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size) 66 | 67 | query = attn.to_q(hidden_states) 68 | 69 | is_cross = encoder_hidden_states is not None 70 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 71 | key = attn.to_k(encoder_hidden_states) 72 | value = attn.to_v(encoder_hidden_states) 73 | 74 | query = attn.head_to_batch_dim(query) 75 | key = attn.head_to_batch_dim(key) 76 | value = attn.head_to_batch_dim(value) 77 | 78 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 79 | 80 | self.attnstore(attention_probs, is_cross, self.place_in_unet) 81 | 82 | hidden_states = torch.bmm(attention_probs, value) 83 | hidden_states = attn.batch_to_head_dim(hidden_states) 84 | 85 | # linear proj 86 | hidden_states = attn.to_out[0](hidden_states) 87 | # dropout 88 | hidden_states = attn.to_out[1](hidden_states) 89 | 90 | return hidden_states 91 | 92 | def register_attention_control(model, controller): 93 | attn_procs = {} 94 | cross_att_count = 0 95 | for name in model.unet.attn_processors.keys(): 96 | cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim 97 | if name.startswith("mid_block"): 98 | hidden_size = model.unet.config.block_out_channels[-1] 99 | place_in_unet = "mid" 100 | elif name.startswith("up_blocks"): 101 | block_id = int(name[len("up_blocks.")]) 102 | hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] 103 | place_in_unet = "up" 104 | elif name.startswith("down_blocks"): 105 | block_id = int(name[len("down_blocks.")]) 106 | hidden_size = model.unet.config.block_out_channels[block_id] 107 | place_in_unet = "down" 108 | else: 109 | continue 110 | 111 | cross_att_count += 1 112 | attn_procs[name] = AttendExciteCrossAttnProcessor( 113 | attnstore=controller, place_in_unet=place_in_unet 114 | ) 115 | 116 | model.unet.set_attn_processor(attn_procs) 117 | controller.num_att_layers = cross_att_count 118 | 119 | class AttentionControl(abc.ABC): 120 | 121 | def step_callback(self, x_t): 122 | return x_t 123 | 124 | def between_steps(self): 125 | return 126 | 127 | @property 128 | def num_uncond_att_layers(self): 129 | return 0 130 | 131 | @abc.abstractmethod 132 | def forward(self, attn, is_cross: bool, place_in_unet: str): 133 | raise NotImplementedError 134 | 135 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 136 | if self.cur_att_layer >= self.num_uncond_att_layers: 137 | self.forward(attn, is_cross, place_in_unet) 138 | self.cur_att_layer += 1 139 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 140 | self.cur_att_layer = 0 141 | self.cur_step += 1 142 | self.between_steps() 143 | 144 | def reset(self): 145 | self.cur_step = 0 146 | self.cur_att_layer = 0 147 | 148 | def __init__(self): 149 | self.cur_step = 0 150 | self.num_att_layers = -1 151 | self.cur_att_layer = 0 152 | 153 | class AttentionStore(AttentionControl): 154 | 155 | @staticmethod 156 | def get_empty_store(): 157 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 158 | "down_self": [], "mid_self": [], "up_self": []} 159 | 160 | def forward(self, attn, is_cross: bool, place_in_unet: str): 161 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 162 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 163 | self.step_store[key].append(attn) 164 | return attn 165 | 166 | def between_steps(self): 167 | self.attention_store = self.step_store 168 | if self.save_global_store: 169 | with torch.no_grad(): 170 | if len(self.global_store) == 0: 171 | self.global_store = self.step_store 172 | else: 173 | for key in self.global_store: 174 | for i in range(len(self.global_store[key])): 175 | self.global_store[key][i] += self.step_store[key][i].detach() 176 | self.step_store = self.get_empty_store() 177 | self.step_store = self.get_empty_store() 178 | 179 | def get_average_attention(self): 180 | average_attention = self.attention_store 181 | return average_attention 182 | 183 | def get_average_global_attention(self): 184 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 185 | self.attention_store} 186 | return average_attention 187 | 188 | def reset(self): 189 | super(AttentionStore, self).reset() 190 | self.step_store = self.get_empty_store() 191 | self.attention_store = {} 192 | self.global_store = {} 193 | 194 | def __init__(self, save_global_store=False): 195 | ''' 196 | Initialize an empty AttentionStore 197 | :param step_index: used to visualize only a specific step in the diffusion process 198 | ''' 199 | super(AttentionStore, self).__init__() 200 | self.save_global_store = save_global_store 201 | self.step_store = self.get_empty_store() 202 | self.attention_store = {} 203 | self.global_store = {} 204 | self.curr_step_index = 0 205 | 206 | def aggregate_attention(attention_store: AttentionStore, 207 | res: int, 208 | from_where: List[str], 209 | is_cross: bool, 210 | select: int) -> torch.Tensor: 211 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 212 | out = [] 213 | attention_maps = attention_store.get_average_attention() 214 | num_pixels = res ** 2 215 | for location in from_where: 216 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 217 | if item.shape[1] == num_pixels: 218 | cross_maps = item.reshape(-1, 8, res, res, item.shape[-1])[select] 219 | out.append(cross_maps) 220 | out = torch.cat(out, dim=0) 221 | out = out.sum(0) / out.shape[0] 222 | return out 223 | -------------------------------------------------------------------------------- /utils/retrieve.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Adobe Research. All rights reserved. 2 | # To view a copy of the license, visit LICENSE.md. 3 | 4 | import argparse 5 | import os 6 | import tqdm 7 | from pathlib import Path 8 | import requests 9 | from PIL import Image 10 | from io import BytesIO 11 | from clip_retrieval.clip_client import ClipClient 12 | 13 | 14 | def retrieve(target_name, outpath, num_class_images): 15 | num_images = 2*num_class_images 16 | client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-H-14", num_images=num_images, aesthetic_weight=0.1) 17 | 18 | if len(target_name.split()): 19 | target = '_'.join(target_name.split()) 20 | else: 21 | target = target_name 22 | os.makedirs(f'{outpath}/{target}', exist_ok=True) 23 | 24 | if len(list(Path(f'{outpath}/{target}').iterdir())) >= num_class_images: 25 | return 26 | 27 | while True: 28 | text = f"a photo of {target_name}" 29 | print(text) 30 | results = client.query(text=text) 31 | if len(results) >= num_class_images or num_images > 1e4: 32 | break 33 | else: 34 | num_images = int(1.5*num_images) 35 | client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1) 36 | 37 | count = 0 38 | urls = [] 39 | captions = [] 40 | 41 | pbar = tqdm.tqdm(desc='downloading real regularization images', total=num_class_images) 42 | 43 | for each in results: 44 | name = f'{outpath}/{target}/{count}.jpg' 45 | success = True 46 | while True: 47 | try: 48 | img = requests.get(each['url']) 49 | success = True 50 | break 51 | except: 52 | success = False 53 | break 54 | if success and img.status_code == 200: 55 | try: 56 | _ = Image.open(BytesIO(img.content)) 57 | with open(name, 'wb') as f: 58 | f.write(img.content) 59 | urls.append(each['url']) 60 | captions.append(each['caption']) 61 | count += 1 62 | pbar.update(1) 63 | except: 64 | pass 65 | if count > num_class_images: 66 | break 67 | 68 | with open(f'{outpath}/caption.txt', 'w') as f: 69 | for each in captions: 70 | f.write(each.strip() + '\n') 71 | 72 | with open(f'{outpath}/urls.txt', 'w') as f: 73 | for each in urls: 74 | f.write(each.strip() + '\n') 75 | 76 | with open(f'{outpath}/images.txt', 'w') as f: 77 | for p in range(count): 78 | f.write(f'{outpath}/{target}/{p}.jpg' + '\n') 79 | 80 | 81 | def parse_args(): 82 | parser = argparse.ArgumentParser('', add_help=False) 83 | parser.add_argument('--target_name', help='target string for query', 84 | type=str) 85 | parser.add_argument('--outpath', help='path to save retrieved images', default='./', 86 | type=str) 87 | parser.add_argument('--num_class_images', help='number of retrieved images', default=200, 88 | type=int) 89 | return parser.parse_args() 90 | 91 | 92 | if __name__ == "__main__": 93 | args = parse_args() 94 | names = args.target_name.split("+") 95 | for name in names: 96 | retrieve(name, f"{args.outpath}/samples_{name}", args.num_class_images) 97 | --------------------------------------------------------------------------------