├── README.md ├── self_guidance.ipynb └── self_guidance_clean.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Self-Guidance for Controllable Image Generation 2 | 3 | This repo is an unofficial implementation of ["Diffusion Self-Guidance for Controllable Image Generation"](https://arxiv.org/abs/2306.00986) (Epstein et al., 2023), built with Stable Diffusion. All code and experiments are contained inside the notebook file — the clean version is just code, and the other one includes experiments. 4 | -------------------------------------------------------------------------------- /self_guidance_clean.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f9d062fb-2ec8-452c-8d90-d0286d942a23", 6 | "metadata": {}, 7 | "source": [ 8 | "# Diffusion self-guidance for controllable image generation\n", 9 | "\n", 10 | "This notebook is an unofficial implementation of the [Diffusion Self-Guidance for Controllable Image Generation](https://arxiv.org/abs/2306.00986). If you are reading this and want to use it, my suggestion is to take this implementation as a start rather than an end — it works in some cases, but more research is needed to get guaranteed results for each kind of edit." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "df53dc0c-b3c5-4c4f-b20c-211af3c4dc37", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from __future__ import annotations\n", 21 | "import math, random, torch, matplotlib.pyplot as plt, numpy as np, matplotlib as mpl, shutil, os, gzip, pickle, re, copy\n", 22 | "from pathlib import Path\n", 23 | "from operator import itemgetter\n", 24 | "from itertools import zip_longest\n", 25 | "from functools import partial\n", 26 | "import fastcore.all as fc\n", 27 | "from glob import glob\n", 28 | "\n", 29 | "from torch import tensor, nn, optim\n", 30 | "import torch.nn.functional as F\n", 31 | "from tqdm.auto import tqdm\n", 32 | "import torchvision.transforms.functional as TF\n", 33 | "from torch.nn import init\n", 34 | "from diffusers import LMSDiscreteScheduler, UNet2DConditionModel, AutoencoderKL\n", 35 | "from transformers import AutoTokenizer, CLIPTextModel\n", 36 | "\n", 37 | "# from miniai.core import *\n", 38 | "\n", 39 | "from einops import rearrange\n", 40 | "from fastprogress import progress_bar\n", 41 | "from PIL import Image\n", 42 | "from torchvision.io import read_image,ImageReadMode\n", 43 | "\n", 44 | "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", 45 | "torch.manual_seed(1)\n", 46 | "mpl.rcParams['image.cmap'] = 'gray_r'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "ed5aafee-e571-4976-b02b-c9ce63631a99", 52 | "metadata": {}, 53 | "source": [ 54 | "#### Helper functions" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "7c1abc70-2d51-4cc4-b789-eef43218088e", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "def add_dims_right(x,y):\n", 65 | " dim = y.ndim - x.ndim\n", 66 | " return x[(...,) + (None,)*dim]\n", 67 | "\n", 68 | "def add_dims_left(x, y):\n", 69 | " dim = y.ndim - x.ndim\n", 70 | " return x[(None,)*dim + (...,)]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "0a4277c9-a37c-4012-a25c-89ce0cd3d611", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "def get_embeddings(prompt, concat_unconditional=False, device='cpu'):\n", 81 | " text_input = tokeniser(prompt, padding=\"max_length\", max_length=tokeniser.model_max_length, truncation=True, return_tensors=\"pt\")\n", 82 | " max_length = text_input.input_ids.shape[-1]\n", 83 | " with torch.no_grad():\n", 84 | " embeds = text_encoder(text_input.input_ids)[0]\n", 85 | " if concat_unconditional:\n", 86 | " uncond_input = tokeniser([\"\"], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n", 87 | " uncond_embeddings = text_encoder(uncond_input.input_ids)[0]\n", 88 | " embeds = torch.cat([uncond_embeddings, embeds])\n", 89 | " return embeds.to(device)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "276ae626-ffc5-41d2-9397-2b02cc096d23", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "def encode_img(input_img):\n", 100 | " if len(input_img.shape)<4: input_img = input_img.unsqueeze(0)\n", 101 | " with torch.no_grad(): latent = vae.encode(input_img*2 - 1)\n", 102 | " return 0.18215 * latent.latent_dist.sample()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "505b96fd-e9ff-40b0-ae72-8287cab91a38", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def process(image): return (image.clip(-1,1) + 1) / 2" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "5f210145-0f72-49cd-be68-6b73fb120286", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "@fc.delegates(plt.Axes.imshow)\n", 123 | "def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):\n", 124 | " \"Show a PIL or PyTorch image on `ax`.\"\n", 125 | " if fc.hasattrs(im, ('cpu','permute','detach')):\n", 126 | " im = im.detach().cpu()\n", 127 | " if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)\n", 128 | " elif not isinstance(im,np.ndarray): im=np.array(im)\n", 129 | " if im.shape[-1]==1: im=im[...,0]\n", 130 | " if ax is None: _,ax = plt.subplots(figsize=figsize)\n", 131 | " ax.imshow(im, **kwargs)\n", 132 | " if title is not None: ax.set_title(title)\n", 133 | " ax.set_xticks([]) \n", 134 | " ax.set_yticks([]) \n", 135 | " if noframe: ax.axis('off')\n", 136 | " return ax" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "22e2211d-a0ae-4158-9523-8769f95a1e9f", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "@fc.delegates(subplots)\n", 147 | "def get_grid(\n", 148 | " n:int, # Number of axes\n", 149 | " nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))`\n", 150 | " ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)`\n", 151 | " title:str=None, # If passed, title set to the figure\n", 152 | " weight:str='bold', # Title font weight\n", 153 | " size:int=14, # Title font size\n", 154 | " **kwargs,\n", 155 | "): # fig and axs\n", 156 | " \"Return a grid of `n` axes, `rows` by `cols`\"\n", 157 | " if nrows: ncols = ncols or int(np.floor(n/nrows))\n", 158 | " elif ncols: nrows = nrows or int(np.ceil(n/ncols))\n", 159 | " else:\n", 160 | " nrows = int(math.sqrt(n))\n", 161 | " ncols = int(np.floor(n/nrows))\n", 162 | " fig,axs = subplots(nrows, ncols, **kwargs)\n", 163 | " for i in range(n, nrows*ncols): axs.flat[i].set_axis_off()\n", 164 | " if title is not None: fig.suptitle(title, weight=weight, size=size)\n", 165 | " return fig,axs" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "d9125689-7f7d-4ed4-a8c7-12707833dd31", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "@fc.delegates(plt.subplots, keep=True)\n", 176 | "def subplots(\n", 177 | " nrows:int=1, # Number of rows in returned axes grid\n", 178 | " ncols:int=1, # Number of columns in returned axes grid\n", 179 | " figsize:tuple=None, # Width, height in inches of the returned figure\n", 180 | " imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure\n", 181 | " suptitle:str=None, # Title to be set to returned figure\n", 182 | " **kwargs\n", 183 | "): # fig and axs\n", 184 | " \"A figure and set of subplots to display images of `imsize` inches\"\n", 185 | " if figsize is None: figsize=(ncols*imsize, nrows*imsize)\n", 186 | " fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)\n", 187 | " if suptitle is not None: fig.suptitle(suptitle)\n", 188 | " if nrows*ncols==1: ax = np.array([ax])\n", 189 | " return fig,ax" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "a1daf504-ef04-494b-903f-bd1197069768", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "@fc.delegates(subplots)\n", 200 | "def show_images(ims:list, # Images to show\n", 201 | " nrows:int=None, # Number of rows in grid\n", 202 | " ncols:int=None, # Number of columns in grid (auto-calculated if None)\n", 203 | " titles:list=None, # Optional list of titles for each image\n", 204 | " **kwargs):\n", 205 | " \"Show all images `ims` as subplots with `rows` using `titles`\"\n", 206 | " axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat\n", 207 | " for im,t,ax in zip_longest(ims, titles or [], axs): show_image(im, ax=ax, title=t)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "0d1b2c2e-f829-401e-9c3f-d806756245e9", 213 | "metadata": {}, 214 | "source": [ 215 | "#### Attention and activation collection/storage" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "33e74ceb-8ce6-4c4e-9b11-b40f1e00c67e", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "from diffusers.models.attention_processor import AttnProcessor, Attention\n", 226 | "\n", 227 | "def get_features(hook, layer, inp, out):\n", 228 | " if not hasattr(hook, 'feats'): hook.feats = out\n", 229 | " hook.feats = out\n", 230 | "\n", 231 | "class Hook():\n", 232 | " def __init__(self, model, func): self.hook = model.register_forward_hook(partial(func, self))\n", 233 | " def remove(self): self.hook.remove()\n", 234 | " def __del__(self): self.remove()\n", 235 | "\n", 236 | "def get_attn_dict(processor, model):\n", 237 | " attn_procs = {}\n", 238 | " for name in model.attn_processors.keys():\n", 239 | " attn_procs[name] = processor(name=name)\n", 240 | " return attn_procs\n", 241 | "\n", 242 | "class AttnStorage:\n", 243 | " def __init__(self): self.storage = {}\n", 244 | " def __call__(self, attention_map, name, pred_type='orig'): \n", 245 | " if not name in self.storage: self.storage[name] = {}\n", 246 | " self.storage[name][pred_type] = attention_map\n", 247 | " def flush(self): self.storage = {}\n", 248 | "\n", 249 | "class CustomAttnProcessor(AttnProcessor):\n", 250 | " def __init__(self, attn_storage, name=None): \n", 251 | " fc.store_attr()\n", 252 | " self.store = False\n", 253 | " self.type = \"attn2\" if \"attn2\" in name else \"attn1\"\n", 254 | " def set_storage(self, store, pred_type): \n", 255 | " self.store = store\n", 256 | " self.pred_type = pred_type\n", 257 | " def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):\n", 258 | " batch_size, sequence_length, _ = (\n", 259 | " hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n", 260 | " )\n", 261 | " attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n", 262 | " query = attn.to_q(hidden_states)\n", 263 | "\n", 264 | " if encoder_hidden_states is None:\n", 265 | " encoder_hidden_states = hidden_states\n", 266 | " elif attn.norm_cross:\n", 267 | " encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n", 268 | "\n", 269 | " key = attn.to_k(encoder_hidden_states)\n", 270 | " value = attn.to_v(encoder_hidden_states)\n", 271 | " \n", 272 | " query = attn.head_to_batch_dim(query)\n", 273 | " key = attn.head_to_batch_dim(key)\n", 274 | " value = attn.head_to_batch_dim(value)\n", 275 | " \n", 276 | " attention_probs = attn.get_attention_scores(query, key, attention_mask)\n", 277 | " attention_probs.requires_grad_(True)\n", 278 | " \n", 279 | " if self.store: self.attn_storage(attention_probs, self.name, pred_type=self.pred_type) ## stores the attention maps in attn_storage\n", 280 | " \n", 281 | " hidden_states = torch.bmm(attention_probs, value)\n", 282 | " hidden_states = attn.batch_to_head_dim(hidden_states)\n", 283 | "\n", 284 | " # linear proj\n", 285 | " hidden_states = attn.to_out[0](hidden_states)\n", 286 | " # dropout\n", 287 | " hidden_states = attn.to_out[1](hidden_states)\n", 288 | " \n", 289 | " return hidden_states\n", 290 | "\n", 291 | "def prepare_attention(model, attn_storage, pred_type='orig', set_store=True):\n", 292 | " for name, module in model.attn_processors.items(): module.set_storage(set_store, pred_type)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "id": "e9e419a0-93dd-4e57-b2a8-8896aa130d8a", 298 | "metadata": {}, 299 | "source": [ 300 | "#### Self guidance equations" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "ebe1ec81-646e-42c4-bf08-c0374c00badf", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "def normalise(x): return (x - x.min()) / (x.max() - x.min())" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "id": "80563788-f815-42e8-90a4-3926f6353d87", 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "def threshold_attention(attn, s=10):\n", 321 | " norm_attn = s * (normalise(attn) - 0.5)\n", 322 | " return normalise(norm_attn.sigmoid())" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "id": "30527cfc-5578-46fe-9b41-325755dd21f0", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "def get_shape(attn, s=20): return threshold_attention(attn, s)\n", 333 | "def get_size(attn): return 1/attn.shape[-2] * threshold_attention(attn).sum((1,2)).mean()\n", 334 | "def get_centroid(attn):\n", 335 | " if not len(attn.shape) == 3: attn = attn[:,:,None]\n", 336 | " h = w = int(tensor(attn.shape[-2]).sqrt().item())\n", 337 | " hs = torch.arange(h).view(-1, 1, 1).to(attn.device)\n", 338 | " ws = torch.arange(w).view(1, -1, 1).to(attn.device)\n", 339 | " attn = rearrange(attn.mean(0), '(h w) d -> h w d', h=h)\n", 340 | " weighted_w = torch.sum(ws * attn, dim=[0,1])\n", 341 | " weighted_h = torch.sum(hs * attn, dim=[0,1])\n", 342 | " return torch.stack([weighted_w, weighted_h]) / attn.sum((0,1))\n", 343 | "def get_appearance(attn, feats):\n", 344 | " if not len(attn.shape) == 3: attn = attn[:,:,None]\n", 345 | " h = w = int(tensor(attn.shape[-2]).sqrt().item())\n", 346 | " shape = get_shape(attn).detach().mean(0).view(h,w,attn.shape[-1])\n", 347 | " feats = feats.mean((0,1))[:,:,None]\n", 348 | " return (shape*feats).sum() / shape.sum()" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "98c760b2-6a04-422e-909d-ab86d7061b61", 354 | "metadata": {}, 355 | "source": [ 356 | "#### G functions" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "id": "f5105c4d-f515-4bfa-8898-1a225ca04e3c", 362 | "metadata": {}, 363 | "source": [ 364 | "Single image editing. These are the functions that are closest to the paper. In the experiments section below, I played around with variations on these equations in pursuit of better results." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "id": "c73d6ef1-0c0c-45b2-afd8-755d83580019", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "def fix_shapes(orig_attns, edit_attns, indices, tau=1):\n", 375 | " shapes = []\n", 376 | " for o in indices:\n", 377 | " deltas = []\n", 378 | " for i in range(len(edit_attns)):\n", 379 | " orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]\n", 380 | " delta = tau*get_shape(orig) - get_shape(edit)\n", 381 | " deltas.append(delta.mean())\n", 382 | " shapes.append(torch.stack(deltas).mean())\n", 383 | " return torch.stack(shapes).mean()\n", 384 | "\n", 385 | "def fix_appearances(orig_attns, orig_feats, edit_attns, edit_feats, indices, attn_idx=-1):\n", 386 | " appearances = []\n", 387 | " for o in indices:\n", 388 | " orig = torch.stack([a[:,:,o] for a in orig_attns[-3:]]).mean(0)\n", 389 | " edit = torch.stack([a[:,:,o] for a in edit_attns[-3:]]).mean(0)\n", 390 | " appearances.append((get_appearance(orig, orig_feats) - get_appearance(edit, edit_feats)).pow(2).mean())\n", 391 | " return torch.stack(appearances).mean()\n", 392 | "\n", 393 | "def fix_sizes(orig_attns, edit_attns, indices, tau=1):\n", 394 | " sizes = []\n", 395 | " for i in range(len(edit_attns)):\n", 396 | " orig, edit = orig_attns[i][:,:,indices], edit_attns[i][:,:,indices]\n", 397 | " sizes.append(tau*get_size(orig) - get_size(edit))\n", 398 | " return torch.stack(sizes).mean()\n", 399 | "\n", 400 | "def position_deltas(orig_attns, edit_attns, indices, target_centroid=None):\n", 401 | " positions = []\n", 402 | " for i in range(len(edit_attns)):\n", 403 | " orig, edit = orig_attns[i][:,:,indices], edit_attns[i][:,:,indices]\n", 404 | " target = tensor(target_centroid) if target_centroid is not None else get_centroid(orig)\n", 405 | " positions.append(target.to(orig.device) - get_centroid(edit))\n", 406 | " return torch.stack(positions).mean()\n", 407 | "\n", 408 | "def fix_selfs(origs, edits):\n", 409 | " shapes = []\n", 410 | " for i in range(len(edits)):\n", 411 | " shapes.append((threshold_attention(origs[i]) - threshold_attention(edits[i])).mean())\n", 412 | " return torch.stack(shapes).mean()" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "id": "64bfb404-8e28-4b67-948d-603a83779f0b", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "def get_attns(attn_storage, attn_type='attn2'):\n", 423 | " origs = [v['orig'] for k,v in attn_storage.storage.items() if attn_type in k]\n", 424 | " edits = [v['edit'] for k,v in attn_storage.storage.items() if attn_type in k]\n", 425 | " return origs, edits\n", 426 | "\n", 427 | "def edit_layout(attn_storage, indices, appearance_weight=0.5, orig_feats=None, edit_feats=None, **kwargs):\n", 428 | " origs, edits = get_attns(attn_storage)\n", 429 | " return appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices, **kwargs)\n", 430 | "\n", 431 | "def edit_appearance(attn_storage, indices, shape_weight=1, **kwargs):\n", 432 | " origs, edits = get_attns(attn_storage)\n", 433 | " return shape_weight*fix_shapes(origs, edits, indices)\n", 434 | "\n", 435 | "def resize_object(attn_storage, indices, relative_size=2, shape_weight=1, size_weight=1, appearance_weight=0.1, orig_feats=None, edit_feats=None, **kwargs):\n", 436 | " origs, edits = get_attns(attn_storage)\n", 437 | " if len(indices) > 1: \n", 438 | " obj_idx, other_idx = indices\n", 439 | " indices = torch.cat([obj_idx, other_idx])\n", 440 | " shape_term = shape_weight*fix_shapes(origs, edits, indices)\n", 441 | " appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)\n", 442 | " size_term = size_weight*fix_sizes(origs, edits, indices, tau=relative_size)\n", 443 | " return shape_term + appearance_term + size_term\n", 444 | "\n", 445 | "def move_object(attn_storage, indices, target_centroid=None, shape_weight=1, size_weight=1, appearance_weight=0.5, position_weight=1, orig_feats=None, edit_feats=None, **kwargs):\n", 446 | " origs, edits = get_attns(attn_storage)\n", 447 | " if len(indices) > 1: \n", 448 | " obj_idx, other_idx = indices\n", 449 | " indices = torch.cat([obj_idx, other_idx])\n", 450 | " shape_term = shape_weight*fix_shapes(origs, edits, indices)\n", 451 | " appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)\n", 452 | " size_term = size_weight*fix_sizes(origs, edits, obj_idx)\n", 453 | " position_term = position_weight*position_deltas(origs, edits, obj_idx, target_centroid=target_centroid)\n", 454 | " return shape_term + appearance_term + size_term + position_term" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "id": "3111350f-7d41-4a4e-8000-735423d3dbcd", 460 | "metadata": {}, 461 | "source": [ 462 | "#### Inference loop" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": null, 468 | "id": "0b699d76-42f4-46c3-a464-0499c206b3d3", 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [ 472 | "def do_self_guidance(t, n, scheduler):\n", 473 | " if type(scheduler).__name__ == \"DDPMScheduler\":\n", 474 | " if t <= int((3*n)/16): return True\n", 475 | " elif t >= int(n - n/32): return False\n", 476 | " elif t % 2 == 0: return True\n", 477 | " else: return False\n", 478 | " elif type(scheduler).__name__ == \"LMSDiscreteScheduler\":\n", 479 | " # return True\n", 480 | " if t <= int(n/5): return True\n", 481 | " elif t >= n - 5: return False\n", 482 | " elif t % 2 == 0: return True\n", 483 | " else: return False" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "id": "c5876e6f-b226-4a98-9646-567163fe36b3", 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "def all_word_indexes(prompt, tokeniser, object_to_edit=None, **kwargs):\n", 494 | " \"\"\"Extracts token indexes by treating all words in the prompt as separate objects.\"\"\"\n", 495 | " prompt_inputs = tokeniser(prompt, padding=\"max_length\", max_length=tokeniser.model_max_length, truncation=True, return_tensors=\"pt\").input_ids\n", 496 | " if object_to_edit is not None: \n", 497 | " obj_inputs = tokeniser(object_to_edit, add_special_tokens=False).input_ids\n", 498 | " obj_idx = torch.cat([torch.where(prompt_inputs == o)[1] for o in obj_inputs])\n", 499 | " a = set(torch.cat([torch.where(prompt_inputs != o)[1] for o in obj_inputs]).numpy())\n", 500 | " b = set(torch.where(prompt_inputs < 49405)[1].numpy())\n", 501 | " other_idx = tensor(list(a&b))\n", 502 | " return obj_idx, other_idx\n", 503 | " else: return torch.where(prompt_inputs < 49405)[1]\n", 504 | "\n", 505 | "def choose_object_indexes(prompt, tokeniser, objects:list=None, object_to_edit=None):\n", 506 | " \"\"\"Extracts token indexes only for user-defined objects.\"\"\"\n", 507 | " prompt_inputs = tokeniser(prompt, padding=\"max_length\", max_length=tokeniser.model_max_length, truncation=True, return_tensors=\"pt\").input_ids\n", 508 | " if object_to_edit is not None: \n", 509 | " obj_inputs = tokeniser(object_to_edit, add_special_tokens=False).input_ids\n", 510 | " obj_idx = torch.cat([torch.where(prompt_inputs == o)[1] for o in obj_inputs])\n", 511 | " if object_to_edit in objects: objects.remove(object_to_edit)\n", 512 | " other_idx = []\n", 513 | " for o in objects:\n", 514 | " inps = tokeniser(o, add_special_tokens=False).input_ids\n", 515 | " other_idx.append(torch.cat([torch.where(prompt_inputs == o)[1] for o in inps]))\n", 516 | " if object_to_edit is None: return torch.cat(other_idx)\n", 517 | " else: return obj_idx, torch.cat(other_idx)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "id": "3bdc3a7e-ea66-4672-82b9-5b7e0684de2f", 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "def sg_sample(\n", 528 | " prompt,\n", 529 | " model,\n", 530 | " scheduler,\n", 531 | " guidance_func,\n", 532 | " g_weight=10, \n", 533 | " feature_layer=None, \n", 534 | " idx_func=all_word_indexes,\n", 535 | " objects:list=None,\n", 536 | " obj_to_edit=None,\n", 537 | " use_same_seed=False,\n", 538 | " seed=None, steps=50, guidance_scale=5., device='cuda', height=512, width=512, return_original=True\n", 539 | "):\n", 540 | " if seed is None: seed = int(torch.rand((1,)) * 1000000)\n", 541 | " seed_2 = int(torch.rand((1,)) * 1000000) if not use_same_seed else seed\n", 542 | " \n", 543 | " # set up the custom attn processor and use to replace standard model processors\n", 544 | " storage = AttnStorage()\n", 545 | " processor = partial(CustomAttnProcessor, storage)\n", 546 | " attn_dict = get_attn_dict(processor, model)\n", 547 | " model.set_attn_processor(attn_dict)\n", 548 | " \n", 549 | " # set up the hook to collect activations from feature_layer\n", 550 | " g_name = guidance_func.func.__name__ if isinstance(guidance_func, partial) else guidance_func.__name__\n", 551 | " if g_name not in ['edit_appearance'] and feature_layer is None:\n", 552 | " feature_layer = model.up_blocks[-1].resnets[-2]\n", 553 | " if feature_layer is not None: hook = Hook(feature_layer, get_features)\n", 554 | " \n", 555 | " # get indexes of editable and non-editable objects from token sequence\n", 556 | " if idx_func.__name__ == 'choose_object_indexes' and objects is None:\n", 557 | " raise ValueError('Provide a list of object strings from the prompt.')\n", 558 | " if g_name not in ['edit_layout', 'edit_appearance', 'edit_layout_2'] and obj_to_edit is None:\n", 559 | " raise ValueError('Provide an object string for editing.')\n", 560 | " indices = idx_func(prompt, tokeniser, objects=objects, object_to_edit=obj_to_edit)\n", 561 | " \n", 562 | " # set up embeddings, latents and scheduler\n", 563 | " uncond_embeddings = get_embeddings(\"\", concat_unconditional=False, device=device)\n", 564 | " cond_embeddings = get_embeddings(prompt, concat_unconditional=False, device=device)\n", 565 | " scheduler.set_timesteps(steps)\n", 566 | " scheduler_2 = copy.deepcopy(scheduler)\n", 567 | " shape = (1, model.config.in_channels, height // 8, width // 8)\n", 568 | " orig_latents = torch.randn(shape, generator=torch.manual_seed(seed)).to(device) * scheduler.init_noise_sigma\n", 569 | " edit_latents = torch.randn(shape, generator=torch.manual_seed(seed_2)).to(device) * scheduler.init_noise_sigma\n", 570 | " \n", 571 | " for i, t in enumerate(progress_bar(scheduler.timesteps, leave=False)):\n", 572 | " # calculate noise_pred on the original unedited solution path\n", 573 | " latent_model_input = scheduler.scale_model_input(orig_latents, t) ## note orig_latents\n", 574 | " with torch.no_grad(): \n", 575 | " # don't store attention for the uncond prediction\n", 576 | " prepare_attention(model, storage, set_store=False)\n", 577 | " uncond = model(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample\n", 578 | "\n", 579 | " # do store attention for the cond prediction\n", 580 | " prepare_attention(model, storage, pred_type='orig', set_store=True)\n", 581 | " cond = model(latent_model_input, t, encoder_hidden_states=cond_embeddings).sample\n", 582 | " orig_feats = hook.feats if feature_layer is not None else None\n", 583 | " \n", 584 | " # classifier-free guidance on original solution path\n", 585 | " orig_noise_pred = uncond + guidance_scale * (cond - uncond)\n", 586 | " orig_latents = scheduler.step(orig_noise_pred, t, orig_latents).prev_sample\n", 587 | " \n", 588 | " edit_latents.requires_grad_(True)\n", 589 | " edit_latents.retain_grad()\n", 590 | " \n", 591 | " # recalculate noise_pred for edited solution path and allow grads to flow this time\n", 592 | " latent_model_input = scheduler_2.scale_model_input(edit_latents, t) ## note edit_latents\n", 593 | " prepare_attention(model, storage, set_store=False)\n", 594 | " uncond = model(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample\n", 595 | "\n", 596 | " prepare_attention(model, storage, pred_type='edit', set_store=True)\n", 597 | " cond = model(latent_model_input, t, encoder_hidden_states=cond_embeddings).sample\n", 598 | " edit_feats = hook.feats if feature_layer is not None else None\n", 599 | " \n", 600 | " # perform guidance with flexible g function\n", 601 | " edit_noise_pred = uncond + guidance_scale * (cond - uncond)\n", 602 | " if do_self_guidance(i, len(scheduler.timesteps), scheduler):\n", 603 | " g = guidance_func(storage, indices, orig_feats=orig_feats, edit_feats=edit_feats)\n", 604 | " g.backward()\n", 605 | " sig_t = scheduler.sigmas[i]\n", 606 | " edit_noise_pred += g_weight*sig_t*edit_latents.grad\n", 607 | " edit_latents = scheduler_2.step(edit_noise_pred.detach(), t, edit_latents.detach()).prev_sample\n", 608 | " storage.flush()\n", 609 | " \n", 610 | " orig_latents = 1 / 0.18215 * orig_latents\n", 611 | " edit_latents = 1 / 0.18215 * edit_latents\n", 612 | "\n", 613 | " with torch.no_grad(): edit_img = vae.decode(edit_latents).sample\n", 614 | " if not return_original: return edit_img\n", 615 | " with torch.no_grad(): orig_img = vae.decode(orig_latents).sample\n", 616 | " return orig_img, edit_img" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": null, 622 | "id": "03b8e281-b9d0-4905-8f7e-8bdbdd05317d", 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "def sample_original(prompt, seed=None, height=512, width=512, steps=50, guidance_scale=5, device='cuda'):\n", 627 | " if seed is None: seed = int(torch.rand((1,)) * 1000000)\n", 628 | " embeddings = get_embeddings(prompt, concat_unconditional=True, device=device)\n", 629 | " scheduler.set_timesteps(steps)\n", 630 | " shape = (1, model.in_channels, height // 8, width // 8)\n", 631 | " latents = torch.randn(shape, generator=torch.manual_seed(seed)).to(device)\n", 632 | " latents = latents * scheduler.init_noise_sigma\n", 633 | " \n", 634 | " for i, t in enumerate(progress_bar(scheduler.timesteps, leave=False)):\n", 635 | " latent_model_input = torch.cat([latents] * 2).to(device)\n", 636 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 637 | " with torch.no_grad():\n", 638 | " noise_pred = model(latent_model_input, t, encoder_hidden_states=embeddings).sample\n", 639 | "\n", 640 | " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", 641 | " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", 642 | " latents = scheduler.step(noise_pred, t, latents).prev_sample\n", 643 | " \n", 644 | " latents = 1 / 0.18215 * latents\n", 645 | " with torch.no_grad(): image = vae.decode(latents).sample\n", 646 | " return image" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "id": "c5048dbe-b277-4547-9a9e-ea65edc423f2", 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "model = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet').to('cuda')\n", 657 | "tokeniser = AutoTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='tokenizer')\n", 658 | "text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='text_encoder')\n", 659 | "vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae').to('cuda')\n", 660 | "scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear')" 661 | ] 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "id": "d0c3a0c6-c7ae-4ef7-9047-82e703b13a53", 666 | "metadata": {}, 667 | "source": [ 668 | "#### Sample new appearances" 669 | ] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "id": "7b5a711d-290b-4cab-a849-d681d05038f2", 674 | "metadata": {}, 675 | "source": [ 676 | "#### Sample new layouts\n", 677 | "\n", 678 | "Additional experimental code" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": null, 684 | "id": "eaa0868f-490f-456f-9cb4-949595d40540", 685 | "metadata": {}, 686 | "outputs": [], 687 | "source": [ 688 | "def fix_appearances_2(orig_attns, orig_feats, edit_attns, edit_feats, indices, attn_idx=-1):\n", 689 | " appearances = []\n", 690 | " for o in indices: appearances.append((orig_feats - edit_feats).pow(2).mean())\n", 691 | " return torch.stack(appearances).mean()\n", 692 | "\n", 693 | "def edit_layout_2(attn_storage, indices, appearance_weight=0.5, orig_feats=None, edit_feats=None, **kwargs):\n", 694 | " origs, edits = get_attns(attn_storage)\n", 695 | " \n", 696 | " return appearance_weight*fix_appearances_2(origs, orig_feats, edits, edit_feats, indices, **kwargs)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "markdown", 701 | "id": "0c6c569e-ae6a-459b-8938-5ab6543d4b2f", 702 | "metadata": {}, 703 | "source": [ 704 | "#### Move an object\n", 705 | "\n", 706 | "Additional experimental code." 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": null, 712 | "id": "90d0b5b8-3334-4eb2-867d-2d9975ef7b5d", 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "def roll_shape(x, direction='up', factor=0.5):\n", 717 | " h = w = int(math.sqrt(x.shape[-2]))\n", 718 | " mag = (0,0)\n", 719 | " if direction == 'up': mag = (int(-h*factor),0)\n", 720 | " elif direction == 'down': mag = (int(-h*factor),0)\n", 721 | " elif direction == 'right': mag = (0,int(w*factor))\n", 722 | " elif direction == 'left': mag = (0,int(-w*factor))\n", 723 | " shape = (x.shape[0], h, h, x.shape[-1])\n", 724 | " x = x.view(shape)\n", 725 | " move = x.roll(mag, dims=(1,2))\n", 726 | " return move.view(x.shape[0], h*h, x.shape[-1])" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": null, 732 | "id": "8cc5e6d3-bf78-497d-8654-d25054c7ffba", 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "def shift_shape(x, direction='up'): \n", 737 | " h = w = int(math.sqrt(x.shape[-2]))\n", 738 | " shape = (x.shape[0], h, w, x.shape[-1])\n", 739 | " x = x.view(shape)\n", 740 | " shift = torch.zeros_like(x)\n", 741 | " \n", 742 | " if direction == 'up':\n", 743 | " shift[:, :h//4*3, :, :] = x[:, h//4:, :, :]\n", 744 | " shift[:, h//4*3:, :, :] = x[:, :h//4, :, :]\n", 745 | " elif direction == 'down':\n", 746 | " shift[:, h//4:, :, :] = x[:, :h//4*3, :, :]\n", 747 | " shift[:, :h//4, :, :] = x[:, h//4*3:, :, :]\n", 748 | " elif direction == 'right':\n", 749 | " shift[:, :, :w//4*3, :] = x[:, :, w//4:, :]\n", 750 | " shift[:, :, w//4*3:, :] = x[:, :, :w//4, :]\n", 751 | " elif direction == 'left':\n", 752 | " shift[:, :, w//4:, :] = x[:, :, :w//4*3, :]\n", 753 | " shift[:, :, :w//4, :] = x[:, :, w//4*3:, :]\n", 754 | " \n", 755 | " return shift.view(x.shape[0], h*h, x.shape[-1])\n", 756 | "\n", 757 | "# def shift_shape(x, direction='up'): \n", 758 | "# h = w = int(math.sqrt(x.shape[-2]))\n", 759 | "# shape = (x.shape[0], h, w, x.shape[-1])\n", 760 | "# x = x.view(shape)\n", 761 | "# shift = torch.zeros_like(x)\n", 762 | " \n", 763 | "# if direction == 'up':\n", 764 | "# shift[:, :h//4, :, :] = x[:, h//4:, :, :]\n", 765 | "# elif direction == 'down':\n", 766 | "# shift[:, h//4:, :, :] = x[:, :h//4, :, :]\n", 767 | "# elif direction == 'right':\n", 768 | "# shift[:, :, :w//4, :] = x[:, :, w//4:, :]\n", 769 | "# elif direction == 'left':\n", 770 | "# shift[:, :, w//4:, :] = x[:, :, :w//4, :]\n", 771 | " \n", 772 | "# return shift.view(x.shape[0], h*h, x.shape[-1])" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": null, 778 | "id": "f8ddbf4a-3f97-4327-bf7d-62257dd16c4a", 779 | "metadata": {}, 780 | "outputs": [], 781 | "source": [ 782 | "def fix_shapes_3(orig_attns, edit_attns, indices, tau=fc.noop):\n", 783 | " shapes = []\n", 784 | " for o in indices:\n", 785 | " deltas = []\n", 786 | " for i in range(len(edit_attns)):\n", 787 | " orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]\n", 788 | " if len(orig.shape) < 3: orig, edit = orig[...,None], edit[...,None]\n", 789 | " delta = (tau(get_shape(orig)) - get_shape(edit)).pow(2).mean()\n", 790 | " deltas.append(delta.mean())\n", 791 | " shapes.append(torch.stack(deltas).mean())\n", 792 | " return torch.stack(shapes).mean()" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": null, 798 | "id": "f4928d31-50d7-44c7-9e4d-c8822f7baac7", 799 | "metadata": {}, 800 | "outputs": [], 801 | "source": [ 802 | "def fix_selfs_2(origs, edits, t=fc.noop):\n", 803 | " deltas = []\n", 804 | " for i in range(len(edits)):\n", 805 | " orig, edit = origs[i][...,None].mean(0), edits[i]\n", 806 | " delta = t(orig).squeeze() - edit\n", 807 | " deltas.append(delta.mean())\n", 808 | " return torch.stack(deltas).mean()" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "id": "4a0dc805-97a7-4c26-b1d9-79665b05e018", 815 | "metadata": {}, 816 | "outputs": [], 817 | "source": [ 818 | "def move_object(attn_storage, indices, t=fc.noop, shape_weight=1, size_weight=1, self_weight=0.1, appearance_weight=0.5, position_weight=1, orig_feats=None, edit_feats=None, **kwargs):\n", 819 | " origs, edits = get_attns(attn_storage)\n", 820 | " # orig_selfs = [v['orig'] for k,v in attn_storage.storage.items() if 'attn1' in k and v['orig'].shape[-1] == 4096]\n", 821 | " # edit_selfs = [v['edit'] for k,v in attn_storage.storage.items() if 'attn1' in k and v['orig'].shape[-1] == 4096]\n", 822 | " if len(indices) > 1: \n", 823 | " obj_idx, other_idx = indices\n", 824 | " indices = torch.cat([obj_idx, other_idx])\n", 825 | " shape_term = shape_weight*fix_shapes(origs, edits, obj_idx)\n", 826 | " appearance_term = appearance_weight*fix_appearances_2(origs, orig_feats, edits, edit_feats, indices)\n", 827 | " # size_term = size_weight*fix_sizes(origs, edits, obj_idx)\n", 828 | " # position_term = position_weight*position_deltas_2(origs, edits, obj_idx, target_centroid=target_centroid)\n", 829 | " # self_term = self_weight*fix_selfs_2(orig_selfs, edit_selfs, t=t)\n", 830 | " move_term = position_weight*fix_shapes_3(origs, edits, other_idx, tau=t)\n", 831 | " return move_term + shape_term + appearance_term" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "id": "e6be3bba-a811-4073-b116-3db5d3171600", 837 | "metadata": {}, 838 | "source": [ 839 | "#### Resize an object" 840 | ] 841 | }, 842 | { 843 | "cell_type": "markdown", 844 | "id": "9c5d3f88-eba7-4061-b571-1f226e32f402", 845 | "metadata": {}, 846 | "source": [ 847 | "Additional experimental code." 848 | ] 849 | }, 850 | { 851 | "cell_type": "code", 852 | "execution_count": null, 853 | "id": "85052ad1-e5cd-4600-8cbf-bff93671fc87", 854 | "metadata": {}, 855 | "outputs": [], 856 | "source": [ 857 | "def enlarge(x, scale_factor=1):\n", 858 | " assert scale_factor >= 1\n", 859 | " h = w = int(math.sqrt(x.shape[-2]))\n", 860 | " x = rearrange(x, 'n (h w) d -> n d h w', h=h)\n", 861 | " x = F.interpolate(x, scale_factor=scale_factor)\n", 862 | " new_h = new_w = x.shape[-1]\n", 863 | " x_l, x_r = (new_w//2) - w//2, (new_w//2) + w//2\n", 864 | " x_t, x_b = (new_h//2) - h//2, (new_h//2) + h//2\n", 865 | " x = x[:,:,x_t:x_b,x_l:x_r]\n", 866 | " return rearrange(x, 'n d h w -> n (h w) d', h=h) * scale_factor" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "id": "8ee85bd5-5028-40a8-814a-effed8e48f60", 873 | "metadata": {}, 874 | "outputs": [], 875 | "source": [ 876 | "def shrink(x, scale_factor=1):\n", 877 | " assert scale_factor <= 1\n", 878 | " h = w = int(math.sqrt(x.shape[-2]))\n", 879 | " x = rearrange(x, 'n (h w) d -> n d h w', h=h)\n", 880 | " sf = int(1/scale_factor)\n", 881 | " new_h, new_w = h*sf, w*sf\n", 882 | " x1 = torch.zeros(x.shape[0], x.shape[1], new_h, new_w).to(x.device)\n", 883 | " x_l, x_r = (new_w//2) - w//2, (new_w//2) + w//2\n", 884 | " x_t, x_b = (new_h//2) - h//2, (new_h//2) + h//2\n", 885 | " x1[:,:,x_t:x_b,x_l:x_r] = x\n", 886 | " shrink = F.interpolate(x1, scale_factor=scale_factor)\n", 887 | " return rearrange(shrink, 'n d h w -> n (h w) d', h=h) * scale_factor" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": null, 893 | "id": "b43561d8-0f45-4000-85d3-10fb27222d1c", 894 | "metadata": {}, 895 | "outputs": [], 896 | "source": [ 897 | "def resize(x, scale_factor=1):\n", 898 | " if scale_factor > 1: return enlarge(x)\n", 899 | " elif scale_factor < 1: return shrink(x)\n", 900 | " else: return x" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": null, 906 | "id": "8d53fe9f-9a82-4e4c-8d72-0978a4df4109", 907 | "metadata": {}, 908 | "outputs": [], 909 | "source": [ 910 | "def fix_shapes_2(orig_attns, edit_attns, indices, tau=1):\n", 911 | " shapes = []\n", 912 | " for o in indices:\n", 913 | " deltas = []\n", 914 | " for i in range(len(edit_attns)):\n", 915 | " orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]\n", 916 | " t = orig + (orig.max() * tau)\n", 917 | " delta = (get_shape((orig + t).clip(min=0))) - get_shape(edit)\n", 918 | " deltas.append(delta.mean())\n", 919 | " shapes.append(torch.stack(deltas).mean())\n", 920 | " return torch.stack(shapes).mean()" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": null, 926 | "id": "14f5de8e-d4fa-4a72-9979-6301a11045d8", 927 | "metadata": {}, 928 | "outputs": [], 929 | "source": [ 930 | "# def fix_shapes_3(orig_attns, edit_attns, indices, tau=fc.noop):\n", 931 | "# shapes = []\n", 932 | "# for o in indices:\n", 933 | "# deltas = []\n", 934 | "# for i in range(len(edit_attns)):\n", 935 | "# orig, edit = orig_attns[i][:,:,o], edit_attns[i][:,:,o]\n", 936 | "# if len(orig.shape) < 3: orig, edit = orig[...,None], edit[...,None]\n", 937 | "# delta = (tau(get_shape(orig)) - get_shape(edit)).pow(2).mean()\n", 938 | "# deltas.append(delta.mean())\n", 939 | "# shapes.append(torch.stack(deltas).mean())\n", 940 | "# return torch.stack(shapes).mean()" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "execution_count": null, 946 | "id": "64d910a0-69cf-49f9-b536-7a8d5ef73e41", 947 | "metadata": {}, 948 | "outputs": [], 949 | "source": [ 950 | "def resize_object_2(attn_storage, indices, t=fc.noop, relative_size=2, shape_weight=1, size_weight=1, appearance_weight=0.1, orig_feats=None, edit_feats=None, self_weight=0.1, **kwargs):\n", 951 | " origs, edits = get_attns(attn_storage)\n", 952 | " # orig_selfs = [v['orig'] for k,v in attn_storage.storage.items() if 'attn1' in k][-1]\n", 953 | " # edit_selfs = [v['edit'] for k,v in attn_storage.storage.items() if 'attn1' in k][-1]\n", 954 | " if len(indices) > 1:\n", 955 | " obj_idx, other_idx = indices\n", 956 | " indices = torch.cat([obj_idx, other_idx])\n", 957 | " shape_term = shape_weight*fix_shapes(origs, edits, other_idx)\n", 958 | " appearance_term = appearance_weight*fix_appearances(origs, orig_feats, edits, edit_feats, indices)\n", 959 | " size_term = size_weight*fix_shapes_2(origs, edits, obj_idx, tau=t)\n", 960 | " # self_term = self_weight*fix_selfs(orig_selfs, edit_selfs)\n", 961 | " return shape_term + appearance_term + size_term" 962 | ] 963 | } 964 | ], 965 | "metadata": { 966 | "kernelspec": { 967 | "display_name": "main_env", 968 | "language": "python", 969 | "name": "main_env" 970 | } 971 | }, 972 | "nbformat": 4, 973 | "nbformat_minor": 5 974 | } 975 | --------------------------------------------------------------------------------