├── CrossAttention_Release.ipynb ├── CrossAttention_Release_NoImages.ipynb ├── InverseCrossAttention_Release.ipynb ├── InverseCrossAttention_Release_NoImages.ipynb ├── LICENSE ├── README.md ├── images ├── A fantasy landscape with a pine forest without fog and rocks, dry sunny day, grass.png ├── A fantasy landscape with a pine forest without fog and without rocks.png ├── a fantasy landscape with a pine forest - A fantasy landscape with a pine forest and a river.png ├── a fantasy landscape with a pine forest - a watercolor painting of a landscape with a pine forest.png ├── a fantasy landscape with a pine forest - a winter fantasy landscape with a pine forest.png ├── a fantasy landscape with a pine forest - decrease clouds.png ├── a fantasy landscape with a pine forest - decrease fantasy.png ├── a fantasy landscape with a pine forest - decrease fog.png ├── a fantasy landscape with a pine forest - decrease rocks.png ├── a fantasy landscape with a pine forest - increase fantasy and forest.png ├── faces_test.png ├── fouranimals.png ├── fourseasons.png └── fourstyles.png └── portrait.png /CrossAttention_Release_NoImages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "88b4974c-6437-422d-afae-daa2884ad633", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer\n", 12 | "from diffusers import AutoencoderKL, UNet2DConditionModel\n", 13 | "\n", 14 | "#NOTE: Last tested working diffusers version is diffusers==0.4.1, https://github.com/huggingface/diffusers/releases/tag/v0.4.1" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "38ebfbd7-5026-4830-93e5-d43272db8912", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#Init CLIP tokenizer and model\n", 25 | "model_path_clip = \"openai/clip-vit-large-patch14\"\n", 26 | "clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)\n", 27 | "clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch.float16)\n", 28 | "clip = clip_model.text_model\n", 29 | "\n", 30 | "#Init diffusion model\n", 31 | "auth_token = True #Replace this with huggingface auth token as a string if model is not already downloaded\n", 32 | "model_path_diffusion = \"CompVis/stable-diffusion-v1-4\"\n", 33 | "unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder=\"unet\", use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16)\n", 34 | "vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder=\"vae\", use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16)\n", 35 | "\n", 36 | "#Move to GPU\n", 37 | "device = \"cuda\"\n", 38 | "unet.to(device)\n", 39 | "vae.to(device)\n", 40 | "clip.to(device)\n", 41 | "print(\"Loaded all models\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "08dee7d1-e050-43d3-86a9-5776276aad78", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import numpy as np\n", 52 | "import random\n", 53 | "from PIL import Image\n", 54 | "from diffusers import LMSDiscreteScheduler\n", 55 | "from tqdm.auto import tqdm\n", 56 | "from torch import autocast\n", 57 | "from difflib import SequenceMatcher\n", 58 | "\n", 59 | "def init_attention_weights(weight_tuples):\n", 60 | " tokens_length = clip_tokenizer.model_max_length\n", 61 | " weights = torch.ones(tokens_length)\n", 62 | " \n", 63 | " for i, w in weight_tuples:\n", 64 | " if i < tokens_length and i >= 0:\n", 65 | " weights[i] = w\n", 66 | " \n", 67 | " \n", 68 | " for name, module in unet.named_modules():\n", 69 | " module_name = type(module).__name__\n", 70 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 71 | " module.last_attn_slice_weights = weights.to(device)\n", 72 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 73 | " module.last_attn_slice_weights = None\n", 74 | " \n", 75 | "\n", 76 | "def init_attention_edit(tokens, tokens_edit):\n", 77 | " tokens_length = clip_tokenizer.model_max_length\n", 78 | " mask = torch.zeros(tokens_length)\n", 79 | " indices_target = torch.arange(tokens_length, dtype=torch.long)\n", 80 | " indices = torch.zeros(tokens_length, dtype=torch.long)\n", 81 | "\n", 82 | " tokens = tokens.input_ids.numpy()[0]\n", 83 | " tokens_edit = tokens_edit.input_ids.numpy()[0]\n", 84 | " \n", 85 | " for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():\n", 86 | " if b0 < tokens_length:\n", 87 | " if name == \"equal\" or (name == \"replace\" and a1-a0 == b1-b0):\n", 88 | " mask[b0:b1] = 1\n", 89 | " indices[b0:b1] = indices_target[a0:a1]\n", 90 | "\n", 91 | " for name, module in unet.named_modules():\n", 92 | " module_name = type(module).__name__\n", 93 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 94 | " module.last_attn_slice_mask = mask.to(device)\n", 95 | " module.last_attn_slice_indices = indices.to(device)\n", 96 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 97 | " module.last_attn_slice_mask = None\n", 98 | " module.last_attn_slice_indices = None\n", 99 | "\n", 100 | "\n", 101 | "def init_attention_func():\n", 102 | " #ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276\n", 103 | " def new_attention(self, query, key, value):\n", 104 | " # TODO: use baddbmm for better performance\n", 105 | " attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale\n", 106 | " attn_slice = attention_scores.softmax(dim=-1)\n", 107 | " # compute attention output\n", 108 | " \n", 109 | " if self.use_last_attn_slice:\n", 110 | " if self.last_attn_slice_mask is not None:\n", 111 | " new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)\n", 112 | " attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask\n", 113 | " else:\n", 114 | " attn_slice = self.last_attn_slice\n", 115 | "\n", 116 | " self.use_last_attn_slice = False\n", 117 | "\n", 118 | " if self.save_last_attn_slice:\n", 119 | " self.last_attn_slice = attn_slice\n", 120 | " self.save_last_attn_slice = False\n", 121 | "\n", 122 | " if self.use_last_attn_weights and self.last_attn_slice_weights is not None:\n", 123 | " attn_slice = attn_slice * self.last_attn_slice_weights\n", 124 | " self.use_last_attn_weights = False\n", 125 | " \n", 126 | " hidden_states = torch.matmul(attn_slice, value)\n", 127 | " # reshape hidden_states\n", 128 | " hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n", 129 | " return hidden_states\n", 130 | " \n", 131 | " def new_sliced_attention(self, query, key, value, sequence_length, dim):\n", 132 | " \n", 133 | " batch_size_attention = query.shape[0]\n", 134 | " hidden_states = torch.zeros(\n", 135 | " (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype\n", 136 | " )\n", 137 | " slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]\n", 138 | " for i in range(hidden_states.shape[0] // slice_size):\n", 139 | " start_idx = i * slice_size\n", 140 | " end_idx = (i + 1) * slice_size\n", 141 | " attn_slice = (\n", 142 | " torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale\n", 143 | " ) # TODO: use baddbmm for better performance\n", 144 | " attn_slice = attn_slice.softmax(dim=-1)\n", 145 | " \n", 146 | " if self.use_last_attn_slice:\n", 147 | " if self.last_attn_slice_mask is not None:\n", 148 | " new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)\n", 149 | " attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask\n", 150 | " else:\n", 151 | " attn_slice = self.last_attn_slice\n", 152 | " \n", 153 | " self.use_last_attn_slice = False\n", 154 | " \n", 155 | " if self.save_last_attn_slice:\n", 156 | " self.last_attn_slice = attn_slice\n", 157 | " self.save_last_attn_slice = False\n", 158 | " \n", 159 | " if self.use_last_attn_weights and self.last_attn_slice_weights is not None:\n", 160 | " attn_slice = attn_slice * self.last_attn_slice_weights\n", 161 | " self.use_last_attn_weights = False\n", 162 | " \n", 163 | " attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])\n", 164 | "\n", 165 | " hidden_states[start_idx:end_idx] = attn_slice\n", 166 | "\n", 167 | " # reshape hidden_states\n", 168 | " hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n", 169 | " return hidden_states\n", 170 | "\n", 171 | " for name, module in unet.named_modules():\n", 172 | " module_name = type(module).__name__\n", 173 | " if module_name == \"CrossAttention\":\n", 174 | " module.last_attn_slice = None\n", 175 | " module.use_last_attn_slice = False\n", 176 | " module.use_last_attn_weights = False\n", 177 | " module.save_last_attn_slice = False\n", 178 | " module._sliced_attention = new_sliced_attention.__get__(module, type(module))\n", 179 | " module._attention = new_attention.__get__(module, type(module))\n", 180 | " \n", 181 | "def use_last_tokens_attention(use=True):\n", 182 | " for name, module in unet.named_modules():\n", 183 | " module_name = type(module).__name__\n", 184 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 185 | " module.use_last_attn_slice = use\n", 186 | " \n", 187 | "def use_last_tokens_attention_weights(use=True):\n", 188 | " for name, module in unet.named_modules():\n", 189 | " module_name = type(module).__name__\n", 190 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 191 | " module.use_last_attn_weights = use\n", 192 | " \n", 193 | "def use_last_self_attention(use=True):\n", 194 | " for name, module in unet.named_modules():\n", 195 | " module_name = type(module).__name__\n", 196 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 197 | " module.use_last_attn_slice = use\n", 198 | " \n", 199 | "def save_last_tokens_attention(save=True):\n", 200 | " for name, module in unet.named_modules():\n", 201 | " module_name = type(module).__name__\n", 202 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 203 | " module.save_last_attn_slice = save\n", 204 | " \n", 205 | "def save_last_self_attention(save=True):\n", 206 | " for name, module in unet.named_modules():\n", 207 | " module_name = type(module).__name__\n", 208 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 209 | " module.save_last_attn_slice = save\n", 210 | " \n", 211 | "@torch.no_grad()\n", 212 | "def stablediffusion(prompt=\"\", prompt_edit=None, prompt_edit_token_weights=[], prompt_edit_tokens_start=0.0, prompt_edit_tokens_end=1.0, prompt_edit_spatial_start=0.0, prompt_edit_spatial_end=1.0, guidance_scale=7.5, steps=50, seed=None, width=512, height=512, init_image=None, init_image_strength=0.5):\n", 213 | " #Change size to multiple of 64 to prevent size mismatches inside model\n", 214 | " width = width - width % 64\n", 215 | " height = height - height % 64\n", 216 | " \n", 217 | " #If seed is None, randomly select seed from 0 to 2^32-1\n", 218 | " if seed is None: seed = random.randrange(2**32 - 1)\n", 219 | " generator = torch.cuda.manual_seed(seed)\n", 220 | " \n", 221 | " #Set inference timesteps to scheduler\n", 222 | " scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\n", 223 | " scheduler.set_timesteps(steps)\n", 224 | " \n", 225 | " #Preprocess image if it exists (img2img)\n", 226 | " if init_image is not None:\n", 227 | " #Resize and transpose for numpy b h w c -> torch b c h w\n", 228 | " init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)\n", 229 | " init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0\n", 230 | " init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))\n", 231 | " \n", 232 | " #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel\n", 233 | " if init_image.shape[1] > 3:\n", 234 | " init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])\n", 235 | " \n", 236 | " #Move image to GPU\n", 237 | " init_image = init_image.to(device)\n", 238 | " \n", 239 | " #Encode image\n", 240 | " with autocast(device):\n", 241 | " init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215\n", 242 | " \n", 243 | " t_start = steps - int(steps * init_image_strength)\n", 244 | " \n", 245 | " else:\n", 246 | " init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)\n", 247 | " t_start = 0\n", 248 | " \n", 249 | " #Generate random normal noise\n", 250 | " noise = torch.randn(init_latent.shape, generator=generator, device=device)\n", 251 | " #latent = noise * scheduler.init_noise_sigma\n", 252 | " latent = scheduler.add_noise(init_latent, noise, torch.tensor([scheduler.timesteps[t_start]], device=device)).to(device)\n", 253 | " \n", 254 | " #Process clip\n", 255 | " with autocast(device):\n", 256 | " tokens_unconditional = clip_tokenizer(\"\", padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 257 | " embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state\n", 258 | "\n", 259 | " tokens_conditional = clip_tokenizer(prompt, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 260 | " embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state\n", 261 | "\n", 262 | " #Process prompt editing\n", 263 | " if prompt_edit is not None:\n", 264 | " tokens_conditional_edit = clip_tokenizer(prompt_edit, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 265 | " embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state\n", 266 | " \n", 267 | " init_attention_edit(tokens_conditional, tokens_conditional_edit)\n", 268 | " \n", 269 | " init_attention_func()\n", 270 | " init_attention_weights(prompt_edit_token_weights)\n", 271 | " \n", 272 | " timesteps = scheduler.timesteps[t_start:]\n", 273 | " \n", 274 | " for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):\n", 275 | " t_index = t_start + i\n", 276 | "\n", 277 | " #sigma = scheduler.sigmas[t_index]\n", 278 | " latent_model_input = latent\n", 279 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 280 | "\n", 281 | " #Predict the unconditional noise residual\n", 282 | " noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional).sample\n", 283 | " \n", 284 | " #Prepare the Cross-Attention layers\n", 285 | " if prompt_edit is not None:\n", 286 | " save_last_tokens_attention()\n", 287 | " save_last_self_attention()\n", 288 | " else:\n", 289 | " #Use weights on non-edited prompt when edit is None\n", 290 | " use_last_tokens_attention_weights()\n", 291 | " \n", 292 | " #Predict the conditional noise residual and save the cross-attention layer activations\n", 293 | " noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional).sample\n", 294 | " \n", 295 | " #Edit the Cross-Attention layer activations\n", 296 | " if prompt_edit is not None:\n", 297 | " t_scale = t / scheduler.num_train_timesteps\n", 298 | " if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:\n", 299 | " use_last_tokens_attention()\n", 300 | " if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:\n", 301 | " use_last_self_attention()\n", 302 | " \n", 303 | " #Use weights on edited prompt\n", 304 | " use_last_tokens_attention_weights()\n", 305 | "\n", 306 | " #Predict the edited conditional noise residual using the cross-attention masks\n", 307 | " noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional_edit).sample\n", 308 | " \n", 309 | " #Perform guidance\n", 310 | " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n", 311 | "\n", 312 | " latent = scheduler.step(noise_pred, t_index, latent).prev_sample\n", 313 | "\n", 314 | " #scale and decode the image latents with vae\n", 315 | " latent = latent / 0.18215\n", 316 | " image = vae.decode(latent.to(vae.dtype)).sample\n", 317 | "\n", 318 | " image = (image / 2 + 0.5).clamp(0, 1)\n", 319 | " image = image.cpu().permute(0, 2, 3, 1).numpy()\n", 320 | " image = (image[0] * 255).round().astype(\"uint8\")\n", 321 | " return Image.fromarray(image)\n" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "24e9d39a-df21-45ba-bcfc-58b8784617d6", 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "def prompt_token(prompt, index):\n", 332 | " tokens = clip_tokenizer(prompt, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True).input_ids[0]\n", 333 | " return clip_tokenizer.decode(tokens[index:index+1])" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "f44ff9cb-6d43-4b32-bf2a-17bfca665218", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "prompt_token(\"a cat sitting on a car\", 2)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "3f69d449-0192-4db2-8687-9f9adab3907d", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "stablediffusion(\"a cat sitting on a car\", seed=248396402679, steps=50)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "id": "ba17dce4-f14c-4df3-8325-7adcabb33c1d", 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "stablediffusion(\"a cat sitting on a car\", \"a smiling dog sitting on a car\", prompt_edit_spatial_start=0.7, seed=248396402679)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "id": "31c10c7e-38d4-4627-816d-241073684825", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "stablediffusion(\"a cat sitting on a car\", \"a hamster sitting on a car\", prompt_edit_spatial_start=0.5, seed=248396402679)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "b62f03c4-ee26-4f1d-a3c2-c83666e0a37e", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "stablediffusion(\"a cat sitting on a car\", \"a tiger sitting on a car\", prompt_edit_spatial_start=1.0, seed=248396402679)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "id": "4a16be9a-0411-4fbd-83f9-e968243d94fa", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", seed=2483964025, width=768)" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "id": "54f31e7f-5527-4310-9f86-909c2f3525dd", 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "#Remove fantasy\n", 404 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", prompt_edit_token_weights=[(2, -8)], seed=2483964025, width=768)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "f919ec34-5038-4bcc-9bcf-c959b6d2e9f1", 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "#Winter\n", 415 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A winter fantasy landscape with a pine forest, trending on artstation\", seed=2483964025, width=768)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "id": "8e8bfb6c-4510-46e2-adfe-e99e38abf7a0", 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "#Watercolor style\n", 426 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A watercolor painting of a landscape with a pine forest, trending on artstation\", seed=2483964025, width=768)" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "995ea7e7-21c8-42ab-99db-3388469fe443", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "#Remove fog\n", 437 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A fantasy landscape with a pine forest with fog, trending on artstation\", prompt_edit_token_weights=[(9, -6)], seed=2483964025, width=768)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "id": "9cbdf169-a0c3-47aa-8cf4-dee1312edab6", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "#Removing fog and rocks\n", 448 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A fantasy landscape with a pine forest with fog and rocks, trending on artstation\", prompt_edit_token_weights=[(9, -6), (11, -6)], seed=2483964025, width=768)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "id": "34d8cd67-a1ab-4572-82b1-337637b5cfce", 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "#Adding a river\n", 459 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A fantasy landscape with a pine forest and a river, trending on artstation\", prompt_edit_spatial_end=0.8, seed=2483964025, width=768)" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "id": "57cf95b7-38f6-49f6-8492-266af0569e3c", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "#Adding a lone cabin\n", 470 | "stablediffusion(\"A fantasy landscape with a pine forest, trending on artstation\", \"A fantasy landscape with a pine forest, lone cabin, trending on artstation\", prompt_edit_spatial_start=1.0, seed=2483964025, width=768)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "id": "fb871a4d-153f-4afe-b6fe-a4cdd4635021", 477 | "metadata": {}, 478 | "outputs": [], 479 | "source": [] 480 | } 481 | ], 482 | "metadata": { 483 | "kernelspec": { 484 | "display_name": "diffusers", 485 | "language": "python", 486 | "name": "diffusers" 487 | }, 488 | "language_info": { 489 | "codemirror_mode": { 490 | "name": "ipython", 491 | "version": 3 492 | }, 493 | "file_extension": ".py", 494 | "mimetype": "text/x-python", 495 | "name": "python", 496 | "nbconvert_exporter": "python", 497 | "pygments_lexer": "ipython3", 498 | "version": "3.10.4" 499 | } 500 | }, 501 | "nbformat": 4, 502 | "nbformat_minor": 5 503 | } 504 | -------------------------------------------------------------------------------- /InverseCrossAttention_Release_NoImages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "88b4974c-6437-422d-afae-daa2884ad633", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer\n", 12 | "from diffusers import AutoencoderKL, UNet2DConditionModel\n", 13 | "\n", 14 | "#NOTE: Last tested working diffusers version is diffusers==0.4.1, https://github.com/huggingface/diffusers/releases/tag/v0.4.1" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "38ebfbd7-5026-4830-93e5-d43272db8912", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "#Init CLIP tokenizer and model\n", 25 | "model_path_clip = \"openai/clip-vit-large-patch14\"\n", 26 | "clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)\n", 27 | "clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch.float16)\n", 28 | "clip = clip_model.text_model\n", 29 | "\n", 30 | "#Init diffusion model\n", 31 | "auth_token = True #Replace this with huggingface auth token as a string if model is not already downloaded\n", 32 | "model_path_diffusion = \"CompVis/stable-diffusion-v1-4\"\n", 33 | "unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder=\"unet\", use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16)\n", 34 | "vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder=\"vae\", use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16)\n", 35 | "\n", 36 | "#Move to GPU\n", 37 | "device = \"cuda\"\n", 38 | "unet.to(device)\n", 39 | "vae.to(device)\n", 40 | "clip.to(device)\n", 41 | "print(\"Loaded all models\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "e6aa4473-05c4-4d35-82b4-23a01f004cfc", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import numpy as np\n", 52 | "import random\n", 53 | "from PIL import Image\n", 54 | "from diffusers import LMSDiscreteScheduler, DDIMScheduler\n", 55 | "from tqdm.auto import tqdm\n", 56 | "from torch import autocast\n", 57 | "from difflib import SequenceMatcher\n", 58 | "\n", 59 | "def init_attention_weights(weight_tuples):\n", 60 | " tokens_length = clip_tokenizer.model_max_length\n", 61 | " weights = torch.ones(tokens_length)\n", 62 | " \n", 63 | " for i, w in weight_tuples:\n", 64 | " if i < tokens_length and i >= 0:\n", 65 | " weights[i] = w\n", 66 | " \n", 67 | " \n", 68 | " for name, module in unet.named_modules():\n", 69 | " module_name = type(module).__name__\n", 70 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 71 | " module.last_attn_slice_weights = weights.to(device)\n", 72 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 73 | " module.last_attn_slice_weights = None\n", 74 | " \n", 75 | "\n", 76 | "def init_attention_edit(tokens, tokens_edit):\n", 77 | " tokens_length = clip_tokenizer.model_max_length\n", 78 | " mask = torch.zeros(tokens_length)\n", 79 | " indices_target = torch.arange(tokens_length, dtype=torch.long)\n", 80 | " indices = torch.zeros(tokens_length, dtype=torch.long)\n", 81 | "\n", 82 | " tokens = tokens.input_ids.numpy()[0]\n", 83 | " tokens_edit = tokens_edit.input_ids.numpy()[0]\n", 84 | " \n", 85 | " for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():\n", 86 | " if b0 < tokens_length:\n", 87 | " if name == \"equal\" or (name == \"replace\" and a1-a0 == b1-b0):\n", 88 | " mask[b0:b1] = 1\n", 89 | " indices[b0:b1] = indices_target[a0:a1]\n", 90 | "\n", 91 | " for name, module in unet.named_modules():\n", 92 | " module_name = type(module).__name__\n", 93 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 94 | " module.last_attn_slice_mask = mask.to(device)\n", 95 | " module.last_attn_slice_indices = indices.to(device)\n", 96 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 97 | " module.last_attn_slice_mask = None\n", 98 | " module.last_attn_slice_indices = None\n", 99 | "\n", 100 | "\n", 101 | "def init_attention_func():\n", 102 | " #ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276\n", 103 | " def new_attention(self, query, key, value):\n", 104 | " # TODO: use baddbmm for better performance\n", 105 | " attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale\n", 106 | " attn_slice = attention_scores.softmax(dim=-1)\n", 107 | " # compute attention output\n", 108 | " \n", 109 | " if self.use_last_attn_slice:\n", 110 | " if self.last_attn_slice_mask is not None:\n", 111 | " new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)\n", 112 | " attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask\n", 113 | " else:\n", 114 | " attn_slice = self.last_attn_slice\n", 115 | "\n", 116 | " self.use_last_attn_slice = False\n", 117 | "\n", 118 | " if self.save_last_attn_slice:\n", 119 | " self.last_attn_slice = attn_slice\n", 120 | " self.save_last_attn_slice = False\n", 121 | "\n", 122 | " if self.use_last_attn_weights and self.last_attn_slice_weights is not None:\n", 123 | " attn_slice = attn_slice * self.last_attn_slice_weights\n", 124 | " self.use_last_attn_weights = False\n", 125 | " \n", 126 | " hidden_states = torch.matmul(attn_slice, value)\n", 127 | " # reshape hidden_states\n", 128 | " hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n", 129 | " return hidden_states\n", 130 | " \n", 131 | " def new_sliced_attention(self, query, key, value, sequence_length, dim):\n", 132 | " \n", 133 | " batch_size_attention = query.shape[0]\n", 134 | " hidden_states = torch.zeros(\n", 135 | " (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype\n", 136 | " )\n", 137 | " slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]\n", 138 | " for i in range(hidden_states.shape[0] // slice_size):\n", 139 | " start_idx = i * slice_size\n", 140 | " end_idx = (i + 1) * slice_size\n", 141 | " attn_slice = (\n", 142 | " torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale\n", 143 | " ) # TODO: use baddbmm for better performance\n", 144 | " attn_slice = attn_slice.softmax(dim=-1)\n", 145 | " \n", 146 | " if self.use_last_attn_slice:\n", 147 | " if self.last_attn_slice_mask is not None:\n", 148 | " new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)\n", 149 | " attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask\n", 150 | " else:\n", 151 | " attn_slice = self.last_attn_slice\n", 152 | " \n", 153 | " self.use_last_attn_slice = False\n", 154 | " \n", 155 | " if self.save_last_attn_slice:\n", 156 | " self.last_attn_slice = attn_slice\n", 157 | " self.save_last_attn_slice = False\n", 158 | " \n", 159 | " if self.use_last_attn_weights and self.last_attn_slice_weights is not None:\n", 160 | " attn_slice = attn_slice * self.last_attn_slice_weights\n", 161 | " self.use_last_attn_weights = False\n", 162 | " \n", 163 | " attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])\n", 164 | "\n", 165 | " hidden_states[start_idx:end_idx] = attn_slice\n", 166 | "\n", 167 | " # reshape hidden_states\n", 168 | " hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n", 169 | " return hidden_states\n", 170 | "\n", 171 | " for name, module in unet.named_modules():\n", 172 | " module_name = type(module).__name__\n", 173 | " if module_name == \"CrossAttention\":\n", 174 | " module.last_attn_slice = None\n", 175 | " module.use_last_attn_slice = False\n", 176 | " module.use_last_attn_weights = False\n", 177 | " module.save_last_attn_slice = False\n", 178 | " module._sliced_attention = new_sliced_attention.__get__(module, type(module))\n", 179 | " module._attention = new_attention.__get__(module, type(module))\n", 180 | " \n", 181 | "def use_last_tokens_attention(use=True):\n", 182 | " for name, module in unet.named_modules():\n", 183 | " module_name = type(module).__name__\n", 184 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 185 | " module.use_last_attn_slice = use\n", 186 | " \n", 187 | "def use_last_tokens_attention_weights(use=True):\n", 188 | " for name, module in unet.named_modules():\n", 189 | " module_name = type(module).__name__\n", 190 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 191 | " module.use_last_attn_weights = use\n", 192 | " \n", 193 | "def use_last_self_attention(use=True):\n", 194 | " for name, module in unet.named_modules():\n", 195 | " module_name = type(module).__name__\n", 196 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 197 | " module.use_last_attn_slice = use\n", 198 | " \n", 199 | "def save_last_tokens_attention(save=True):\n", 200 | " for name, module in unet.named_modules():\n", 201 | " module_name = type(module).__name__\n", 202 | " if module_name == \"CrossAttention\" and \"attn2\" in name:\n", 203 | " module.save_last_attn_slice = save\n", 204 | " \n", 205 | "def save_last_self_attention(save=True):\n", 206 | " for name, module in unet.named_modules():\n", 207 | " module_name = type(module).__name__\n", 208 | " if module_name == \"CrossAttention\" and \"attn1\" in name:\n", 209 | " module.save_last_attn_slice = save\n", 210 | " \n", 211 | "@torch.no_grad()\n", 212 | "def stablediffusion(prompt=\"\", prompt_edit=None, prompt_edit_token_weights=[], init_latents=None, prompt_edit_tokens_start=0.0, prompt_edit_tokens_end=1.0, prompt_edit_spatial_start=0.0, prompt_edit_spatial_end=1.0, guidance_scale=7.5, steps=50, seed=None, width=512, height=512):\n", 213 | " #Change size to multiple of 64 to prevent size mismatches inside model\n", 214 | " if init_latents is not None:\n", 215 | " width = init_latents.shape[-1] * 8\n", 216 | " height = init_latents.shape[-2] * 8\n", 217 | " \n", 218 | " width = width - width % 64\n", 219 | " height = height - height % 64\n", 220 | " \n", 221 | " #If seed is None, randomly select seed from 0 to 2^32-1\n", 222 | " if seed is None: seed = random.randrange(2**32 - 1)\n", 223 | " generator = torch.cuda.manual_seed(seed)\n", 224 | " \n", 225 | " #Set inference timesteps to scheduler\n", 226 | " scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\n", 227 | " scheduler.set_timesteps(steps)\n", 228 | " \n", 229 | " init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)\n", 230 | " t_start = 0\n", 231 | " \n", 232 | " #Generate random normal noise\n", 233 | " noise = torch.randn(init_latent.shape, generator=generator, device=device)\n", 234 | " \n", 235 | " #If init_latents is used, initialize noise as init_latent\n", 236 | " if init_latents is not None:\n", 237 | " noise = init_latents\n", 238 | " \n", 239 | " init_latents = noise\n", 240 | " latent = scheduler.add_noise(init_latent, noise, torch.tensor([scheduler.timesteps[t_start]], device=device)).to(device)\n", 241 | " \n", 242 | " #Process clip\n", 243 | " with autocast(device):\n", 244 | " tokens_unconditional = clip_tokenizer(\"\", padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 245 | " embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state\n", 246 | "\n", 247 | " tokens_conditional = clip_tokenizer(prompt, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 248 | " embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state\n", 249 | "\n", 250 | " #Process prompt editing\n", 251 | " if prompt_edit is not None:\n", 252 | " tokens_conditional_edit = clip_tokenizer(prompt_edit, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 253 | " embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state\n", 254 | " \n", 255 | " init_attention_edit(tokens_conditional, tokens_conditional_edit)\n", 256 | " \n", 257 | " init_attention_func()\n", 258 | " init_attention_weights(prompt_edit_token_weights)\n", 259 | " \n", 260 | " timesteps = scheduler.timesteps[t_start:]\n", 261 | " \n", 262 | " for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):\n", 263 | " t_index = t_start + i\n", 264 | "\n", 265 | " #sigma = scheduler.sigmas[t_index]\n", 266 | " latent_model_input = latent\n", 267 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 268 | " \n", 269 | " #Predict the unconditional noise residual\n", 270 | " noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional).sample\n", 271 | " \n", 272 | " #Prepare the Cross-Attention layers\n", 273 | " if prompt_edit is not None:\n", 274 | " save_last_tokens_attention()\n", 275 | " save_last_self_attention()\n", 276 | " else:\n", 277 | " #Use weights on non-edited prompt when edit is None\n", 278 | " use_last_tokens_attention_weights()\n", 279 | " \n", 280 | " #Predict the conditional noise residual and save the cross-attention layer activations\n", 281 | " noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional).sample\n", 282 | " \n", 283 | " #Edit the Cross-Attention layer activations\n", 284 | " if prompt_edit is not None:\n", 285 | " t_scale = t / scheduler.num_train_timesteps\n", 286 | " if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:\n", 287 | " use_last_tokens_attention()\n", 288 | " if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:\n", 289 | " use_last_self_attention()\n", 290 | " \n", 291 | " #Use weights on edited prompt\n", 292 | " use_last_tokens_attention_weights()\n", 293 | "\n", 294 | " #Predict the edited conditional noise residual using the cross-attention masks\n", 295 | " noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional_edit).sample\n", 296 | " \n", 297 | " #Perform guidance\n", 298 | " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n", 299 | " \n", 300 | " latent = scheduler.step(noise_pred, t_index, latent).prev_sample\n", 301 | "\n", 302 | " #scale and decode the image latents with vae\n", 303 | " latent = latent / 0.18215\n", 304 | " image = vae.decode(latent.to(vae.dtype)).sample\n", 305 | "\n", 306 | " image = (image / 2 + 0.5).clamp(0, 1)\n", 307 | " image = image.cpu().permute(0, 2, 3, 1).numpy()\n", 308 | " image = (image[0] * 255).round().astype(\"uint8\")\n", 309 | " return Image.fromarray(image)\n" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "id": "ab397f07-f069-4e5d-96fe-6edcc779884f", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "@torch.no_grad()\n", 320 | "def inversestablediffusion(init_image, prompt=\"\", guidance_scale=3.0, steps=50, refine_iterations=3, refine_strength=0.9, refine_skip=0.7):\n", 321 | " #Change size to multiple of 64 to prevent size mismatches inside model\n", 322 | " width, height = init_image.size\n", 323 | " width = width - width % 64\n", 324 | " height = height - height % 64\n", 325 | " \n", 326 | " image_width, image_height = init_image.size\n", 327 | " left = (image_width - width)/2\n", 328 | " top = (image_height - height)/2\n", 329 | " right = left + width\n", 330 | " bottom = top + height\n", 331 | " \n", 332 | " init_image = init_image.crop((left, top, right, bottom))\n", 333 | " init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0\n", 334 | " init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))\n", 335 | "\n", 336 | " #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel\n", 337 | " if init_image.shape[1] > 3:\n", 338 | " init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])\n", 339 | "\n", 340 | " #Move image to GPU\n", 341 | " init_image = init_image.to(device)\n", 342 | "\n", 343 | " train_steps = 1000\n", 344 | " step_ratio = train_steps // steps\n", 345 | " timesteps = torch.from_numpy(np.linspace(0, train_steps - 1, steps + 1, dtype=float)).int().to(device)\n", 346 | " \n", 347 | " betas = torch.linspace(0.00085**0.5, 0.012**0.5, train_steps, dtype=torch.float32) ** 2\n", 348 | " alphas = torch.cumprod(1 - betas, dim=0)\n", 349 | " \n", 350 | " init_step = 0\n", 351 | " \n", 352 | " #Fixed seed such that the vae sampling is deterministic, shouldn't need to be changed by the user...\n", 353 | " generator = torch.cuda.manual_seed(798122)\n", 354 | " \n", 355 | " #Process clip\n", 356 | " with autocast(device):\n", 357 | " init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215\n", 358 | " \n", 359 | " tokens_unconditional = clip_tokenizer(\"\", padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 360 | " embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state\n", 361 | "\n", 362 | " tokens_conditional = clip_tokenizer(prompt, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True)\n", 363 | " embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state\n", 364 | " \n", 365 | " latent = init_latent\n", 366 | "\n", 367 | " for i in tqdm(range(steps), total=steps):\n", 368 | " t_index = i + init_step\n", 369 | " \n", 370 | " t = timesteps[t_index]\n", 371 | " t1 = timesteps[t_index + 1]\n", 372 | " #Magic number for tless taken from Narnia, used for backwards CFG correction\n", 373 | " tless = t - (t1 - t) * 0.25\n", 374 | " \n", 375 | " ap = alphas[t] ** 0.5\n", 376 | " bp = (1 - alphas[t]) ** 0.5\n", 377 | " ap1 = alphas[t1] ** 0.5\n", 378 | " bp1 = (1 - alphas[t1]) ** 0.5\n", 379 | " \n", 380 | " latent_model_input = latent\n", 381 | " #Predict the unconditional noise residual\n", 382 | " noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional).sample\n", 383 | " \n", 384 | " #Predict the conditional noise residual and save the cross-attention layer activations\n", 385 | " noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional).sample\n", 386 | " \n", 387 | " #Perform guidance\n", 388 | " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n", 389 | " \n", 390 | " #One reverse DDIM step\n", 391 | " px0 = (latent_model_input - bp * noise_pred) / ap\n", 392 | " latent = ap1 * px0 + bp1 * noise_pred\n", 393 | " \n", 394 | " #Initialize loop variables\n", 395 | " latent_refine = latent\n", 396 | " latent_orig = latent_model_input\n", 397 | " min_error = 1e10\n", 398 | " lr = refine_strength\n", 399 | " \n", 400 | " #Finite difference gradient descent method to correct for classifier free guidance, performs best when CFG is high\n", 401 | " #Very slow and unoptimized, might be able to use Newton's method or some other multidimensional root finding method\n", 402 | " if i > (steps * refine_skip):\n", 403 | " for k in range(refine_iterations):\n", 404 | " #Compute reverse diffusion process to get better prediction for noise at t+1\n", 405 | " #tless and t are used instead of the \"numerically correct\" t+1, produces way better results in practice, reason unknown...\n", 406 | " noise_pred_uncond = unet(latent_refine, tless, encoder_hidden_states=embedding_unconditional).sample\n", 407 | " noise_pred_cond = unet(latent_refine, t, encoder_hidden_states=embedding_conditional).sample\n", 408 | " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n", 409 | " \n", 410 | " #One forward DDIM Step\n", 411 | " px0 = (latent_refine - bp1 * noise_pred) / ap1\n", 412 | " latent_refine_orig = ap * px0 + bp * noise_pred\n", 413 | " \n", 414 | " #Save latent if error is smaller\n", 415 | " error = float((latent_orig - latent_refine_orig).abs_().sum())\n", 416 | " if error < min_error:\n", 417 | " latent = latent_refine\n", 418 | " min_error = error\n", 419 | "\n", 420 | " #print(k, error)\n", 421 | " \n", 422 | " #Break to avoid \"overfitting\", too low error does not produce good results in practice, why?\n", 423 | " if min_error < 5:\n", 424 | " break\n", 425 | " \n", 426 | " #\"Learning rate\" decay if error decrease is too small or negative (dampens oscillations)\n", 427 | " if (min_error - error) < 1:\n", 428 | " lr *= 0.9\n", 429 | " \n", 430 | " #Finite difference gradient descent\n", 431 | " latent_refine = latent_refine + (latent_model_input - latent_refine_orig) * lr\n", 432 | " \n", 433 | " \n", 434 | " return latent" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "f2e9bb8d-2050-4283-89b1-7ffe18a8475f", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "def prompt_token(prompt, index):\n", 445 | " tokens = clip_tokenizer(prompt, padding=\"max_length\", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors=\"pt\", return_overflowing_tokens=True).input_ids[0]\n", 446 | " return clip_tokenizer.decode(tokens[index:index+1])" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "id": "38221201-23c7-4760-bfb2-6aae227316b3", 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "prompt_token(\"a photo of a woman with blonde hair\", 2)" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "id": "05b63c1f-9bb6-4575-87ab-05663b8846f1", 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "#https://pixabay.com/photos/portrait-woman-model-face-6054910/\n", 467 | "input_image = Image.open(\"portrait.png\")\n", 468 | "input_image" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "id": "66463f61-6dde-41d1-a77e-36280964c9be", 475 | "metadata": { 476 | "scrolled": true, 477 | "tags": [] 478 | }, 479 | "outputs": [], 480 | "source": [ 481 | "gen_latents = inversestablediffusion(input_image, \"a photo of a woman with blonde hair\", refine_iterations=5, guidance_scale=4.0)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "id": "2d29ee30-6f45-4bfb-9626-473460869b19", 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "stablediffusion(\"a photo of a woman with blonde hair\", guidance_scale=4.0, init_latents=gen_latents)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "id": "b74f85fb-c13c-4467-80e5-abe5a832938b", 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young man with short blonde hair\", prompt_edit_spatial_end=0.9, guidance_scale=4.0, init_latents=gen_latents)" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "id": "c811556f-9707-44d1-b136-d190bdfdecba", 508 | "metadata": {}, 509 | "outputs": [], 510 | "source": [ 511 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young man with long blonde hair and glasses\", guidance_scale=4.0, init_latents=gen_latents)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "id": "ab4f5614-1327-40f7-ba96-9cef764f39e0", 518 | "metadata": {}, 519 | "outputs": [], 520 | "source": [ 521 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a woman with blonde hair, tanned skin\", prompt_edit_spatial_end=0.9, guidance_scale=4.0, init_latents=gen_latents)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "id": "d48a32e2-d32b-4cd7-a277-2c133f981100", 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young woman with short blonde hair and earrings, smiling\", prompt_edit_token_weights=[(14, -2.5)], prompt_edit_spatial_end=0.9, guidance_scale=4.0, init_latents=gen_latents)" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "id": "e9779806-212a-4b21-abd7-1057b029af23", 538 | "metadata": {}, 539 | "outputs": [], 540 | "source": [ 541 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young woman with brown hair, smiling\", prompt_edit_token_weights=[(11, -2.5)], guidance_scale=4.0, init_latents=gen_latents)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "id": "eee76816-9192-4c9e-b547-1fcdea42a719", 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young woman with wavy black hair, smiling, raised eyebrows\", prompt_edit_token_weights=[(12, -3), (14, -3)], guidance_scale=4.0, init_latents=gen_latents)" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "id": "050427ce-d8a1-493d-bc9c-da4b4c882ed4", 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "stablediffusion(\"a photo of a woman with blonde hair\", \"a photo of a young woman with curly black hair, smiling, raised eyebrows\", prompt_edit_token_weights=[(12, -2.5), (14, -3)], init_latents=gen_latents, guidance_scale=4.0)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "id": "7918056e-a5b4-40ed-9b49-511b00caa5b8", 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [] 571 | } 572 | ], 573 | "metadata": { 574 | "kernelspec": { 575 | "display_name": "diffusers", 576 | "language": "python", 577 | "name": "diffusers" 578 | }, 579 | "language_info": { 580 | "codemirror_mode": { 581 | "name": "ipython", 582 | "version": 3 583 | }, 584 | "file_extension": ".py", 585 | "mimetype": "text/x-python", 586 | "name": "python", 587 | "nbconvert_exporter": "python", 588 | "pygments_lexer": "ipython3", 589 | "version": "3.10.4" 590 | } 591 | }, 592 | "nbformat": 4, 593 | "nbformat_minor": 5 594 | } 595 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 bloc97 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross Attention Control with Stable Diffusion 2 | Unofficial implementation of "Prompt-to-Prompt Image Editing with Cross Attention Control" with Stable Diffusion, some modifications were made to the methods described in the paper in order to make them work with Stable Diffusion. 3 | 4 | Paper: https://arxiv.org/abs/2208.01626 5 | Official implementation: https://github.com/google/prompt-to-prompt 6 | 7 | ## What is Cross Attention Control? 8 | Large-scale language-image models (eg. Stable Diffusion) are usually hard to control just with editing the prompts alone and can be very unpredictable and unintuitive for users. Most existing methods require the user to input a mask which is cumbersome and might not yield good results if the mask has an inadequate shape. Cross Attention Control allows much finer control of the prompt by modifying the internal attention maps of the diffusion model during inference without the need for the user to input a mask and does so with minimal performance penalities (compared to clip guidance) and no additional training or fine-tuning of the diffusion model. 9 | 10 | ## Getting started 11 | This notebook uses the following libraries: `torch transformers diffusers numpy PIL tqdm difflib` 12 | The last known working version of `diffusers` for the notebook is `diffusers==0.4.1`. A different version of diffusers might cause errors as this notebook injects code into the model and any code change from the `diffusers` library is likely to break compatibility. 13 | Simply install the required libraries using `pip` and run the jupyter notebook, some examples are given inside. 14 | A description of the parameters are given at the end of the readme. 15 | 16 | Alternatively there is this easy-to-follow colab demo by [Lewington-pitsos](https://github.com/Lewington-pitsos): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1PsWKXtqAAoDz-KGB45VeCXdTsqW-Mumo) 17 | 18 | # Results/Demonstrations 19 | **All images shown below are generated using the same seed. The initial and target images must be generated with the same seed for cross attention control to work.** 20 | 21 | ## New: Image inversion 22 | This method takes an existing image and finds its corresponding gaussian latent vector using a modified inverse DDIM process that keeps compatibility with other ODE schedulers such as K-LMS, then edits using prompt to prompt editing with cross attention control. A finite difference gradient descent method that corrects for high CFG values is also provided. It allows inversion with higher CFG values (eg. 3.0-5.0), while without it only CFG values below 3.0 are usable. 23 | 24 | Middle: Original image 25 | Top left: Reconstructed image using the prompt `a photo of a woman with blonde hair` 26 | Clockwise: See [InverseCrossAttention_Release.ipynb](https://github.com/bloc97/CrossAttentionControl/blob/main/InverseCrossAttention_Release.ipynb) for the prompts in order. 27 | Note that some fine tuning on the prompts have been done to make these images consistent. For example, when changing the hair color, sometimes the person starts smiling, which can be removed by adding a `smile` token in the prompt and adjust its weight downwards using cross attention control. 28 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/faces_test.png?raw=true) 29 | 30 | 31 | ## Target replacement 32 | Top left prompt: `[a cat] sitting on a car` 33 | Clockwise: `a smiling dog...`, `a hamster...`, `a tiger...` 34 | Note: different strength values for `prompt_edit_spatial_start` were used, clockwise: `0.7`, `0.5`, `1.0` 35 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/fouranimals.png?raw=true) 36 | 37 | ## Style injection 38 | Top left prompt: `a fantasy landscape with a maple forest` 39 | Clockwise: `a watercolor painting of...`, `a van gogh painting of...`, `a charcoal pencil sketch of...` 40 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/fourstyles.png?raw=true) 41 | 42 | ## Global editing 43 | Top left prompt: `a fantasy landscape with a pine forest` 44 | Clockwise: `..., autumn`, `..., winter`, `..., spring, green` 45 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/fourseasons.png?raw=true) 46 | 47 | ## Reducing unpredictability when modifying prompts 48 | 49 | Left image prompt: `a fantasy landscape with a pine forest` 50 | Right image prompt: `a winter fantasy landscape with a pine forest` 51 | Middle image: Cross attention enabled prompt editing (left image -> right image) 52 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20a%20winter%20fantasy%20landscape%20with%20a%20pine%20forest.png?raw=true) 53 | 54 | Left image prompt: `a fantasy landscape with a pine forest` 55 | Right image prompt: `a watercolor painting of a landscape with a pine forest` 56 | Middle image: Cross attention enabled prompt editing (left image -> right image) 57 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20a%20watercolor%20painting%20of%20a%20landscape%20with%20a%20pine%20forest.png?raw=true) 58 | 59 | Left image prompt: `a fantasy landscape with a pine forest` 60 | Right image prompt: `a fantasy landscape with a pine forest and a river` 61 | Middle image: Cross attention enabled prompt editing (left image -> right image) 62 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20A%20fantasy%20landscape%20with%20a%20pine%20forest%20and%20a%20river.png?raw=true) 63 | 64 | ## Direct token attention control 65 | Left image prompt: `a fantasy landscape with a pine forest` 66 | Towards the right: `-fantasy` 67 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20decrease%20fantasy.png?raw=true) 68 | 69 | Left image prompt: `a fantasy landscape with a pine forest` 70 | Towards the right: `+fantasy` and `+forest` 71 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20increase%20fantasy%20and%20forest.png?raw=true) 72 | 73 | Left image prompt: `a fantasy landscape with a pine forest` 74 | Towards the right: `-fog` 75 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20decrease%20fog.png?raw=true) 76 | 77 | Left image: from previous example 78 | Towards the right: `-rocks` 79 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/a%20fantasy%20landscape%20with%20a%20pine%20forest%20-%20decrease%20rocks.png?raw=true) 80 | 81 | ## Comparison to standard prompt editing 82 | Let's compare our results above where we removed fog and rocks from our fantasy landscape using cross attention maps against what people usually do, by editing the prompt alone. 83 | We can first try adding "without fog and without rocks" to our prompt. 84 | 85 | Image prompt: `A fantasy landscape with a pine forest without fog and without rocks` 86 | However, we still see fog and rocks. 87 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/A%20fantasy%20landscape%20with%20a%20pine%20forest%20without%20fog%20and%20without%20rocks.png?raw=true) 88 | 89 | We can try adding words like dry, sunny and grass. 90 | Image prompt: `A fantasy landscape with a pine forest without fog and rocks, dry sunny day, grass` 91 | There are less rocks and fog, but the image's composition and style is completely different from before and we still haven't obtained our desired fog and rock-free image... 92 | ![Demo](https://github.com/bloc97/CrossAttentionControl/blob/main/images/A%20fantasy%20landscape%20with%20a%20pine%20forest%20without%20fog%20and%20rocks%2C%20dry%20sunny%20day%2C%20grass.png?raw=true) 93 | 94 | 95 | ## Usage 96 | Two functions are included, `stablediffusion(...)` which generates images and `prompt_token(...)` that is used to help the user find the token index for words in the prompt, which is used to tweak token weights in `prompt_edit_token_weights`. 97 | 98 | Parameters of `stabledifusion(...)`: 99 | | Name = Default Value | Description | Example | 100 | |---|---|---| 101 | | `prompt=""` | the prompt as a string | `"a cat riding a bicycle"` | 102 | | `prompt_edit=None` | the second prompt as a string, used to edit the first prompt using cross attention, set `None` to disable | `"a dog riding a bicycle"` | 103 | | `prompt_edit_token_weights=[]` | values to scale the importance of the tokens in cross attention layers, as a list of tuples representing `(token id, strength)`, this is used to increase or decrease the importance of a word in the prompt, it is applied to `prompt_edit` when possible (if `prompt_edit` is `None`, weights are applied to `prompt`) | `[(2, 2.5), (6, -5.0)]` | 104 | | `prompt_edit_tokens_start=0.0` | how strict is the generation with respect to the initial prompt, increasing this will let the network be more creative for smaller details/textures, should be smaller than `prompt_edit_tokens_end` | `0.0` | 105 | | `prompt_edit_tokens_end=1.0` | how strict is the generation with respect to the initial prompt, decreasing this will let the network be more creative for larger features/general scene composition, should be bigger than `prompt_edit_tokens_start` | `1.0` | 106 | | `prompt_edit_spatial_start=0.0` | how strict is the generation with respect to the initial image *(generated from the first prompt, not from img2img)*, increasing this will let the network be more creative for smaller details/textures, should be smaller than `prompt_edit_spatial_end` | `0.0` | 107 | | `prompt_edit_spatial_end=1.0` | how strict is the generation with respect to the initial image *(generated from the first prompt, not from img2img)*, decreasing this will let the network be more creative for larger features/general scene composition, should be bigger than `prompt_edit_spatial_start` | `1.0` | 108 | | `guidance_scale=7.5` | standard classifier-free guidance strength for stable diffusion | `7.5` | 109 | | `steps=50` | number of diffusion steps as an integer, higher usually produces better images but is slower | `50` | 110 | | `seed=None` | random seed as an integer, set `None` to use a random seed | `126794873` | 111 | | `width=512` | image width | `512` | 112 | | `height=512` | image height | `512` | 113 | | `init_image=None` | init image for image to image generation, as a PIL image, it will be resized to `width x height` | `PIL.Image()` | 114 | | `init_image_strength=0.5` | strength of the noise added for image to image generation, higher will make the generation care less about the initial image | `0.5` | 115 | 116 | Parameters of `inversestabledifusion(...)`: 117 | | Name = Default Value | Description | Example | 118 | |---|---|---| 119 | | `init_image` | the image to invert | `PIL.Image("portrait.png")` | 120 | | `prompt=""` | the prompt as a string used for inversion | `"portrait of a person"` | 121 | | `guidance_scale=3.0` | standard classifier-free guidance strength for stable diffusion | `3.0` | 122 | | `steps=50` | number of diffusion steps used for inversion, as an integer | `50` | 123 | | `refine_iterations=3` | inversion refinement iterations for high CFG values, set to 0 to disable refinement when using lower CFG values, for higher CFG values, consider increasing it. Higher values slow down the algorithm significantly. | `3` | 124 | | `refine_strength=0.9` | initial strength value for the refinement steps, the internal strength is adaptive | `0.9` | 125 | | `refine_skip=0.7` | how many diffusion steps of refinement are skipped (value between `0.0` and `1.0`), there is usually no need to refine earlier diffusion steps as CFG is not very important in lower time steps, higher values will skip even more steps | `0.7` | 126 | -------------------------------------------------------------------------------- /images/A fantasy landscape with a pine forest without fog and rocks, dry sunny day, grass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/A fantasy landscape with a pine forest without fog and rocks, dry sunny day, grass.png -------------------------------------------------------------------------------- /images/A fantasy landscape with a pine forest without fog and without rocks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/A fantasy landscape with a pine forest without fog and without rocks.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - A fantasy landscape with a pine forest and a river.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - A fantasy landscape with a pine forest and a river.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - a watercolor painting of a landscape with a pine forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - a watercolor painting of a landscape with a pine forest.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - a winter fantasy landscape with a pine forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - a winter fantasy landscape with a pine forest.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - decrease clouds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - decrease clouds.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - decrease fantasy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - decrease fantasy.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - decrease fog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - decrease fog.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - decrease rocks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - decrease rocks.png -------------------------------------------------------------------------------- /images/a fantasy landscape with a pine forest - increase fantasy and forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/a fantasy landscape with a pine forest - increase fantasy and forest.png -------------------------------------------------------------------------------- /images/faces_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/faces_test.png -------------------------------------------------------------------------------- /images/fouranimals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/fouranimals.png -------------------------------------------------------------------------------- /images/fourseasons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/fourseasons.png -------------------------------------------------------------------------------- /images/fourstyles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/images/fourstyles.png -------------------------------------------------------------------------------- /portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloc97/CrossAttentionControl/bcb095b11e7270051e4a329ac1dcbd6ed75f129b/portrait.png --------------------------------------------------------------------------------