├── dogs ├── source.png └── target1.png ├── PerMIRS ├── utils.py ├── eval_miou.py ├── dataset.py ├── permirs_gen_dataset.py ├── visualization_utils.py ├── extract_diff_features.py └── video.py ├── attention_store.py ├── sam_utils.py ├── README.md ├── pdm_matching.py ├── pdm_permir.py ├── ptp_utils.py ├── pdm_permis.py ├── StableDiffusionPipelineWithDDIMInversion.py └── dift.py /dogs/source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvirsamuel/PDM/HEAD/dogs/source.png -------------------------------------------------------------------------------- /dogs/target1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvirsamuel/PDM/HEAD/dogs/target1.png -------------------------------------------------------------------------------- /PerMIRS/utils.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/Ali2500/BURST-benchmark 2 | from typing import Dict, List, Any, Tuple 3 | 4 | import numpy as np 5 | import pycocotools.mask as cocomask 6 | 7 | 8 | def intify_track_ids(video_dict: Dict[str, Any]): 9 | video_dict["track_category_ids"] = { 10 | int(track_id): category_id for track_id, category_id in video_dict["track_category_ids"].items() 11 | } 12 | 13 | for t in range(len(video_dict["segmentations"])): 14 | video_dict["segmentations"][t] = { 15 | int(track_id): seg 16 | for track_id, seg in video_dict["segmentations"][t].items() 17 | } 18 | 19 | return video_dict 20 | 21 | 22 | def rle_ann_to_mask(rle: str, image_size: Tuple[int, int]) -> np.ndarray: 23 | return cocomask.decode({ 24 | "size": image_size, 25 | "counts": rle.encode("utf-8") 26 | }).astype(bool) 27 | 28 | 29 | def mask_to_rle_ann(mask: np.ndarray) -> Dict[str, Any]: 30 | assert mask.ndim == 2, f"Mask must be a 2-D array, but got array of shape {mask.shape}" 31 | rle = cocomask.encode(np.asfortranarray(mask.astype(np.uint8))) 32 | rle["counts"] = rle["counts"].decode("utf-8") 33 | return rle -------------------------------------------------------------------------------- /PerMIRS/eval_miou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | 4 | 5 | 6 | def get_arguments(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--pred_path', type=str, default='') 10 | parser.add_argument('--gt_path', type=str, default='/inputs/Projects/PerDet/datasets/PerSeg/Annotations') 11 | 12 | parser.add_argument('--ref_idx', type=str, default='00') 13 | 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | 21 | def __init__(self): 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = 0 26 | self.avg = 0 27 | self.sum = 0 28 | self.count = 0 29 | 30 | def update(self, val, n=1): 31 | self.val = val 32 | self.sum += val * n 33 | self.count += n 34 | self.avg = self.sum / self.count 35 | 36 | 37 | def intersectionAndUnion(output, target): 38 | assert (output.ndim in [1, 2, 3]) 39 | assert output.shape == target.shape 40 | output = output.reshape(output.size).copy() 41 | target = target.reshape(target.size) 42 | 43 | area_intersection = np.logical_and(output, target).sum() 44 | area_union = np.logical_or(output, target).sum() 45 | area_target = target.sum() 46 | 47 | return area_intersection, area_union, area_target -------------------------------------------------------------------------------- /PerMIRS/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from video import BURSTVideo 3 | 4 | import json 5 | import os.path as osp 6 | 7 | import utils 8 | 9 | 10 | class BURSTDataset: 11 | def __init__(self, annotations_file: str, images_base_dir: Optional[str] = None): 12 | with open(annotations_file, 'r') as fh: 13 | content = json.load(fh) 14 | 15 | # convert track IDs from str to int wherever they are used as dict keys (JSON format always parses dict keys as 16 | # strings) 17 | self._videos = [utils.intify_track_ids(video) for video in content["sequences"]] 18 | self._videos = list(filter(lambda x: x["dataset"] != "HACS" and x["dataset"] != "AVA" and x["dataset"] != "LaSOT", self._videos)) 19 | 20 | self.category_names = { 21 | category["id"]: category["name"] for category in content["categories"] 22 | } 23 | 24 | self._split = content["split"] 25 | 26 | self.images_base_dir = images_base_dir 27 | 28 | @property 29 | def num_videos(self) -> int: 30 | return len(self._videos) 31 | 32 | def __getitem__(self, index) -> BURSTVideo: 33 | assert index < self.num_videos, f"Index {index} invalid since total number of videos is {self.num_videos}" 34 | 35 | video_dict = self._videos[index] 36 | if self.images_base_dir is None: 37 | video_images_dir = None 38 | else: 39 | video_images_dir = osp.join(self.images_base_dir, self._split, video_dict["dataset"], video_dict["seq_name"]) 40 | assert osp.exists(video_images_dir), f"Images directory for video not found at expected path: '{video_images_dir}'" 41 | 42 | return BURSTVideo(video_dict, video_images_dir) 43 | 44 | def __iter__(self): 45 | for i in range(self.num_videos): 46 | yield self[i] -------------------------------------------------------------------------------- /PerMIRS/permirs_gen_dataset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from argparse import ArgumentParser 3 | from tqdm import tqdm 4 | from dataset import BURSTDataset 5 | import numpy as np 6 | import os 7 | import os.path as osp 8 | 9 | 10 | def main(args): 11 | dataset = BURSTDataset(annotations_file=args.annotations_file, 12 | images_base_dir=args.images_base_dir) 13 | 14 | for i in tqdm(range(dataset.num_videos)): 15 | try: 16 | video = dataset[i] 17 | 18 | if args.seq and f"{video.dataset}/{video.name}" != args.seq: 19 | continue 20 | 21 | print(f"- Dataset: {video.dataset}\n" 22 | f"- Name: {video.name}") 23 | 24 | if args.first_frame_annotations: 25 | annotations = video.load_first_frame_annotations() 26 | else: 27 | annotations = video.load_masks() 28 | 29 | frames_idx, annotations = video.filter_dataset_for_benchmark(annotations) 30 | image_paths = video.images_paths(frames_idx) 31 | if len(image_paths) < 3: 32 | raise Exception(i, "not enough images found") 33 | 34 | base_path = f"/PerMIRS/{i}" 35 | for f_idx, image_p in enumerate(image_paths): 36 | frame_p = base_path + f"/{f_idx}.{osp.split(image_p)[-1].split('.')[-1]}" 37 | os.makedirs(os.path.dirname(frame_p), exist_ok=True) 38 | shutil.copyfile(image_p, frame_p) 39 | 40 | np.save(base_path + "/masks.npz", annotations) 41 | except Exception as e: 42 | print(i, e) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = ArgumentParser() 47 | 48 | parser.add_argument("--images_base_dir", required=True) 49 | parser.add_argument("--annotations_file", required=True) 50 | parser.add_argument("--first_frame_annotations", action='store_true') 51 | 52 | # extra options 53 | parser.add_argument("--save_dir", required=False) 54 | parser.add_argument("--seq", required=False) 55 | 56 | main(parser.parse_args()) -------------------------------------------------------------------------------- /PerMIRS/visualization_utils.py: -------------------------------------------------------------------------------- 1 | # Codef from https://github.com/Ali2500/BURST-benchmark 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def create_color_map(): 7 | # This function has been copied with minor changes from the DAVIS dataset API at: 8 | # https://github.com/davisvideochallenge/davis2017-evaluation (Caelles, Pont-Tuset, et al.) 9 | N = 256 10 | 11 | def bitget(byteval, idx): 12 | return (byteval & (1 << idx)) != 0 13 | 14 | cmap = np.zeros((N, 3), dtype=np.uint8) 15 | for i in range(N): 16 | r = g = b = 0 17 | c = i 18 | for j in range(8): 19 | r = r | (bitget(c, 0) << 7-j) 20 | g = g | (bitget(c, 1) << 7-j) 21 | b = b | (bitget(c, 2) << 7-j) 22 | c = c >> 3 23 | 24 | cmap[i] = np.array([r, g, b]) 25 | 26 | return cmap.tolist() 27 | 28 | 29 | def overlay_mask_on_image(image, mask, opacity, color): 30 | assert mask.ndim == 2 31 | mask_bgr = np.stack((mask, mask, mask), axis=2) 32 | masked_image = np.where(mask_bgr > 0, color, image) 33 | return ((opacity * masked_image) + ((1. - opacity) * image)).astype(np.uint8) 34 | 35 | 36 | def bbox_from_mask(mask): 37 | reduced_y = np.any(mask, axis=0) 38 | reduced_x = np.any(mask, axis=1) 39 | 40 | x_min = reduced_y.argmax() 41 | if x_min == 0 and reduced_y[0] == 0: # mask is all zeros 42 | return None 43 | 44 | x_max = len(reduced_y) - np.flip(reduced_y, 0).argmax() 45 | 46 | y_min = reduced_x.argmax() 47 | y_max = len(reduced_x) - np.flip(reduced_x, 0).argmax() 48 | 49 | return x_min, y_min, x_max, y_max 50 | 51 | 52 | def annotate_image(image, mask, color, label, point=None, **kwargs): 53 | """ 54 | :param image: np.ndarray(H, W, 3) 55 | :param mask: np.ndarray(H, W) 56 | :param color: tuple/list(int, int, int) in range [0, 255] 57 | :param label: str 58 | :param kwargs: "bbox_thickness", "text_font", "font_size", "mask_opacity" 59 | :return: np.ndarray(H, W, 3) 60 | """ 61 | annotated_image = overlay_mask_on_image(image, mask, color=color, opacity=kwargs.get("mask_opacity", 0.5)) 62 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(mask)] 63 | 64 | bbox_thickness = kwargs.get("bbox_thickness", 2) 65 | text_font = kwargs.get("text_font", cv2.FONT_HERSHEY_SIMPLEX) 66 | font_size = kwargs.get("font_size", 0.5) 67 | 68 | annotated_image = cv2.rectangle(cv2.UMat(annotated_image), (xmin, ymin), (xmax, ymax), color=tuple(color), 69 | thickness=bbox_thickness) 70 | 71 | (text_width, text_height), _ = cv2.getTextSize(label, text_font, font_size, thickness=1) 72 | text_offset_x, text_offset_y = int(xmin + 2), int(ymin + text_height + 2) 73 | 74 | text_bg_box_pt1 = int(text_offset_x), int(text_offset_y + 2) 75 | text_bg_box_pt2 = int(text_offset_x + text_width + 2), int(text_offset_y - text_height - 2) 76 | 77 | annotated_image = cv2.rectangle(cv2.UMat(annotated_image), text_bg_box_pt1, text_bg_box_pt2, color=(255, 255, 255), thickness=-1) 78 | annotated_image = cv2.putText(cv2.UMat(annotated_image), label, (text_offset_x, text_offset_y), text_font, font_size, (0, 0, 0)) 79 | 80 | if point is not None: 81 | # use a darker color so the point is more visible on the mask 82 | color = tuple([int(round(0.5 * c)) for c in color]) 83 | annotated_image = cv2.circle(cv2.UMat(annotated_image), point, radius=3, color=color, thickness=-1) 84 | 85 | if isinstance(annotated_image, cv2.UMat): 86 | # sometimes OpenCV functions return objects of type cv2.UMat instead of numpy arrays 87 | return annotated_image.get() 88 | else: 89 | return annotated_image 90 | -------------------------------------------------------------------------------- /attention_store.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://prompt-to-prompt.github.io/ 2 | import abc 3 | 4 | 5 | class EmptyControl: 6 | 7 | def step_callback(self, x_t): 8 | return x_t 9 | 10 | def between_steps(self): 11 | return 12 | 13 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 14 | return attn 15 | 16 | 17 | class AttentionControl(abc.ABC): 18 | 19 | def step_callback(self, x_t): 20 | return x_t 21 | 22 | def between_steps(self): 23 | return 24 | 25 | def between_steps_inject(self): 26 | return 27 | 28 | @property 29 | def num_uncond_att_layers(self): 30 | return 0 #self.num_att_layers if LOW_RESOURCE else 0 31 | 32 | @abc.abstractmethod 33 | def forward(self, attn, is_cross: bool, place_in_unet: str): 34 | raise NotImplementedError 35 | 36 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 37 | if self.cur_att_layer >= self.num_uncond_att_layers: 38 | attn = self.forward(attn, is_cross, place_in_unet) 39 | return attn 40 | 41 | def check_next_step(self): 42 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 43 | self.cur_att_layer = 0 44 | self.cur_step += 1 45 | if self.is_inject: 46 | self.between_steps_inject() 47 | else: 48 | self.between_steps() 49 | 50 | def reset(self): 51 | self.cur_step = 0 52 | self.cur_att_layer = 0 53 | 54 | def __init__(self): 55 | self.cur_step = 0 56 | self.num_att_layers = -1 57 | self.cur_att_layer = 0 58 | self.is_inject = False 59 | 60 | 61 | class AttentionStore(AttentionControl): 62 | 63 | @staticmethod 64 | def get_empty_store(): 65 | #return {"down_cross": [], "mid_cross": [], "up_cross": [], 66 | # "down_self": [], "mid_self": [], "up_self": []} 67 | return {} 68 | 69 | def forward(self, attn, is_cross: bool, place_in_unet: str): 70 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 71 | if attn is None: 72 | attn = self.attention_store[self.cur_step][key] 73 | else: 74 | self.step_store[key] = attn.cpu() 75 | #if attn.shape[1] <= 32 ** 2: # avoid memory overhead 76 | # self.step_store[key].append(attn) 77 | return attn 78 | 79 | def between_steps(self): 80 | self.attention_store[self.cur_step - 1] = self.step_store 81 | # if len(self.attention_store) == 0: 82 | # self.attention_store[self.cur_step-1] = self.step_store 83 | # else: 84 | # for key in self.attention_store: 85 | # for i in range(len(self.attention_store[key])): 86 | # self.attention_store[key][i] += self.step_store[key][i] 87 | self.step_store = self.get_empty_store() 88 | 89 | def between_steps_inject(self): 90 | self.step_store = self.get_empty_store() 91 | 92 | def get_average_attention(self): 93 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 94 | self.attention_store} 95 | return average_attention 96 | 97 | def reset(self): 98 | super(AttentionStore, self).reset() 99 | self.step_store = self.get_empty_store() 100 | self.attention_store = {} 101 | 102 | def __init__(self): 103 | super(AttentionStore, self).__init__() 104 | self.step_store = self.get_empty_store() 105 | self.attention_store = {} 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /sam_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def show_mask(mask, ax, random_color=False): 6 | if random_color: 7 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 8 | else: 9 | color = np.array([30/255, 144/255, 255/255, 0.6]) 10 | h, w = mask.shape[-2:] 11 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 12 | ax.imshow(mask_image) 13 | 14 | 15 | def show_box(box, ax): 16 | x0, y0 = box[0], box[1] 17 | w, h = box[2] - box[0], box[3] - box[1] 18 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 19 | 20 | def show_boxes_on_image(raw_image, boxes): 21 | plt.figure(figsize=(10,10)) 22 | plt.imshow(raw_image) 23 | for box in boxes: 24 | show_box(box, plt.gca()) 25 | plt.axis('on') 26 | plt.show() 27 | 28 | def show_points_on_image(raw_image, input_points, input_labels=None): 29 | plt.figure(figsize=(10,10)) 30 | plt.imshow(raw_image) 31 | input_points = np.array(input_points) 32 | if input_labels is None: 33 | labels = np.ones_like(input_points[:, 0]) 34 | else: 35 | labels = np.array(input_labels) 36 | show_points(input_points, labels, plt.gca()) 37 | plt.axis('on') 38 | plt.show() 39 | 40 | def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): 41 | plt.figure(figsize=(10,10)) 42 | plt.imshow(raw_image) 43 | input_points = np.array(input_points) 44 | if input_labels is None: 45 | labels = np.ones_like(input_points[:, 0]) 46 | else: 47 | labels = np.array(input_labels) 48 | show_points(input_points, labels, plt.gca()) 49 | for box in boxes: 50 | show_box(box, plt.gca()) 51 | plt.axis('on') 52 | plt.show() 53 | 54 | 55 | def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): 56 | plt.figure(figsize=(10,10)) 57 | plt.imshow(raw_image) 58 | input_points = np.array(input_points) 59 | if input_labels is None: 60 | labels = np.ones_like(input_points[:, 0]) 61 | else: 62 | labels = np.array(input_labels) 63 | show_points(input_points, labels, plt.gca()) 64 | for box in boxes: 65 | show_box(box, plt.gca()) 66 | plt.axis('on') 67 | plt.show() 68 | 69 | 70 | def show_points(coords, labels, ax, marker_size=375): 71 | pos_points = coords[labels==1] 72 | neg_points = coords[labels==0] 73 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 74 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 75 | 76 | 77 | def show_masks_on_image(raw_image, masks, scores): 78 | if len(masks.shape) == 4: 79 | masks = masks.squeeze() 80 | if scores.shape[0] == 1: 81 | scores = scores.squeeze() 82 | 83 | nb_predictions = scores.shape[-1] 84 | fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15)) 85 | 86 | for i, (mask, score) in enumerate(zip(masks, scores)): 87 | mask = mask.cpu().detach() 88 | axes[i].imshow(np.array(raw_image)) 89 | show_mask(mask, axes[i]) 90 | axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}") 91 | axes[i].axis("off") 92 | plt.show() 93 | 94 | def show_single_mask_on_image(raw_image, mask, score): 95 | _ = plt.figure(figsize=(15, 15)) 96 | ax = plt.gca() 97 | mask = mask.cpu().detach() 98 | ax.imshow(np.array(raw_image)) 99 | show_mask(mask, ax) 100 | ax.title.set_text(f"Score: {score.item():.3f}") 101 | ax.axis("off") 102 | plt.show() -------------------------------------------------------------------------------- /PerMIRS/extract_diff_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import numpy as np 4 | import torch 5 | from diffusers import DDIMScheduler 6 | from tqdm import tqdm 7 | import ptp_utils 8 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion 9 | from attention_store import AttentionStore 10 | from PerMIRS.visualization_utils import bbox_from_mask 11 | 12 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None): 13 | if offsets is None: 14 | width, height = im.size # Get dimensions 15 | min_dim = min(width, height) 16 | left = (width - min_dim) / 2 17 | top = (height - min_dim) / 2 18 | right = (width + min_dim) / 2 19 | bottom = (height + min_dim) / 2 20 | 21 | if min_obj_x < left: 22 | diff = abs(left - min_obj_x) 23 | left = min_obj_x 24 | right = right - diff 25 | if max_obj_x > right: 26 | diff = abs(right - max_obj_x) 27 | right = max_obj_x 28 | left = left + diff 29 | else: 30 | left, top, right, bottom = offsets 31 | 32 | # Crop the center of the image 33 | im = im.crop((left, top, right, bottom)) 34 | return im, (left, top, right, bottom) 35 | 36 | 37 | def load_im_into_format_from_path(im_path, size=512, offsets=None): 38 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets) 39 | return im 40 | 41 | 42 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None): 43 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x) 44 | return im.resize((size, size)), offsets 45 | 46 | def extract_attention(sd_model, image, prompt): 47 | controller = AttentionStore() 48 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents 49 | ptp_utils.register_attention_control_efficient(sd_model, controller) 50 | # recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0] 51 | key_format = f"up_blocks_3_attentions_1_transformer_blocks_0_attn1_self" 52 | timestamp = 49 53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"), 54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"), 55 | ], sd_model.unet.up_blocks[3].attentions[1].transformer_blocks[0].attn1 56 | 57 | if __name__ == "__main__": 58 | model_id = "CompVis/stable-diffusion-v1-4" 59 | 60 | device = "cuda" # if torch.cuda.is_available() else "cpu" 61 | 62 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 63 | model_id, 64 | safety_checker=None, 65 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") 66 | ) 67 | sd_model = sd_model.to(device) 68 | img_size = 512 69 | attn_size = 64 70 | 71 | dataset_dir = "/PerMIRS" 72 | for vid_id in tqdm(os.listdir(dataset_dir)): 73 | try: 74 | frames = [] 75 | masks = [] 76 | masks_img_size = [] 77 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True) 78 | for f in range(3): 79 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])] 80 | m, curr_offsets = load_im_into_format_from_image( 81 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])), 82 | min_obj_x=xmin, max_obj_x=xmax) 83 | masks += [np.asarray(m.resize((attn_size, attn_size)))] 84 | masks_img_size += [np.asarray(m.resize((img_size, img_size)))] 85 | frames += [ 86 | load_im_into_format_from_path(f"{dataset_dir}/{vid_id}/{f}.jpg", offsets=curr_offsets).resize((img_size, img_size)).convert("RGB")] 87 | 88 | all_attn = [] 89 | for f in frames: 90 | curr_attn, _ = extract_attention(sd_model, f, "A photo") 91 | all_attn += [curr_attn] 92 | torch.save(all_attn, f"{dataset_dir}/{vid_id}/diff_feats.pt") 93 | except Exception as e: 94 | f = open(f"{dataset_dir}/{vid_id}/error.txt", "w") 95 | f.writelines(str(e)) 96 | f.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Where's Waldo: Diffusion Features For Personalized Segmentation and Retrieval (NeurIPS 2024) 2 | > Dvir Samuel, Rami Ben-Ari, Matan Levy, Nir Darshan, Gal Chechik 3 | > Bar Ilan University, The Hebrew University of Jerusalem, NVIDIA Research 4 | 5 | > 6 | > 7 | > Personalized retrieval and segmentation aim to locate specific instances within a dataset based on an input image and a short description of the reference instance. While supervised methods are effective, they require extensive labeled data for training. Recently, self-supervised foundation models have been introduced to these tasks showing comparable results to supervised methods. However, a significant flaw in these models is evident: they struggle to locate a desired instance when other instances within the same class are presented. In this paper, we explore text-to-image diffusion models for these tasks. Specifically, we propose a novel approach called PDM for Personalized Diffusion Features Matching, that leverages intermediate features of pre-trained text-to-image models for personalization tasks without any additional training. PDM demonstrates superior performance on popular retrieval and segmentation benchmarks, outperforming even supervised methods. We also highlight notable shortcomings in current instance and segmentation datasets and propose new benchmarks for these tasks. 8 | 9 | 10 | 11 |

12 | 13 | ![image](https://github.com/user-attachments/assets/c90fcb80-52f3-4a1e-9b08-7c93528d3c6d) 14 | Personalized segmentation task involves segmenting a specific reference object in a new scene. Our method is capable to accurately identify the specific reference instance in the target image, even when other objects from the same class are present. While other methods capture visually or semantically similar objects, our method can successfully extract the identical instance, by using a new personalized feature map and fusing semantic and appearance cues. Red and green indicate incorrect and correct segmentations respectively. 15 | 16 | 17 |
18 | 19 | ## Requirements 20 | 21 | Quick installation using pip: 22 | ``` 23 | torch==2.0.1 24 | torchvision==0.15.2 25 | diffusers==0.18.2 26 | transformers==4.32.0.dev0 27 | ``` 28 | 29 | ## Personalized Diffusion Features Matching (PDM) 30 | 31 | To run PDM visualization between two images run the following: 32 | 33 | ``` 34 | python pdm_matching.py 35 | ``` 36 | 37 | ## PerMIR and PerMIS Datasets 38 | 39 | The PerMIR and PerMIS datasets were sourced from the [BURST](https://github.com/Ali2500/BURST-benchmark) repository. 40 | 41 | ### Instructions: 42 | 1. Download the datasets from the BURST repository. Place train,val, and test sets in the same directory. 43 | 2. Run the script `PerMIRS/permirs_gen_dataset.py` to prepare the personalization datasets. Ensure `--images_base_dir` contains the downloaded BURST splits. Additionally, set `--annotations_file` to all_classes.json. 44 | 3. Execute `PerMIRS/extract_diff_features.py` to extract PDM and DIFT features from each image in the dataset. 45 | 46 | 47 | ## Evaluation on PerMIR 48 | 49 | For PDM evaluation on PerMIR dataset (personalized retrieval) run: 50 | 51 | ``` 52 | python pdm_permir.py 53 | ``` 54 | 55 | ## Evaluation on PerMIS 56 | 57 | For PDM evaluation on PerMIS dataset (personalized segmentation) run: 58 | 59 | ``` 60 | python pdm_permis.py 61 | ``` 62 | 63 | 64 | 65 | ## Cite Our Paper 66 | If you find our paper and repo useful, please cite: 67 | ``` 68 | @article{Samuel2024Waldo, 69 | title={Where's Waldo: Diffusion Features For Personalized Segmentation and Retrieval}, 70 | author={Dvir Samuel and Rami Ben-Ari and Matan Levy and Nir Darshan and Gal Chechik}, 71 | journal={NeurIPS}, 72 | year={2024} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /pdm_matching.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import torch 4 | from diffusers import DDIMScheduler 5 | from matplotlib import pyplot as plt 6 | from torch import nn 7 | from transformers import SamModel, SamProcessor 8 | from attention_store import AttentionStore 9 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion 10 | #from sam_utils import show_masks_on_image, show_points_on_image 11 | import ptp_utils 12 | 13 | 14 | def center_crop(im): 15 | width, height = im.size # Get dimensions 16 | min_dim = min(width, height) 17 | left = (width - min_dim) / 2 18 | top = (height - min_dim) / 2 19 | right = (width + min_dim) / 2 20 | bottom = (height + min_dim) / 2 21 | 22 | # Crop the center of the image 23 | im = im.crop((left, top, right, bottom)) 24 | return im 25 | 26 | 27 | def load_im_into_format_from_path(im_path): 28 | return center_crop(PIL.Image.open(im_path)).resize((512, 512)) 29 | 30 | def get_masks(raw_image, input_points): 31 | inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device) 32 | with torch.no_grad(): 33 | outputs = model(**inputs) 34 | 35 | masks = processor.image_processor.post_process_masks( 36 | outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() 37 | ) 38 | # visualize sam segmentation 39 | # scores = outputs.iou_scores 40 | # show_masks_on_image(raw_image, masks[0], scores) 41 | # show_points_on_image(raw_image, input_points[0]) 42 | return masks 43 | 44 | 45 | def extract_attention(sd_model, image, prompt): 46 | controller = AttentionStore() 47 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents 48 | ptp_utils.register_attention_control_efficient(sd_model, controller) 49 | recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0] 50 | key_format = f"up_blocks_3_attentions_2_transformer_blocks_0_attn1_self" 51 | timestamp = 49 52 | res_attn = 64 53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"), 54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"), 55 | ], sd_model.unet.up_blocks[3].attentions[2].transformer_blocks[0].attn1, res_attn 56 | 57 | def heatmap(sd_model, ref_img, tgt_img): 58 | image_size = 512 59 | prompt = f"A photo" 60 | # extract query PDM features (Q and K) 61 | ref_attn_ls, _, res_attn = extract_attention(sd_model, ref_img, prompt) 62 | h = w = res_attn 63 | 64 | # get query mask using SAM or use provided mask from user 65 | source_masks = get_masks(ref_img.resize(size=(h, w)), [[[h // 2, w // 2]]]) 66 | source_mask = source_masks[0][:, 1:2, :, :].squeeze(dim=0).squeeze(dim=0) 67 | mask_idx_y, mask_idx_x = torch.where(source_mask) 68 | 69 | # extract target PDM features (Q and K) 70 | target_attn_ls, _, _ = extract_attention(sd_model, tgt_img, prompt) 71 | 72 | # apply matching and show heatmap 73 | for attn_idx, (ref_attn, target_attn) in enumerate(zip(ref_attn_ls, target_attn_ls)): 74 | heatmap = torch.zeros(ref_attn.shape[0]).to("cuda") 75 | for x, y in zip(mask_idx_x, mask_idx_y): 76 | t = np.ravel_multi_index((y, x), dims=(h, w)) 77 | source_vec = ref_attn[t].reshape(1, -1) 78 | euclidean_dist = torch.cdist(source_vec, target_attn) 79 | idx = torch.sort(euclidean_dist)[1][0][:100] 80 | heatmap[idx] += 1 81 | heatmap = heatmap / heatmap.max() 82 | heatmap_img_size = \ 83 | nn.Upsample(size=(image_size, image_size), mode='bilinear')(heatmap.reshape(1, 1, 64, 64))[0][ 84 | 0] 85 | plt.imshow(tgt_img) 86 | plt.imshow(heatmap_img_size.cpu(), alpha=0.6) 87 | plt.show() 88 | 89 | 90 | if __name__ == "__main__": 91 | model_id = "CompVis/stable-diffusion-v1-4" 92 | 93 | device = "cuda" 94 | 95 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 96 | model_id, 97 | safety_checker=None, 98 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") 99 | ) 100 | sd_model = sd_model.to(device) 101 | img_size = 512 102 | 103 | model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) 104 | processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") 105 | 106 | ref_img = load_im_into_format_from_path("dogs/source.png").convert("RGB") 107 | plt.imshow(ref_img) 108 | plt.show() 109 | tgt_img = load_im_into_format_from_path("dogs/target1.png").convert("RGB") 110 | plt.imshow(tgt_img) 111 | plt.show() 112 | heatmap(sd_model, ref_img, tgt_img) 113 | 114 | -------------------------------------------------------------------------------- /pdm_permir.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import torch 4 | import os 5 | from diffusers import DDIMScheduler 6 | from sklearn.metrics import average_precision_score 7 | from torchvision.transforms import PILToTensor 8 | from tqdm import tqdm 9 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion 10 | from dift import SDFeaturizer, get_correspondences_seg 11 | from PerMIRS.visualization_utils import bbox_from_mask 12 | import torch.nn.functional as F 13 | 14 | 15 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None): 16 | if offsets is None: 17 | width, height = im.size # Get dimensions 18 | min_dim = min(width, height) 19 | left = (width - min_dim) / 2 20 | top = (height - min_dim) / 2 21 | right = (width + min_dim) / 2 22 | bottom = (height + min_dim) / 2 23 | 24 | if min_obj_x < left: 25 | diff = abs(left - min_obj_x) 26 | left = min_obj_x 27 | right = right - diff 28 | if max_obj_x > right: 29 | diff = abs(right - max_obj_x) 30 | right = max_obj_x 31 | left = left + diff 32 | else: 33 | left, top, right, bottom = offsets 34 | 35 | # Crop the center of the image 36 | im = im.crop((left, top, right, bottom)) 37 | return im, (left, top, right, bottom) 38 | 39 | 40 | def load_im_into_format_from_path(im_path, size=512, offsets=None): 41 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets) 42 | return im 43 | 44 | 45 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None): 46 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x) 47 | return im.resize((size, size)), offsets 48 | 49 | if __name__ == "__main__": 50 | model_id = "CompVis/stable-diffusion-v1-4" 51 | 52 | device = "cuda" 53 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 54 | model_id, 55 | safety_checker=None, 56 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") 57 | ) 58 | sd_model = sd_model.to(device) 59 | 60 | dift = SDFeaturizer(sd_model) 61 | img_size = 512 62 | attn_size = 64 63 | dift_size = 32 64 | 65 | dataset_dir = "/PerMIRS" 66 | 67 | if "ret_data.pth" not in os.listdir(dataset_dir): 68 | # organize data for retrieval 69 | query_frames = [] 70 | query_labels = [] 71 | query_dift_features = [] 72 | query_dift_mask = [] 73 | query_perdiff_features = [] 74 | gallery_frames = [] 75 | gallery_labels = [] 76 | gallery_dift_features = [] 77 | gallery_perdiff_features = [] 78 | 79 | for vid_idx, vid_id in tqdm(enumerate(os.listdir(dataset_dir))): 80 | masks_attn = [] 81 | masks_dift = [] 82 | masks_relative_size = [] 83 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True) 84 | attn_ls = torch.load(f"{dataset_dir}/{vid_id}/diff_feats.pt") 85 | dift_features = [] 86 | frame_paths = [] 87 | for f in range(3): 88 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])] 89 | m, curr_offsets = load_im_into_format_from_image( 90 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])), 91 | min_obj_x=xmin, max_obj_x=xmax) 92 | 93 | masks_attn += [np.asarray(m.resize((attn_size, attn_size)))] 94 | masks_dift += [np.asarray(m.resize((dift_size, dift_size)))] 95 | masks_relative_size += [np.asarray(m).sum() / (img_size * img_size)] 96 | path = f"{dataset_dir}/{vid_id}/{f}.jpg" 97 | frame_paths += [path] 98 | frame = load_im_into_format_from_path(path, offsets=curr_offsets).resize( 99 | (img_size, img_size)).convert("RGB") 100 | curr_dift = dift.forward((PILToTensor()(frame) / 255.0 - 0.5) * 2, 101 | prompt="A photo", 102 | ensemble_size=2) 103 | dift_features += [curr_dift] 104 | 105 | masks_relative_size = np.array(masks_relative_size) 106 | # remove small frames 107 | if len(np.where(np.array(masks_relative_size) < 0.005)[0]) > 0: 108 | continue 109 | 110 | query_idx = masks_relative_size.argmax() 111 | for i in range(len(dift_features)): 112 | if i == query_idx: 113 | # query 114 | query_dift_features += [dift_features[i]] 115 | query_dift_mask += [masks_dift[i]] 116 | query_perdiff_features += [attn_ls[i]] 117 | query_labels += [vid_idx] 118 | query_frames += [frame_paths[i]] 119 | else: 120 | # gallery 121 | gallery_dift_features += [dift_features[i]] 122 | gallery_perdiff_features += [attn_ls[i]] 123 | gallery_labels += [vid_idx] 124 | gallery_frames += [frame_paths[i]] 125 | 126 | query_labels = torch.tensor(query_labels) 127 | gallery_labels = torch.tensor(gallery_labels) 128 | 129 | torch.save([query_frames, query_labels, query_dift_features, query_dift_mask, query_perdiff_features, 130 | gallery_frames, gallery_labels, gallery_dift_features, gallery_perdiff_features], f"{dataset_dir}/ret_data.pt") 131 | else: 132 | # retrieval performance on PerMIR 133 | query_frames, query_labels, query_dift_features, query_dift_mask, query_perdiff_features, \ 134 | gallery_frames, gallery_labels, gallery_dift_features, gallery_perdiff_features = torch.load( 135 | f"{dataset_dir}/ret_data.pt") 136 | topk = 1 137 | recall_dict = {1: 0, 5: 0, 10: 0, 50: 0} 138 | ap = [] 139 | for q_idx in tqdm(range(len(query_dift_features))): 140 | scores = [] 141 | for g_idx in range(len(gallery_dift_features)): 142 | # First, extract correspondences using DIFT 143 | ref_points, tgt_points, dift_scores = get_correspondences_seg(query_dift_features[q_idx], 144 | gallery_dift_features[g_idx], 145 | query_dift_mask[q_idx], 146 | img_size=attn_size, 147 | topk=topk) 148 | 149 | total_maps_scores = [] 150 | total_maps_scores_cs = [] 151 | for attn_idx, (ref_attn, target_attn) in enumerate( 152 | zip(query_perdiff_features[q_idx], gallery_perdiff_features[g_idx])): 153 | all_point_scores = [] 154 | all_point_scores_cs = [] 155 | for p_ref, p_tgt, dift_score in zip(ref_points, tgt_points, dift_scores): 156 | source_vec = ref_attn[ 157 | torch.tensor(np.ravel_multi_index(p_ref, dims=(attn_size, attn_size))).to("cuda")].reshape(1, 158 | -1) 159 | target_vec = target_attn[ 160 | torch.tensor(np.ravel_multi_index(p_tgt.T, dims=(attn_size, attn_size))).to( 161 | "cuda")].reshape(topk, -1) 162 | euclidean_dist = torch.cdist(source_vec, target_vec) 163 | all_point_scores += [dift_score * euclidean_dist.mean().cpu().numpy()] 164 | all_point_scores_cs += [ 165 | dift_score * (F.normalize(source_vec) @ F.normalize(target_vec).T).mean().cpu().numpy()] 166 | total_maps_scores += [np.mean(all_point_scores)] 167 | total_maps_scores_cs += [np.mean(all_point_scores_cs)] 168 | total_score = np.mean(total_maps_scores) 169 | total_score_cs = np.mean(total_maps_scores_cs) 170 | scores += [total_score_cs] 171 | 172 | pred_scores_idx = torch.argsort(torch.tensor(scores), descending=True) # change to false fro euclidean 173 | pred_g_labels = gallery_labels[pred_scores_idx] 174 | curr_query_lbl = query_labels[q_idx] 175 | 176 | ap += [average_precision_score((gallery_labels == curr_query_lbl).int().numpy(),scores)] 177 | for r in [1, 5, 10, 50]: 178 | if curr_query_lbl in pred_g_labels[:r]: 179 | recall_dict[r] += 1 180 | 181 | print("MAP:", np.array(ap).mean()) 182 | for k in recall_dict.keys(): 183 | print(f"Recall@{k}", recall_dict[k] / len(query_dift_features)) 184 | -------------------------------------------------------------------------------- /ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://prompt-to-prompt.github.io/ 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | import cv2 7 | from typing import Optional, Union, Tuple, Dict 8 | from matplotlib import pyplot as plt 9 | 10 | 11 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 12 | h, w, c = image.shape 13 | offset = int(h * .2) 14 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 15 | font = cv2.FONT_HERSHEY_SIMPLEX 16 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 17 | img[:h] = image 18 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 21 | return img 22 | 23 | 24 | def view_images(images, num_rows=1, offset_ratio=0.02): 25 | if type(images) is list: 26 | num_empty = len(images) % num_rows 27 | elif images.ndim == 4: 28 | num_empty = images.shape[0] % num_rows 29 | else: 30 | images = [images] 31 | num_empty = 0 32 | 33 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 34 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 35 | num_items = len(images) 36 | 37 | h, w, c = images[0].shape 38 | offset = int(h * offset_ratio) 39 | num_cols = num_items // num_rows 40 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 41 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 42 | for i in range(num_rows): 43 | for j in range(num_cols): 44 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 45 | i * num_cols + j] 46 | 47 | pil_img = Image.fromarray(image_) 48 | # display(pil_img) 49 | plt.imshow(pil_img) 50 | plt.show() 51 | 52 | 53 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 54 | if low_resource: 55 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 56 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 57 | else: 58 | latents_input = torch.cat([latents] * 2) 59 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 60 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 61 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 62 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 63 | latents = controller.step_callback(latents) 64 | return latents 65 | 66 | 67 | def latent2image(vae, latents): 68 | latents = 1 / 0.18215 * latents 69 | image = vae.decode(latents)['sample'] 70 | image = (image / 2 + 0.5).clamp(0, 1) 71 | image = image.cpu().permute(0, 2, 3, 1).numpy() 72 | image = (image * 255).astype(np.uint8) 73 | return image 74 | 75 | 76 | def init_latent(latent, model, height, width, generator, batch_size): 77 | if latent is None: 78 | latent = torch.randn( 79 | (1, model.unet.in_channels, height // 8, width // 8), 80 | generator=generator, 81 | ) 82 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 83 | return latent, latents 84 | 85 | 86 | 87 | 88 | 89 | def register_attention_control_efficient(model, controller): 90 | def ca_forward(self, place_in_unet): 91 | to_out = self.to_out 92 | if type(to_out) is torch.nn.modules.container.ModuleList: 93 | to_out = self.to_out[0] 94 | else: 95 | to_out = self.to_out 96 | 97 | def forward(x, encoder_hidden_states=None, attention_mask=None): 98 | batch_size, sequence_length, dim = x.shape 99 | h = self.heads 100 | 101 | is_cross = encoder_hidden_states is not None 102 | encoder_hidden_states = encoder_hidden_states if is_cross else x 103 | inject_cond = False 104 | 105 | if inject_cond: 106 | q = controller(None, is_cross, "Q_" + "_".join(place_in_unet)).cuda() 107 | k = controller(None, is_cross, "K_" + "_".join(place_in_unet)).cuda() 108 | else: 109 | q = self.to_q(x) 110 | k = self.to_k(encoder_hidden_states) 111 | if not controller.is_inject: 112 | q = controller(q, is_cross, "Q_" + "_".join(place_in_unet)) 113 | k = controller(k, is_cross, "K_" + "_".join(place_in_unet)) 114 | 115 | 116 | q = self.head_to_batch_dim(q) 117 | k = self.head_to_batch_dim(k) 118 | 119 | if inject_cond: 120 | v = controller(None, is_cross, "V_" + "_".join(place_in_unet)).cuda() 121 | else: 122 | v = self.to_v(encoder_hidden_states) 123 | if not controller.is_inject: 124 | v = controller(v, is_cross, "V_" + "_".join(place_in_unet)) 125 | controller.cur_att_layer += 1 126 | controller.check_next_step() 127 | v = self.head_to_batch_dim(v) 128 | 129 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 130 | 131 | if attention_mask is not None: 132 | attention_mask = attention_mask.reshape(batch_size, -1) 133 | max_neg_value = -torch.finfo(sim.dtype).max 134 | attention_mask = attention_mask[:, None, :].repeat(h, 1, 1) 135 | sim.masked_fill_(~attention_mask, max_neg_value) 136 | 137 | # attention, what we cannot get enough of 138 | attn = sim.softmax(dim=-1) 139 | 140 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 141 | out = self.batch_to_head_dim(out) 142 | 143 | return to_out(out) 144 | 145 | return forward 146 | 147 | class DummyController: 148 | 149 | def __call__(self, *args): 150 | return args[0] 151 | 152 | def __init__(self): 153 | self.num_att_layers = 0 154 | 155 | if controller is None: 156 | controller = DummyController() 157 | 158 | def register_recr(net_, count, place_in_unet): 159 | if net_.__class__.__name__ == 'Attention': 160 | net_.forward = ca_forward(net_, place_in_unet) 161 | return count + 1 162 | elif hasattr(net_, 'children'): 163 | for name, net__ in net_.named_children(): 164 | count = register_recr(net__, count, place_in_unet + [name]) 165 | return count 166 | 167 | att_count = 0 168 | ref_modules = [] 169 | sub_nets = model.unet.named_children() 170 | for name, net in sub_nets: 171 | if "down" in name or "up" in name or "mid" in name: 172 | att_count += register_recr(net, 0, [name]) 173 | controller.num_att_layers = att_count 174 | 175 | 176 | def get_word_inds(text: str, word_place: int, tokenizer): 177 | split_text = text.split(" ") 178 | if type(word_place) is str: 179 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 180 | elif type(word_place) is int: 181 | word_place = [word_place] 182 | out = [] 183 | if len(word_place) > 0: 184 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 185 | cur_len, ptr = 0, 0 186 | 187 | for i in range(len(words_encode)): 188 | cur_len += len(words_encode[i]) 189 | if ptr in word_place: 190 | out.append(i + 1) 191 | if cur_len >= len(split_text[ptr]): 192 | ptr += 1 193 | cur_len = 0 194 | return np.array(out) 195 | 196 | 197 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 198 | word_inds: Optional[torch.Tensor] = None): 199 | if type(bounds) is float: 200 | bounds = 0, bounds 201 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 202 | if word_inds is None: 203 | word_inds = torch.arange(alpha.shape[2]) 204 | alpha[: start, prompt_ind, word_inds] = 0 205 | alpha[start: end, prompt_ind, word_inds] = 1 206 | alpha[end:, prompt_ind, word_inds] = 0 207 | return alpha 208 | 209 | 210 | def get_time_words_attention_alpha(prompts, num_steps, 211 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 212 | tokenizer, max_num_words=77): 213 | if type(cross_replace_steps) is not dict: 214 | cross_replace_steps = {"default_": cross_replace_steps} 215 | if "default_" not in cross_replace_steps: 216 | cross_replace_steps["default_"] = (0., 1.) 217 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 218 | for i in range(len(prompts) - 1): 219 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 220 | i) 221 | for key, item in cross_replace_steps.items(): 222 | if key != "default_": 223 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 224 | for i, ind in enumerate(inds): 225 | if len(ind) > 0: 226 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 227 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 228 | return alpha_time_words 229 | -------------------------------------------------------------------------------- /pdm_permis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import numpy as np 4 | import torch 5 | from diffusers import DDIMScheduler 6 | from matplotlib import pyplot as plt 7 | from torch import nn 8 | from tqdm import tqdm 9 | from transformers import SamModel, SamProcessor 10 | import ptp_utils 11 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion 12 | from PerMIRS.eval_miou import AverageMeter, intersectionAndUnion 13 | from attention_store import AttentionStore 14 | from PerMIRS.visualization_utils import bbox_from_mask 15 | from sam_utils import show_points_on_image, show_masks_on_image, show_single_mask_on_image 16 | import torch.nn.functional as F 17 | 18 | 19 | def get_sam_masks(sam_model, processor, raw_image, input_points, input_labels, attn, ref_embeddings, input_masks=None, box=None, 20 | multimask_output=True, verbose=False): 21 | inputs = processor(raw_image, input_points=input_points, input_labels=input_labels, input_boxes=box, 22 | return_tensors="pt", 23 | attention_similarity=attn).to("cuda") 24 | inputs["attention_similarity"] = attn 25 | inputs["target_embedding"] = ref_embeddings 26 | inputs["input_masks"] = input_masks 27 | inputs["multimask_output"] = multimask_output 28 | with torch.no_grad(): 29 | outputs = sam_model(**inputs) 30 | 31 | masks = processor.image_processor.post_process_masks( 32 | outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() 33 | ) 34 | scores = outputs.iou_scores 35 | logits = outputs.pred_masks 36 | if verbose: 37 | if not multimask_output: 38 | scores = scores.reshape(1, -1) 39 | masks = masks[0].reshape(1, masks[0].shape[-2], masks[0].shape[-1]) 40 | show_single_mask_on_image(raw_image, masks[0], scores) 41 | else: 42 | show_masks_on_image(raw_image, masks[0], scores) 43 | show_points_on_image(raw_image, input_points[0], input_labels[0]) 44 | return masks, scores, logits 45 | 46 | def extract_attention(sd_model, image, prompt): 47 | controller = AttentionStore() 48 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents 49 | ptp_utils.register_attention_control_efficient(sd_model, controller) 50 | # recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0] 51 | key_format = f"up_blocks_3_attentions_1_transformer_blocks_0_attn1_self" 52 | timestamp = 49 53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"), 54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"), 55 | ], sd_model.unet.up_blocks[3].attentions[1].transformer_blocks[0].attn1 56 | 57 | 58 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None): 59 | if offsets is None: 60 | width, height = im.size # Get dimensions 61 | min_dim = min(width, height) 62 | left = (width - min_dim) / 2 63 | top = (height - min_dim) / 2 64 | right = (width + min_dim) / 2 65 | bottom = (height + min_dim) / 2 66 | 67 | if min_obj_x < left: 68 | diff = abs(left - min_obj_x) 69 | left = min_obj_x 70 | right = right - diff 71 | if max_obj_x > right: 72 | diff = abs(right - max_obj_x) 73 | right = max_obj_x 74 | left = left + diff 75 | else: 76 | left, top, right, bottom = offsets 77 | 78 | # Crop the center of the image 79 | im = im.crop((left, top, right, bottom)) 80 | return im, (left, top, right, bottom) 81 | 82 | 83 | def load_im_into_format_from_path(im_path, size=512, offsets=None): 84 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets) 85 | return im 86 | 87 | 88 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None): 89 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x) 90 | return im.resize((size, size)), offsets 91 | 92 | 93 | def diff_attn_images(sd_model, sam_model, sam_processor, ref_img, ref_mask, ref_attn_ls, tgt_img, target_attn_ls, 94 | verbose=True): 95 | image_size = 512 96 | prompt = f"A photo" 97 | h = w = 64 98 | mask_idx_y, mask_idx_x = torch.where(ref_mask) 99 | 100 | all_points = [] 101 | all_labels = [] 102 | for attn_idx, (ref_attn, target_attn) in enumerate(zip(ref_attn_ls, target_attn_ls)): 103 | heatmap = torch.zeros(ref_attn.shape[0]).to("cuda") 104 | for x, y in zip(mask_idx_x, mask_idx_y): 105 | t = np.ravel_multi_index((y, x), dims=(h, w)) 106 | source_vec = ref_attn[t].reshape(1, -1) 107 | cs_similarity = F.normalize(source_vec) @ F.normalize(target_attn).T 108 | idx = torch.sort(cs_similarity, descending=True)[1][0][:100] 109 | heatmap[idx] += 1 110 | heatmap = heatmap / heatmap.max() 111 | heatmap_img_size = \ 112 | nn.Upsample(size=(image_size, image_size), mode='bilinear')(heatmap.reshape(1, 1, 64, 64))[0][ 113 | 0] 114 | if verbose: 115 | plt.imshow(tgt_img) 116 | plt.imshow(heatmap_img_size.cpu(), alpha=0.6) 117 | plt.axis("off") 118 | plt.show() 119 | 120 | all_points += [(heatmap_img_size == torch.max(heatmap_img_size)).nonzero()[0].reshape(1, -1).cpu().numpy()] 121 | all_labels += [1] 122 | 123 | all_points = np.concatenate(all_points) 124 | all_points[:, [0, 1]] = all_points[:, [1, 0]] 125 | pred_masks, masks_scores, mask_logits = get_sam_masks(sam_model, sam_processor, tgt_img, 126 | input_points=[list(all_points)], 127 | input_labels=[all_labels], 128 | attn=None, 129 | ref_embeddings=None, 130 | verbose=verbose 131 | ) 132 | best_mask = 2 133 | y, x = pred_masks[0][0][best_mask].nonzero().T 134 | x_min = x.min().item() 135 | x_max = x.max().item() 136 | y_min = y.min().item() 137 | y_max = y.max().item() 138 | input_box = [[x_min, y_min, x_max, y_max]] 139 | pred_masks, masks_scores, mask_logits = get_sam_masks(sam_model, sam_processor, tgt_img, 140 | input_points=[list(all_points)], 141 | input_labels=[all_labels], 142 | attn=None, 143 | ref_embeddings=None, 144 | # input_masks=mask_logits[0, 0, best_mask: best_mask + 1, :,:], 145 | box=[input_box], 146 | multimask_output=True, 147 | verbose=verbose 148 | ) 149 | best_idx = 2 150 | final_pred_mask = pred_masks[0][:, best_idx, :, :].squeeze().numpy() 151 | return final_pred_mask 152 | 153 | 154 | if __name__ == "__main__": 155 | model_id = "CompVis/stable-diffusion-v1-4" 156 | 157 | device = "cuda" 158 | 159 | sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) 160 | sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") 161 | 162 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 163 | model_id, 164 | safety_checker=None, 165 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") 166 | ) 167 | sd_model = sd_model.to(device) 168 | img_size = 512 169 | attn_size = 64 170 | 171 | dataset_dir = "/PerMIRS" 172 | intersection_meter = AverageMeter() 173 | union_meter = AverageMeter() 174 | target_meter = AverageMeter() 175 | for vid_id in tqdm(os.listdir(dataset_dir)): 176 | frames = [] 177 | masks = [] 178 | masks_img_size = [] 179 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True) 180 | attn_ls = torch.load(f"{dataset_dir}/{vid_id}/diff_feats.pt") 181 | for f in range(3): 182 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])] 183 | m, curr_offsets = load_im_into_format_from_image( 184 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])), 185 | min_obj_x=xmin, max_obj_x=xmax) 186 | masks += [np.asarray(m.resize((attn_size, attn_size)))] 187 | masks_img_size += [np.asarray(m.resize((img_size, img_size)))] 188 | frames += [ 189 | load_im_into_format_from_path(f"{dataset_dir}/{vid_id}/{f}.jpg", offsets=curr_offsets).resize( 190 | (img_size, img_size)).convert("RGB")] 191 | 192 | ref_idx = 0 193 | tgt_idx = 1 194 | 195 | # remove tiny objects 196 | if masks_img_size[tgt_idx].sum() / (img_size * img_size) < 0.005: 197 | continue 198 | pred_mask = diff_attn_images(sd_model, sam_model, sam_processor, frames[ref_idx], torch.tensor(masks[ref_idx]), 199 | attn_ls[ref_idx][0:2], 200 | frames[tgt_idx], attn_ls[tgt_idx][0:2], verbose=False) 201 | 202 | pred_mask = np.uint8(pred_mask) 203 | gt_mask = np.uint8(masks_img_size[tgt_idx]) 204 | 205 | intersection, union, target = intersectionAndUnion(pred_mask, gt_mask) 206 | print(vid_id, intersection, union, target) 207 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) 208 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 209 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 210 | 211 | print("\nmIoU: %.2f" % (100 * iou_class)) 212 | print("mAcc: %.2f\n" % (100 * accuracy_class)) 213 | -------------------------------------------------------------------------------- /PerMIRS/video.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/Ali2500/BURST-benchmark 2 | from typing import Optional, List, Tuple, Union, Dict, Any 3 | 4 | import cv2 5 | import numpy as np 6 | import os.path as osp 7 | 8 | import utils 9 | from PerMIRS.visualization_utils import bbox_from_mask 10 | 11 | 12 | class BURSTVideo: 13 | def __init__(self, video_dict: Dict[str, Any], images_dir: Optional[str] = None): 14 | 15 | self.annotated_image_paths: List[str] = video_dict["annotated_image_paths"] 16 | self.all_images_paths: List[str] = video_dict["all_image_paths"] 17 | self.segmentations: List[Dict[int, Dict[str, Any]]] = video_dict["segmentations"] 18 | self._track_category_ids: Dict[int, int] = video_dict["track_category_ids"] 19 | self.image_size: Tuple[int, int] = (video_dict["height"], video_dict["width"]) 20 | 21 | self.id = video_dict["id"] 22 | self.dataset = video_dict["dataset"] 23 | self.name = video_dict["seq_name"] 24 | self.negative_category_ids = video_dict["neg_category_ids"] 25 | self.not_exhaustive_category_ids = video_dict["not_exhaustive_category_ids"] 26 | 27 | self._images_dir = images_dir 28 | 29 | @property 30 | def num_annotated_frames(self) -> int: 31 | return len(self.annotated_image_paths) 32 | 33 | @property 34 | def num_total_frames(self) -> int: 35 | return len(self.all_images_paths) 36 | 37 | @property 38 | def image_height(self) -> int: 39 | return self.image_size[0] 40 | 41 | @property 42 | def image_width(self) -> int: 43 | return self.image_size[1] 44 | 45 | @property 46 | def track_ids(self) -> List[int]: 47 | return list(sorted(self._track_category_ids.keys())) 48 | 49 | @property 50 | def track_category_ids(self) -> Dict[int, int]: 51 | return { 52 | track_id: self._track_category_ids[track_id] 53 | for track_id in self.track_ids 54 | } 55 | 56 | def filter_dataset_for_benchmark(self, masks_per_frame): 57 | # First, filter low quality frames and frames that do not contain more than two different instances of the same object 58 | global_relevant_tracks = [key for key, value in self.track_category_ids.items() if 59 | list(self.track_category_ids.values()).count(value) > 1] 60 | tracks_count = dict.fromkeys(global_relevant_tracks, 0) 61 | #tracks_bbox_size_per_frame = dict.fromkeys(global_relevant_tracks, list()) 62 | tracks_per_frame = dict.fromkeys(global_relevant_tracks, list()) 63 | # For each frame, first check if the frame contains multi-instance of the same object, then, proceed to 64 | # calculate #occurance for each track-id and it's size. 65 | for i in range(len(masks_per_frame)): 66 | curr_id_categories = dict(zip(self.segmentations[i].keys(), 67 | list(map(self.track_category_ids.get, self.segmentations[i].keys())))) 68 | curr_relevant_tracks = [key for key, value in curr_id_categories.items() if 69 | list(curr_id_categories.values()).count(value) > 1] 70 | if len(curr_relevant_tracks) >= 2: 71 | #potential_frames += [i] 72 | frame_masks = masks_per_frame[i] 73 | frame_size = list(frame_masks.values())[0].size 74 | tracks_count.update( 75 | zip(curr_relevant_tracks, list(map(lambda x: tracks_count.get(x) + 1, curr_relevant_tracks)))) 76 | for t_id in curr_relevant_tracks: 77 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(frame_masks[t_id])] 78 | relative_size = (abs(ymax - ymin) * abs(xmax - xmin) / frame_size) 79 | #relative_size = frame_masks[t_id].sum() / frame_size 80 | tracks_per_frame[t_id] = tracks_per_frame[t_id] + [(i, relative_size)] 81 | #tracks_per_frame.update( 82 | # zip(curr_relevant_tracks, 83 | # list(map(lambda x: tracks_per_frame.get(x) + [(i, relative_size)], curr_relevant_tracks)))) 84 | #tracks_bbox_size_per_frame[t_id] = tracks_bbox_size_per_frame[t_id] + [ 85 | # (abs(ymax - ymin) * abs(xmax - xmin) / frame_size)] 86 | 87 | # Take the instance which appeared the most during the video 88 | max_inst = max(tracks_count, key=tracks_count.get) 89 | # Select n frames, where the instance appeared the largest, make sure the frames are far apart. 90 | inst_size_per_frame = dict(tracks_per_frame[max_inst]) 91 | sorted_frames = np.array(list(dict(sorted(inst_size_per_frame.items(), key=lambda item: item[1], reverse=True)).keys())) 92 | first_frame = sorted_frames[0] 93 | final_frames = [first_frame] + list(sorted_frames[np.where(abs(sorted_frames - first_frame) > 15)[0][:2]]) 94 | final_masks = [] 95 | for f in final_frames: 96 | final_masks += [{max_inst: masks_per_frame[f][max_inst]}] 97 | return final_frames, final_masks 98 | 99 | def get_image_paths(self, frame_indices: Optional[List[int]] = None) -> List[str]: 100 | """ 101 | Get file paths to all image frames 102 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy 103 | 0 <= t < len(self.num_annotated_frames) 104 | :return: List of file paths 105 | """ 106 | if frame_indices is None: 107 | frame_indices = list(range(self.num_annotated_frames)) 108 | else: 109 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \ 110 | f"invalid" 111 | 112 | return [osp.join(self._images_dir, self.annotated_image_paths[t]) for t in frame_indices] 113 | 114 | def load_images(self, frame_indices: Optional[List[int]] = None) -> List[np.ndarray]: 115 | """ 116 | Load annotated image frames for the video 117 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy 118 | 0 <= t < len(self.num_annotated_frames) 119 | :return: List of images as numpy arrays of dtype uint8 and shape [H, W, 3] (RGB) 120 | """ 121 | assert self._images_dir is not None, f"Images cannot be loaded because 'images_dir' is None" 122 | 123 | if frame_indices is None: 124 | frame_indices = list(range(self.num_annotated_frames)) 125 | else: 126 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \ 127 | f"invalid" 128 | 129 | images = [] 130 | 131 | for t in frame_indices: 132 | filepath = osp.join(self._images_dir, self.annotated_image_paths[t]) 133 | assert osp.exists(filepath), f"Image file not found: '{filepath}'" 134 | images.append(cv2.imread(filepath, cv2.IMREAD_COLOR)[:, :, ::-1]) # convert BGR to RGB 135 | 136 | return images 137 | 138 | def images_paths(self, frame_indices: Optional[List[int]] = None) -> List[np.ndarray]: 139 | """ 140 | Load annotated image frames for the video 141 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy 142 | 0 <= t < len(self.num_annotated_frames) 143 | :return: List of images as numpy arrays of dtype uint8 and shape [H, W, 3] (RGB) 144 | """ 145 | assert self._images_dir is not None, f"Images cannot be loaded because 'images_dir' is None" 146 | 147 | if frame_indices is None: 148 | frame_indices = list(range(self.num_annotated_frames)) 149 | else: 150 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \ 151 | f"invalid" 152 | 153 | paths = [] 154 | 155 | for t in frame_indices: 156 | filepath = osp.join(self._images_dir, self.annotated_image_paths[t]) 157 | assert osp.exists(filepath), f"Image file not found: '{filepath}'" 158 | #images.append(cv2.imread(filepath, cv2.IMREAD_COLOR)[:, :, ::-1]) # convert BGR to RGB 159 | paths += [filepath] 160 | 161 | return paths 162 | 163 | def load_masks(self, frame_indices: Optional[List[int]] = None) -> List[Dict[int, np.ndarray]]: 164 | """ 165 | Decode RLE masks into mask images 166 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy 167 | 0 <= t < len(self.num_annotated_frames) 168 | :return: List of dicts (one per frame). Each dict has track IDs as keys and mask images as values. 169 | """ 170 | if frame_indices is None: 171 | frame_indices = list(range(self.num_annotated_frames)) 172 | else: 173 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \ 174 | f"invalid" 175 | 176 | zero_mask = np.zeros(self.image_size, bool) 177 | masks = [] 178 | 179 | for t in frame_indices: 180 | masks_t = dict() 181 | 182 | for track_id in self.track_ids: 183 | if track_id in self.segmentations[t]: 184 | masks_t[track_id] = utils.rle_ann_to_mask(self.segmentations[t][track_id]["rle"], self.image_size) 185 | else: 186 | masks_t[track_id] = zero_mask 187 | 188 | masks.append(masks_t) 189 | 190 | return masks 191 | 192 | def filter_category_ids(self, category_ids_to_keep: List[int]): 193 | track_ids_to_keep = [ 194 | track_id for track_id, category_id in self._track_category_ids.items() 195 | if category_id in category_ids_to_keep 196 | ] 197 | 198 | self._track_category_ids = { 199 | track_id: category_id for track_id, category_id in self._track_category_ids.items() 200 | if track_id in track_ids_to_keep 201 | } 202 | 203 | for t in range(self.num_annotated_frames): 204 | self.segmentations[t] = { 205 | track_id: seg for track_id, seg in self.segmentations[t].items() 206 | if track_id in track_ids_to_keep 207 | } 208 | 209 | def stats(self) -> Dict[str, Any]: 210 | total_masks = 0 211 | for segs_t in self.segmentations: 212 | total_masks += len(segs_t) 213 | 214 | return { 215 | "Annotated frames": self.num_annotated_frames, 216 | "Object tracks": len(self.track_ids), 217 | "Object masks": total_masks, 218 | "Unique category IDs": list(set(self.track_category_ids.values())) 219 | } 220 | 221 | def load_first_frame_annotations(self) -> List[Dict[int, Dict[str, Any]]]: 222 | annotations = [] 223 | for t in range(self.num_annotated_frames): 224 | annotations_t = dict() 225 | 226 | for track_id, annotation in self.segmentations[t].items(): 227 | annotations_t[track_id] = { 228 | "mask": utils.rle_ann_to_mask(annotation["rle"], self.image_size), 229 | "bbox": annotation["bbox"], # xywh format 230 | "point": annotation["point"] 231 | } 232 | 233 | annotations.append(annotations_t) 234 | 235 | return annotations 236 | -------------------------------------------------------------------------------- /StableDiffusionPipelineWithDDIMInversion.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/mkshing/svdiff-pytorch 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | import PIL 4 | import torch 5 | from diffusers import StableDiffusionPipeline, DDIMInverseScheduler 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess 7 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero import Pix2PixInversionPipelineOutput 8 | from diffusers.schedulers.scheduling_ddim_inverse import DDIMSchedulerOutput 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | class StableDiffusionPipelineWithDDIMInversion(StableDiffusionPipeline): 13 | def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, 14 | requires_safety_checker: bool = True): 15 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, 16 | requires_safety_checker) 17 | self.inverse_scheduler = DDIMInverseScheduler.from_config(self.scheduler.config) 18 | # self.register_modules(inverse_scheduler=DDIMInverseScheduler.from_config(self.scheduler.config)) 19 | 20 | def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): 21 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 22 | raise ValueError( 23 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 24 | ) 25 | 26 | image = image.to(device=device, dtype=dtype) 27 | 28 | if isinstance(generator, list) and len(generator) != batch_size: 29 | raise ValueError( 30 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 31 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 32 | ) 33 | 34 | if isinstance(generator, list): 35 | init_latents = [ 36 | self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) 37 | ] 38 | init_latents = torch.cat(init_latents, dim=0) 39 | else: 40 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 41 | 42 | init_latents = self.vae.config.scaling_factor * init_latents 43 | 44 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: 45 | raise ValueError( 46 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 47 | ) 48 | else: 49 | init_latents = torch.cat([init_latents], dim=0) 50 | 51 | latents = init_latents 52 | 53 | return latents 54 | 55 | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): 56 | pred_type = self.inverse_scheduler.config.prediction_type 57 | alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] 58 | 59 | beta_prod_t = 1 - alpha_prod_t 60 | 61 | if pred_type == "epsilon": 62 | return model_output 63 | elif pred_type == "sample": 64 | return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) 65 | elif pred_type == "v_prediction": 66 | return (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample 67 | else: 68 | raise ValueError( 69 | f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" 70 | ) 71 | 72 | def auto_corr_loss(self, hidden_states, generator=None): 73 | batch_size, channel, height, width = hidden_states.shape 74 | if batch_size > 1: 75 | raise ValueError("Only batch_size 1 is supported for now") 76 | 77 | hidden_states = hidden_states.squeeze(0) 78 | # hidden_states must be shape [C,H,W] now 79 | reg_loss = 0.0 80 | for i in range(hidden_states.shape[0]): 81 | noise = hidden_states[i][None, None, :, :] 82 | while True: 83 | roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() 84 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 85 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 86 | 87 | if noise.shape[2] <= 8: 88 | break 89 | noise = F.avg_pool2d(noise, kernel_size=2) 90 | return reg_loss 91 | 92 | def kl_divergence(self, hidden_states): 93 | mean = hidden_states.mean() 94 | var = hidden_states.var() 95 | return var + mean ** 2 - 1 - torch.log(var + 1e-7) 96 | 97 | def scheduler_next_step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True): 98 | # 1. get next step value (=t+1) 99 | next_timestep = timestep 100 | timestep = min( 101 | timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) 102 | # 2. compute alphas, betas 103 | # change original implementation to exactly match noise levels for analogous forward process 104 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 105 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 106 | 107 | beta_prod_t = 1 - alpha_prod_t 108 | 109 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 110 | pred_epsilon = model_output 111 | 112 | # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 113 | next_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon 114 | 115 | # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 116 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 117 | 118 | if not return_dict: 119 | return (next_sample, next_sample_direction) 120 | return DDIMSchedulerOutput(prev_sample=next_sample, pred_original_sample=next_sample_direction) 121 | 122 | def decode_to_tensor(self, latents): 123 | latents = 1 / self.vae.config.scaling_factor * latents 124 | image = self.vae.decode(latents).sample 125 | return image 126 | 127 | # based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py#L1063 128 | @torch.no_grad() 129 | def invert( 130 | self, 131 | prompt: Optional[str] = None, 132 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 133 | num_inference_steps: int = 50, 134 | guidance_scale: float = 1, 135 | num_images_per_prompt = 1, 136 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 137 | prompt_embeds: Optional[torch.FloatTensor] = None, 138 | output_type: Optional[str] = "pil", 139 | return_dict: bool = True, 140 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 141 | callback_steps: Optional[int] = 1, 142 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 143 | lambda_auto_corr: float = 20.0, 144 | lambda_kl: float = 20.0, 145 | num_reg_steps: int = 0, # disabled 146 | num_auto_corr_rolls: int = 5, 147 | ): 148 | # 1. Define call parameters 149 | if prompt is not None and isinstance(prompt, str): 150 | batch_size = 1 151 | elif prompt is not None and isinstance(prompt, list): 152 | batch_size = len(prompt) 153 | else: 154 | batch_size = prompt_embeds.shape[0] 155 | if cross_attention_kwargs is None: 156 | cross_attention_kwargs = {} 157 | 158 | device = self._execution_device 159 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 160 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 161 | # corresponds to doing no classifier free guidance. 162 | do_classifier_free_guidance = guidance_scale > 1.0 163 | 164 | # 3. Preprocess image 165 | image = preprocess(image) 166 | 167 | # 4. Prepare latent variables 168 | latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) 169 | 170 | # 5. Encode input prompt 171 | #num_images_per_prompt = num_images_per_prompt 172 | prompt_embeds = self._encode_prompt( 173 | prompt, 174 | device, 175 | num_images_per_prompt, 176 | do_classifier_free_guidance, 177 | prompt_embeds=prompt_embeds, 178 | ) 179 | 180 | # 4. Prepare timesteps 181 | self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) 182 | self.scheduler.set_timesteps(num_inference_steps, device=device) 183 | timesteps = self.inverse_scheduler.timesteps 184 | 185 | # 7. Denoising loop where we obtain the cross-attention maps. 186 | num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order 187 | with self.progress_bar(total=num_inference_steps) as progress_bar: 188 | for i, t in enumerate(timesteps): 189 | # expand the latents if we are doing classifier free guidance 190 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 191 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) 192 | 193 | # predict the noise residual 194 | noise_pred = self.unet( 195 | latent_model_input, 196 | t, 197 | encoder_hidden_states=prompt_embeds, 198 | cross_attention_kwargs=cross_attention_kwargs, 199 | ).sample 200 | 201 | # perform guidance 202 | if do_classifier_free_guidance: 203 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 204 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 205 | 206 | # regularization of the noise prediction 207 | with torch.enable_grad(): 208 | for _ in range(num_reg_steps): 209 | if lambda_auto_corr > 0: 210 | for _ in range(num_auto_corr_rolls): 211 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) 212 | 213 | # Derive epsilon from model output before regularizing to IID standard normal 214 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) 215 | 216 | l_ac = self.auto_corr_loss(var_epsilon, generator=generator) 217 | l_ac.backward() 218 | 219 | grad = var.grad.detach() / num_auto_corr_rolls 220 | noise_pred = noise_pred - lambda_auto_corr * grad 221 | 222 | if lambda_kl > 0: 223 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) 224 | 225 | # Derive epsilon from model output before regularizing to IID standard normal 226 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) 227 | 228 | l_kld = self.kl_divergence(var_epsilon) 229 | l_kld.backward() 230 | 231 | grad = var.grad.detach() 232 | noise_pred = noise_pred - lambda_kl * grad 233 | 234 | noise_pred = noise_pred.detach() 235 | 236 | # compute the previous noisy sample x_t -> x_t-1 237 | # latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample 238 | latents = self.scheduler_next_step(noise_pred, t, latents).prev_sample 239 | 240 | # call the callback, if provided 241 | if i == len(timesteps) - 1 or ( 242 | (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 243 | ): 244 | progress_bar.update() 245 | if callback is not None and i % callback_steps == 0: 246 | callback(i, t, latents) 247 | 248 | inverted_latents = latents.detach().clone() 249 | 250 | # 8. Post-processing 251 | image = self.decode_latents(latents.detach()) 252 | 253 | # Offload last model to CPU 254 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 255 | self.final_offload_hook.offload() 256 | 257 | # 9. Convert to PIL. 258 | if output_type == "pil": 259 | image = self.numpy_to_pil(image) 260 | 261 | if not return_dict: 262 | return (inverted_latents, image) 263 | 264 | return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) -------------------------------------------------------------------------------- /dift.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import numpy as np 4 | import torch 5 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 6 | from diffusers import DDIMScheduler 7 | from diffusers import StableDiffusionPipeline 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 11 | 12 | class MyUNet2DConditionModel(UNet2DConditionModel): 13 | def forward( 14 | self, 15 | sample: torch.FloatTensor, 16 | timestep: Union[torch.Tensor, float, int], 17 | up_ft_indices, 18 | encoder_hidden_states: torch.Tensor, 19 | class_labels: Optional[torch.Tensor] = None, 20 | timestep_cond: Optional[torch.Tensor] = None, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | cross_attention_kwargs: Optional[Dict[str, Any]] = None): 23 | r""" 24 | Args: 25 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 26 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 27 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 28 | cross_attention_kwargs (`dict`, *optional*): 29 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 30 | `self.processor` in 31 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 32 | """ 33 | # By default samples have to be AT least a multiple of the overall upsampling factor. 34 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 35 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 36 | # on the fly if necessary. 37 | default_overall_up_factor = 2 ** self.num_upsamplers 38 | 39 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 40 | forward_upsample_size = False 41 | upsample_size = None 42 | 43 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 44 | # logger.info("Forward upsample size to force interpolation output size.") 45 | forward_upsample_size = True 46 | 47 | # prepare attention_mask 48 | if attention_mask is not None: 49 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 50 | attention_mask = attention_mask.unsqueeze(1) 51 | 52 | # 0. center input if necessary 53 | if self.config.center_input_sample: 54 | sample = 2 * sample - 1.0 55 | 56 | # 1. time 57 | timesteps = timestep 58 | if not torch.is_tensor(timesteps): 59 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 60 | # This would be a good case for the `match` statement (Python 3.10+) 61 | is_mps = sample.device.type == "mps" 62 | if isinstance(timestep, float): 63 | dtype = torch.float32 if is_mps else torch.float64 64 | else: 65 | dtype = torch.int32 if is_mps else torch.int64 66 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 67 | elif len(timesteps.shape) == 0: 68 | timesteps = timesteps[None].to(sample.device) 69 | 70 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 71 | timesteps = timesteps.expand(sample.shape[0]) 72 | 73 | t_emb = self.time_proj(timesteps) 74 | 75 | # timesteps does not contain any weights and will always return f32 tensors 76 | # but time_embedding might actually be running in fp16. so we need to cast here. 77 | # there might be better ways to encapsulate this. 78 | t_emb = t_emb.to(dtype=self.dtype) 79 | 80 | emb = self.time_embedding(t_emb, timestep_cond) 81 | 82 | if self.class_embedding is not None: 83 | if class_labels is None: 84 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 85 | 86 | if self.config.class_embed_type == "timestep": 87 | class_labels = self.time_proj(class_labels) 88 | 89 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 90 | emb = emb + class_emb 91 | 92 | # 2. pre-process 93 | sample = self.conv_in(sample) 94 | 95 | # 3. down 96 | down_block_res_samples = (sample,) 97 | for downsample_block in self.down_blocks: 98 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 99 | sample, res_samples = downsample_block( 100 | hidden_states=sample, 101 | temb=emb, 102 | encoder_hidden_states=encoder_hidden_states, 103 | attention_mask=attention_mask, 104 | cross_attention_kwargs=cross_attention_kwargs, 105 | ) 106 | else: 107 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 108 | 109 | down_block_res_samples += res_samples 110 | 111 | # 4. mid 112 | if self.mid_block is not None: 113 | sample = self.mid_block( 114 | sample, 115 | emb, 116 | encoder_hidden_states=encoder_hidden_states, 117 | attention_mask=attention_mask, 118 | cross_attention_kwargs=cross_attention_kwargs, 119 | ) 120 | 121 | # 5. up 122 | up_ft = {} 123 | for i, upsample_block in enumerate(self.up_blocks): 124 | 125 | if i > np.max(up_ft_indices): 126 | break 127 | 128 | is_final_block = i == len(self.up_blocks) - 1 129 | 130 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 131 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 132 | 133 | # if we have not reached the final block and need to forward the 134 | # upsample size, we do it here 135 | if not is_final_block and forward_upsample_size: 136 | upsample_size = down_block_res_samples[-1].shape[2:] 137 | 138 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 139 | sample = upsample_block( 140 | hidden_states=sample, 141 | temb=emb, 142 | res_hidden_states_tuple=res_samples, 143 | encoder_hidden_states=encoder_hidden_states, 144 | cross_attention_kwargs=cross_attention_kwargs, 145 | upsample_size=upsample_size, 146 | attention_mask=attention_mask, 147 | ) 148 | else: 149 | sample = upsample_block( 150 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 151 | ) 152 | 153 | if i in up_ft_indices: 154 | up_ft[i] = sample.detach() 155 | 156 | output = {} 157 | output['up_ft'] = up_ft 158 | return output 159 | 160 | 161 | class OneStepSDPipeline(StableDiffusionPipeline): 162 | @torch.no_grad() 163 | def __call__( 164 | self, 165 | img_tensor, 166 | t, 167 | up_ft_indices, 168 | negative_prompt: Optional[Union[str, List[str]]] = None, 169 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 170 | prompt_embeds: Optional[torch.FloatTensor] = None, 171 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 172 | callback_steps: int = 1, 173 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 174 | ): 175 | device = self._execution_device 176 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor 177 | t = torch.tensor(t, dtype=torch.long, device=device) 178 | noise = torch.randn_like(latents).to(device) 179 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 180 | unet_output = self.unet(latents_noisy, 181 | t, 182 | up_ft_indices, 183 | encoder_hidden_states=prompt_embeds, 184 | cross_attention_kwargs=cross_attention_kwargs) 185 | return unet_output 186 | 187 | 188 | class SDFeaturizer: 189 | def __init__(self, sd_model): 190 | unet = MyUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") 191 | onestep_pipe = OneStepSDPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet, 192 | safety_checker=None) 193 | onestep_pipe.vae.decoder = None 194 | onestep_pipe.scheduler = sd_model.scheduler 195 | 196 | gc.collect() 197 | onestep_pipe = onestep_pipe.to("cuda") 198 | onestep_pipe.enable_attention_slicing() 199 | onestep_pipe.enable_xformers_memory_efficient_attention() 200 | self.pipe = onestep_pipe 201 | 202 | @torch.no_grad() 203 | def forward(self, 204 | img_tensor, # single image, [1,c,h,w] 205 | prompt, 206 | t=261, 207 | up_ft_index=1, 208 | ensemble_size=8): 209 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 210 | 211 | prompt_embeds = self.pipe._encode_prompt( 212 | prompt=prompt, 213 | device='cuda', 214 | num_images_per_prompt=1, 215 | do_classifier_free_guidance=False) # [1, 77, dim] 216 | 217 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1) 218 | torch.manual_seed(42) 219 | unet_ft_all = self.pipe( 220 | img_tensor=img_tensor, 221 | t=t, 222 | up_ft_indices=[up_ft_index], 223 | prompt_embeds=prompt_embeds) 224 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 225 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 226 | return unet_ft 227 | 228 | 229 | def get_correspondences(src_ft, tgt_ft, ref_bbox, img_size=512, topk=10): 230 | num_channel = src_ft.shape[1] 231 | ref_object_start_token = np.ravel_multi_index([(ref_bbox[1]), (ref_bbox[0])], (img_size, img_size)) 232 | bbox_w, bbox_h = ref_bbox[2] - ref_bbox[0], ref_bbox[3] - ref_bbox[1] 233 | with torch.no_grad(): 234 | src_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(src_ft) 235 | src_vec = F.normalize(src_ft.view(num_channel, -1)) # C, HW 236 | tgt_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(tgt_ft) 237 | trg_vec = F.normalize(tgt_ft.view(num_channel, -1)) # C, HW 238 | # For efficient computation on high-memory GPUs, process all tokens simultaneously rather than per row. 239 | all_ref_points = [] 240 | all_tgt_points = [] 241 | all_cosine_similarity = [] 242 | for i in range(bbox_h): 243 | curr_ref_tokens = list(range(ref_object_start_token + img_size*i, ref_object_start_token + img_size*i+bbox_w+1)) 244 | all_ref_points += [np.array(np.unravel_index(curr_ref_tokens, shape=(img_size, img_size))).T] 245 | cos_map = torch.matmul(src_vec.T[curr_ref_tokens], trg_vec) 246 | #tgt_tokens = cos_map.argmax(dim=-1).cpu().numpy() 247 | res = cos_map.topk(k=topk,dim=-1) 248 | all_cosine_similarity += [res[0].cpu().numpy()] 249 | tgt_tokens = res[1].cpu().numpy() 250 | all_tgt_points += [np.array(np.unravel_index(tgt_tokens, shape=(img_size, img_size))).T] 251 | 252 | return np.concatenate(all_ref_points), np.concatenate(all_tgt_points, axis=1).reshape(-1,topk,2), np.concatenate(all_cosine_similarity).reshape(-1) 253 | 254 | def get_correspondences_seg(src_ft, tgt_ft, src_mask, img_size=512, topk=10): 255 | num_channel = src_ft.shape[1] 256 | with torch.no_grad(): 257 | src_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(src_ft) 258 | src_vec = F.normalize(src_ft.view(num_channel, -1)) # C, HW 259 | tgt_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(tgt_ft) 260 | trg_vec = F.normalize(tgt_ft.view(num_channel, -1)) # C, HW 261 | # For efficient computation on high-memory GPUs, process all tokens simultaneously rather than per row. 262 | all_tgt_points = [] 263 | all_cosine_similarity = [] 264 | all_ref_points = np.column_stack(src_mask.nonzero()) 265 | for p in all_ref_points: 266 | curr_ref_tokens = np.ravel_multi_index([(p[1]), (p[0])], (img_size, img_size)) 267 | cos_map = torch.matmul(src_vec.T[curr_ref_tokens], trg_vec) 268 | #tgt_tokens = cos_map.argmax(dim=-1).cpu().numpy() 269 | res = cos_map.topk(k=topk,dim=-1) 270 | all_cosine_similarity += [res[0].cpu().numpy()] 271 | tgt_tokens = res[1].cpu().numpy() 272 | all_tgt_points += [np.array(np.unravel_index(tgt_tokens, shape=(img_size, img_size))).T] 273 | 274 | return all_ref_points, np.concatenate(all_tgt_points, axis=1).reshape(-1,topk,2), np.concatenate(all_cosine_similarity).reshape(-1) 275 | --------------------------------------------------------------------------------