├── dogs
├── source.png
└── target1.png
├── PerMIRS
├── utils.py
├── eval_miou.py
├── dataset.py
├── permirs_gen_dataset.py
├── visualization_utils.py
├── extract_diff_features.py
└── video.py
├── attention_store.py
├── sam_utils.py
├── README.md
├── pdm_matching.py
├── pdm_permir.py
├── ptp_utils.py
├── pdm_permis.py
├── StableDiffusionPipelineWithDDIMInversion.py
└── dift.py
/dogs/source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvirsamuel/PDM/HEAD/dogs/source.png
--------------------------------------------------------------------------------
/dogs/target1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvirsamuel/PDM/HEAD/dogs/target1.png
--------------------------------------------------------------------------------
/PerMIRS/utils.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/Ali2500/BURST-benchmark
2 | from typing import Dict, List, Any, Tuple
3 |
4 | import numpy as np
5 | import pycocotools.mask as cocomask
6 |
7 |
8 | def intify_track_ids(video_dict: Dict[str, Any]):
9 | video_dict["track_category_ids"] = {
10 | int(track_id): category_id for track_id, category_id in video_dict["track_category_ids"].items()
11 | }
12 |
13 | for t in range(len(video_dict["segmentations"])):
14 | video_dict["segmentations"][t] = {
15 | int(track_id): seg
16 | for track_id, seg in video_dict["segmentations"][t].items()
17 | }
18 |
19 | return video_dict
20 |
21 |
22 | def rle_ann_to_mask(rle: str, image_size: Tuple[int, int]) -> np.ndarray:
23 | return cocomask.decode({
24 | "size": image_size,
25 | "counts": rle.encode("utf-8")
26 | }).astype(bool)
27 |
28 |
29 | def mask_to_rle_ann(mask: np.ndarray) -> Dict[str, Any]:
30 | assert mask.ndim == 2, f"Mask must be a 2-D array, but got array of shape {mask.shape}"
31 | rle = cocomask.encode(np.asfortranarray(mask.astype(np.uint8)))
32 | rle["counts"] = rle["counts"].decode("utf-8")
33 | return rle
--------------------------------------------------------------------------------
/PerMIRS/eval_miou.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 |
4 |
5 |
6 | def get_arguments():
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument('--pred_path', type=str, default='')
10 | parser.add_argument('--gt_path', type=str, default='/inputs/Projects/PerDet/datasets/PerSeg/Annotations')
11 |
12 | parser.add_argument('--ref_idx', type=str, default='00')
13 |
14 | args = parser.parse_args()
15 | return args
16 |
17 |
18 | class AverageMeter(object):
19 | """Computes and stores the average and current value"""
20 |
21 | def __init__(self):
22 | self.reset()
23 |
24 | def reset(self):
25 | self.val = 0
26 | self.avg = 0
27 | self.sum = 0
28 | self.count = 0
29 |
30 | def update(self, val, n=1):
31 | self.val = val
32 | self.sum += val * n
33 | self.count += n
34 | self.avg = self.sum / self.count
35 |
36 |
37 | def intersectionAndUnion(output, target):
38 | assert (output.ndim in [1, 2, 3])
39 | assert output.shape == target.shape
40 | output = output.reshape(output.size).copy()
41 | target = target.reshape(target.size)
42 |
43 | area_intersection = np.logical_and(output, target).sum()
44 | area_union = np.logical_or(output, target).sum()
45 | area_target = target.sum()
46 |
47 | return area_intersection, area_union, area_target
--------------------------------------------------------------------------------
/PerMIRS/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from video import BURSTVideo
3 |
4 | import json
5 | import os.path as osp
6 |
7 | import utils
8 |
9 |
10 | class BURSTDataset:
11 | def __init__(self, annotations_file: str, images_base_dir: Optional[str] = None):
12 | with open(annotations_file, 'r') as fh:
13 | content = json.load(fh)
14 |
15 | # convert track IDs from str to int wherever they are used as dict keys (JSON format always parses dict keys as
16 | # strings)
17 | self._videos = [utils.intify_track_ids(video) for video in content["sequences"]]
18 | self._videos = list(filter(lambda x: x["dataset"] != "HACS" and x["dataset"] != "AVA" and x["dataset"] != "LaSOT", self._videos))
19 |
20 | self.category_names = {
21 | category["id"]: category["name"] for category in content["categories"]
22 | }
23 |
24 | self._split = content["split"]
25 |
26 | self.images_base_dir = images_base_dir
27 |
28 | @property
29 | def num_videos(self) -> int:
30 | return len(self._videos)
31 |
32 | def __getitem__(self, index) -> BURSTVideo:
33 | assert index < self.num_videos, f"Index {index} invalid since total number of videos is {self.num_videos}"
34 |
35 | video_dict = self._videos[index]
36 | if self.images_base_dir is None:
37 | video_images_dir = None
38 | else:
39 | video_images_dir = osp.join(self.images_base_dir, self._split, video_dict["dataset"], video_dict["seq_name"])
40 | assert osp.exists(video_images_dir), f"Images directory for video not found at expected path: '{video_images_dir}'"
41 |
42 | return BURSTVideo(video_dict, video_images_dir)
43 |
44 | def __iter__(self):
45 | for i in range(self.num_videos):
46 | yield self[i]
--------------------------------------------------------------------------------
/PerMIRS/permirs_gen_dataset.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from argparse import ArgumentParser
3 | from tqdm import tqdm
4 | from dataset import BURSTDataset
5 | import numpy as np
6 | import os
7 | import os.path as osp
8 |
9 |
10 | def main(args):
11 | dataset = BURSTDataset(annotations_file=args.annotations_file,
12 | images_base_dir=args.images_base_dir)
13 |
14 | for i in tqdm(range(dataset.num_videos)):
15 | try:
16 | video = dataset[i]
17 |
18 | if args.seq and f"{video.dataset}/{video.name}" != args.seq:
19 | continue
20 |
21 | print(f"- Dataset: {video.dataset}\n"
22 | f"- Name: {video.name}")
23 |
24 | if args.first_frame_annotations:
25 | annotations = video.load_first_frame_annotations()
26 | else:
27 | annotations = video.load_masks()
28 |
29 | frames_idx, annotations = video.filter_dataset_for_benchmark(annotations)
30 | image_paths = video.images_paths(frames_idx)
31 | if len(image_paths) < 3:
32 | raise Exception(i, "not enough images found")
33 |
34 | base_path = f"/PerMIRS/{i}"
35 | for f_idx, image_p in enumerate(image_paths):
36 | frame_p = base_path + f"/{f_idx}.{osp.split(image_p)[-1].split('.')[-1]}"
37 | os.makedirs(os.path.dirname(frame_p), exist_ok=True)
38 | shutil.copyfile(image_p, frame_p)
39 |
40 | np.save(base_path + "/masks.npz", annotations)
41 | except Exception as e:
42 | print(i, e)
43 |
44 |
45 | if __name__ == '__main__':
46 | parser = ArgumentParser()
47 |
48 | parser.add_argument("--images_base_dir", required=True)
49 | parser.add_argument("--annotations_file", required=True)
50 | parser.add_argument("--first_frame_annotations", action='store_true')
51 |
52 | # extra options
53 | parser.add_argument("--save_dir", required=False)
54 | parser.add_argument("--seq", required=False)
55 |
56 | main(parser.parse_args())
--------------------------------------------------------------------------------
/PerMIRS/visualization_utils.py:
--------------------------------------------------------------------------------
1 | # Codef from https://github.com/Ali2500/BURST-benchmark
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | def create_color_map():
7 | # This function has been copied with minor changes from the DAVIS dataset API at:
8 | # https://github.com/davisvideochallenge/davis2017-evaluation (Caelles, Pont-Tuset, et al.)
9 | N = 256
10 |
11 | def bitget(byteval, idx):
12 | return (byteval & (1 << idx)) != 0
13 |
14 | cmap = np.zeros((N, 3), dtype=np.uint8)
15 | for i in range(N):
16 | r = g = b = 0
17 | c = i
18 | for j in range(8):
19 | r = r | (bitget(c, 0) << 7-j)
20 | g = g | (bitget(c, 1) << 7-j)
21 | b = b | (bitget(c, 2) << 7-j)
22 | c = c >> 3
23 |
24 | cmap[i] = np.array([r, g, b])
25 |
26 | return cmap.tolist()
27 |
28 |
29 | def overlay_mask_on_image(image, mask, opacity, color):
30 | assert mask.ndim == 2
31 | mask_bgr = np.stack((mask, mask, mask), axis=2)
32 | masked_image = np.where(mask_bgr > 0, color, image)
33 | return ((opacity * masked_image) + ((1. - opacity) * image)).astype(np.uint8)
34 |
35 |
36 | def bbox_from_mask(mask):
37 | reduced_y = np.any(mask, axis=0)
38 | reduced_x = np.any(mask, axis=1)
39 |
40 | x_min = reduced_y.argmax()
41 | if x_min == 0 and reduced_y[0] == 0: # mask is all zeros
42 | return None
43 |
44 | x_max = len(reduced_y) - np.flip(reduced_y, 0).argmax()
45 |
46 | y_min = reduced_x.argmax()
47 | y_max = len(reduced_x) - np.flip(reduced_x, 0).argmax()
48 |
49 | return x_min, y_min, x_max, y_max
50 |
51 |
52 | def annotate_image(image, mask, color, label, point=None, **kwargs):
53 | """
54 | :param image: np.ndarray(H, W, 3)
55 | :param mask: np.ndarray(H, W)
56 | :param color: tuple/list(int, int, int) in range [0, 255]
57 | :param label: str
58 | :param kwargs: "bbox_thickness", "text_font", "font_size", "mask_opacity"
59 | :return: np.ndarray(H, W, 3)
60 | """
61 | annotated_image = overlay_mask_on_image(image, mask, color=color, opacity=kwargs.get("mask_opacity", 0.5))
62 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(mask)]
63 |
64 | bbox_thickness = kwargs.get("bbox_thickness", 2)
65 | text_font = kwargs.get("text_font", cv2.FONT_HERSHEY_SIMPLEX)
66 | font_size = kwargs.get("font_size", 0.5)
67 |
68 | annotated_image = cv2.rectangle(cv2.UMat(annotated_image), (xmin, ymin), (xmax, ymax), color=tuple(color),
69 | thickness=bbox_thickness)
70 |
71 | (text_width, text_height), _ = cv2.getTextSize(label, text_font, font_size, thickness=1)
72 | text_offset_x, text_offset_y = int(xmin + 2), int(ymin + text_height + 2)
73 |
74 | text_bg_box_pt1 = int(text_offset_x), int(text_offset_y + 2)
75 | text_bg_box_pt2 = int(text_offset_x + text_width + 2), int(text_offset_y - text_height - 2)
76 |
77 | annotated_image = cv2.rectangle(cv2.UMat(annotated_image), text_bg_box_pt1, text_bg_box_pt2, color=(255, 255, 255), thickness=-1)
78 | annotated_image = cv2.putText(cv2.UMat(annotated_image), label, (text_offset_x, text_offset_y), text_font, font_size, (0, 0, 0))
79 |
80 | if point is not None:
81 | # use a darker color so the point is more visible on the mask
82 | color = tuple([int(round(0.5 * c)) for c in color])
83 | annotated_image = cv2.circle(cv2.UMat(annotated_image), point, radius=3, color=color, thickness=-1)
84 |
85 | if isinstance(annotated_image, cv2.UMat):
86 | # sometimes OpenCV functions return objects of type cv2.UMat instead of numpy arrays
87 | return annotated_image.get()
88 | else:
89 | return annotated_image
90 |
--------------------------------------------------------------------------------
/attention_store.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://prompt-to-prompt.github.io/
2 | import abc
3 |
4 |
5 | class EmptyControl:
6 |
7 | def step_callback(self, x_t):
8 | return x_t
9 |
10 | def between_steps(self):
11 | return
12 |
13 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
14 | return attn
15 |
16 |
17 | class AttentionControl(abc.ABC):
18 |
19 | def step_callback(self, x_t):
20 | return x_t
21 |
22 | def between_steps(self):
23 | return
24 |
25 | def between_steps_inject(self):
26 | return
27 |
28 | @property
29 | def num_uncond_att_layers(self):
30 | return 0 #self.num_att_layers if LOW_RESOURCE else 0
31 |
32 | @abc.abstractmethod
33 | def forward(self, attn, is_cross: bool, place_in_unet: str):
34 | raise NotImplementedError
35 |
36 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
37 | if self.cur_att_layer >= self.num_uncond_att_layers:
38 | attn = self.forward(attn, is_cross, place_in_unet)
39 | return attn
40 |
41 | def check_next_step(self):
42 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
43 | self.cur_att_layer = 0
44 | self.cur_step += 1
45 | if self.is_inject:
46 | self.between_steps_inject()
47 | else:
48 | self.between_steps()
49 |
50 | def reset(self):
51 | self.cur_step = 0
52 | self.cur_att_layer = 0
53 |
54 | def __init__(self):
55 | self.cur_step = 0
56 | self.num_att_layers = -1
57 | self.cur_att_layer = 0
58 | self.is_inject = False
59 |
60 |
61 | class AttentionStore(AttentionControl):
62 |
63 | @staticmethod
64 | def get_empty_store():
65 | #return {"down_cross": [], "mid_cross": [], "up_cross": [],
66 | # "down_self": [], "mid_self": [], "up_self": []}
67 | return {}
68 |
69 | def forward(self, attn, is_cross: bool, place_in_unet: str):
70 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
71 | if attn is None:
72 | attn = self.attention_store[self.cur_step][key]
73 | else:
74 | self.step_store[key] = attn.cpu()
75 | #if attn.shape[1] <= 32 ** 2: # avoid memory overhead
76 | # self.step_store[key].append(attn)
77 | return attn
78 |
79 | def between_steps(self):
80 | self.attention_store[self.cur_step - 1] = self.step_store
81 | # if len(self.attention_store) == 0:
82 | # self.attention_store[self.cur_step-1] = self.step_store
83 | # else:
84 | # for key in self.attention_store:
85 | # for i in range(len(self.attention_store[key])):
86 | # self.attention_store[key][i] += self.step_store[key][i]
87 | self.step_store = self.get_empty_store()
88 |
89 | def between_steps_inject(self):
90 | self.step_store = self.get_empty_store()
91 |
92 | def get_average_attention(self):
93 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
94 | self.attention_store}
95 | return average_attention
96 |
97 | def reset(self):
98 | super(AttentionStore, self).reset()
99 | self.step_store = self.get_empty_store()
100 | self.attention_store = {}
101 |
102 | def __init__(self):
103 | super(AttentionStore, self).__init__()
104 | self.step_store = self.get_empty_store()
105 | self.attention_store = {}
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/sam_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def show_mask(mask, ax, random_color=False):
6 | if random_color:
7 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
8 | else:
9 | color = np.array([30/255, 144/255, 255/255, 0.6])
10 | h, w = mask.shape[-2:]
11 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
12 | ax.imshow(mask_image)
13 |
14 |
15 | def show_box(box, ax):
16 | x0, y0 = box[0], box[1]
17 | w, h = box[2] - box[0], box[3] - box[1]
18 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
19 |
20 | def show_boxes_on_image(raw_image, boxes):
21 | plt.figure(figsize=(10,10))
22 | plt.imshow(raw_image)
23 | for box in boxes:
24 | show_box(box, plt.gca())
25 | plt.axis('on')
26 | plt.show()
27 |
28 | def show_points_on_image(raw_image, input_points, input_labels=None):
29 | plt.figure(figsize=(10,10))
30 | plt.imshow(raw_image)
31 | input_points = np.array(input_points)
32 | if input_labels is None:
33 | labels = np.ones_like(input_points[:, 0])
34 | else:
35 | labels = np.array(input_labels)
36 | show_points(input_points, labels, plt.gca())
37 | plt.axis('on')
38 | plt.show()
39 |
40 | def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
41 | plt.figure(figsize=(10,10))
42 | plt.imshow(raw_image)
43 | input_points = np.array(input_points)
44 | if input_labels is None:
45 | labels = np.ones_like(input_points[:, 0])
46 | else:
47 | labels = np.array(input_labels)
48 | show_points(input_points, labels, plt.gca())
49 | for box in boxes:
50 | show_box(box, plt.gca())
51 | plt.axis('on')
52 | plt.show()
53 |
54 |
55 | def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
56 | plt.figure(figsize=(10,10))
57 | plt.imshow(raw_image)
58 | input_points = np.array(input_points)
59 | if input_labels is None:
60 | labels = np.ones_like(input_points[:, 0])
61 | else:
62 | labels = np.array(input_labels)
63 | show_points(input_points, labels, plt.gca())
64 | for box in boxes:
65 | show_box(box, plt.gca())
66 | plt.axis('on')
67 | plt.show()
68 |
69 |
70 | def show_points(coords, labels, ax, marker_size=375):
71 | pos_points = coords[labels==1]
72 | neg_points = coords[labels==0]
73 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
74 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
75 |
76 |
77 | def show_masks_on_image(raw_image, masks, scores):
78 | if len(masks.shape) == 4:
79 | masks = masks.squeeze()
80 | if scores.shape[0] == 1:
81 | scores = scores.squeeze()
82 |
83 | nb_predictions = scores.shape[-1]
84 | fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
85 |
86 | for i, (mask, score) in enumerate(zip(masks, scores)):
87 | mask = mask.cpu().detach()
88 | axes[i].imshow(np.array(raw_image))
89 | show_mask(mask, axes[i])
90 | axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
91 | axes[i].axis("off")
92 | plt.show()
93 |
94 | def show_single_mask_on_image(raw_image, mask, score):
95 | _ = plt.figure(figsize=(15, 15))
96 | ax = plt.gca()
97 | mask = mask.cpu().detach()
98 | ax.imshow(np.array(raw_image))
99 | show_mask(mask, ax)
100 | ax.title.set_text(f"Score: {score.item():.3f}")
101 | ax.axis("off")
102 | plt.show()
--------------------------------------------------------------------------------
/PerMIRS/extract_diff_features.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import numpy as np
4 | import torch
5 | from diffusers import DDIMScheduler
6 | from tqdm import tqdm
7 | import ptp_utils
8 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion
9 | from attention_store import AttentionStore
10 | from PerMIRS.visualization_utils import bbox_from_mask
11 |
12 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None):
13 | if offsets is None:
14 | width, height = im.size # Get dimensions
15 | min_dim = min(width, height)
16 | left = (width - min_dim) / 2
17 | top = (height - min_dim) / 2
18 | right = (width + min_dim) / 2
19 | bottom = (height + min_dim) / 2
20 |
21 | if min_obj_x < left:
22 | diff = abs(left - min_obj_x)
23 | left = min_obj_x
24 | right = right - diff
25 | if max_obj_x > right:
26 | diff = abs(right - max_obj_x)
27 | right = max_obj_x
28 | left = left + diff
29 | else:
30 | left, top, right, bottom = offsets
31 |
32 | # Crop the center of the image
33 | im = im.crop((left, top, right, bottom))
34 | return im, (left, top, right, bottom)
35 |
36 |
37 | def load_im_into_format_from_path(im_path, size=512, offsets=None):
38 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets)
39 | return im
40 |
41 |
42 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None):
43 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x)
44 | return im.resize((size, size)), offsets
45 |
46 | def extract_attention(sd_model, image, prompt):
47 | controller = AttentionStore()
48 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents
49 | ptp_utils.register_attention_control_efficient(sd_model, controller)
50 | # recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0]
51 | key_format = f"up_blocks_3_attentions_1_transformer_blocks_0_attn1_self"
52 | timestamp = 49
53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"),
54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"),
55 | ], sd_model.unet.up_blocks[3].attentions[1].transformer_blocks[0].attn1
56 |
57 | if __name__ == "__main__":
58 | model_id = "CompVis/stable-diffusion-v1-4"
59 |
60 | device = "cuda" # if torch.cuda.is_available() else "cpu"
61 |
62 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
63 | model_id,
64 | safety_checker=None,
65 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
66 | )
67 | sd_model = sd_model.to(device)
68 | img_size = 512
69 | attn_size = 64
70 |
71 | dataset_dir = "/PerMIRS"
72 | for vid_id in tqdm(os.listdir(dataset_dir)):
73 | try:
74 | frames = []
75 | masks = []
76 | masks_img_size = []
77 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True)
78 | for f in range(3):
79 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])]
80 | m, curr_offsets = load_im_into_format_from_image(
81 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])),
82 | min_obj_x=xmin, max_obj_x=xmax)
83 | masks += [np.asarray(m.resize((attn_size, attn_size)))]
84 | masks_img_size += [np.asarray(m.resize((img_size, img_size)))]
85 | frames += [
86 | load_im_into_format_from_path(f"{dataset_dir}/{vid_id}/{f}.jpg", offsets=curr_offsets).resize((img_size, img_size)).convert("RGB")]
87 |
88 | all_attn = []
89 | for f in frames:
90 | curr_attn, _ = extract_attention(sd_model, f, "A photo")
91 | all_attn += [curr_attn]
92 | torch.save(all_attn, f"{dataset_dir}/{vid_id}/diff_feats.pt")
93 | except Exception as e:
94 | f = open(f"{dataset_dir}/{vid_id}/error.txt", "w")
95 | f.writelines(str(e))
96 | f.close()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Where's Waldo: Diffusion Features For Personalized Segmentation and Retrieval (NeurIPS 2024)
2 | > Dvir Samuel, Rami Ben-Ari, Matan Levy, Nir Darshan, Gal Chechik
3 | > Bar Ilan University, The Hebrew University of Jerusalem, NVIDIA Research
4 |
5 | >
6 | >
7 | > Personalized retrieval and segmentation aim to locate specific instances within a dataset based on an input image and a short description of the reference instance. While supervised methods are effective, they require extensive labeled data for training. Recently, self-supervised foundation models have been introduced to these tasks showing comparable results to supervised methods. However, a significant flaw in these models is evident: they struggle to locate a desired instance when other instances within the same class are presented. In this paper, we explore text-to-image diffusion models for these tasks. Specifically, we propose a novel approach called PDM for Personalized Diffusion Features Matching, that leverages intermediate features of pre-trained text-to-image models for personalization tasks without any additional training. PDM demonstrates superior performance on popular retrieval and segmentation benchmarks, outperforming even supervised methods. We also highlight notable shortcomings in current instance and segmentation datasets and propose new benchmarks for these tasks.
8 |
9 |
10 |
11 | 
12 |
13 | 
14 | Personalized segmentation task involves segmenting a specific reference object in a new scene. Our method is capable to accurately identify the specific reference instance in the target image, even when other objects from the same class are present. While other methods capture visually or semantically similar objects, our method can successfully extract the identical instance, by using a new personalized feature map and fusing semantic and appearance cues. Red and green indicate incorrect and correct segmentations respectively.
15 |
16 |
17 |
18 |
19 | ## Requirements
20 |
21 | Quick installation using pip:
22 | ```
23 | torch==2.0.1
24 | torchvision==0.15.2
25 | diffusers==0.18.2
26 | transformers==4.32.0.dev0
27 | ```
28 |
29 | ## Personalized Diffusion Features Matching (PDM)
30 |
31 | To run PDM visualization between two images run the following:
32 |
33 | ```
34 | python pdm_matching.py
35 | ```
36 |
37 | ## PerMIR and PerMIS Datasets
38 |
39 | The PerMIR and PerMIS datasets were sourced from the [BURST](https://github.com/Ali2500/BURST-benchmark) repository.
40 |
41 | ### Instructions:
42 | 1. Download the datasets from the BURST repository. Place train,val, and test sets in the same directory.
43 | 2. Run the script `PerMIRS/permirs_gen_dataset.py` to prepare the personalization datasets. Ensure `--images_base_dir` contains the downloaded BURST splits. Additionally, set `--annotations_file` to all_classes.json.
44 | 3. Execute `PerMIRS/extract_diff_features.py` to extract PDM and DIFT features from each image in the dataset.
45 |
46 |
47 | ## Evaluation on PerMIR
48 |
49 | For PDM evaluation on PerMIR dataset (personalized retrieval) run:
50 |
51 | ```
52 | python pdm_permir.py
53 | ```
54 |
55 | ## Evaluation on PerMIS
56 |
57 | For PDM evaluation on PerMIS dataset (personalized segmentation) run:
58 |
59 | ```
60 | python pdm_permis.py
61 | ```
62 |
63 |
64 |
65 | ## Cite Our Paper
66 | If you find our paper and repo useful, please cite:
67 | ```
68 | @article{Samuel2024Waldo,
69 | title={Where's Waldo: Diffusion Features For Personalized Segmentation and Retrieval},
70 | author={Dvir Samuel and Rami Ben-Ari and Matan Levy and Nir Darshan and Gal Chechik},
71 | journal={NeurIPS},
72 | year={2024}
73 | }
74 | ```
75 |
--------------------------------------------------------------------------------
/pdm_matching.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import numpy as np
3 | import torch
4 | from diffusers import DDIMScheduler
5 | from matplotlib import pyplot as plt
6 | from torch import nn
7 | from transformers import SamModel, SamProcessor
8 | from attention_store import AttentionStore
9 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion
10 | #from sam_utils import show_masks_on_image, show_points_on_image
11 | import ptp_utils
12 |
13 |
14 | def center_crop(im):
15 | width, height = im.size # Get dimensions
16 | min_dim = min(width, height)
17 | left = (width - min_dim) / 2
18 | top = (height - min_dim) / 2
19 | right = (width + min_dim) / 2
20 | bottom = (height + min_dim) / 2
21 |
22 | # Crop the center of the image
23 | im = im.crop((left, top, right, bottom))
24 | return im
25 |
26 |
27 | def load_im_into_format_from_path(im_path):
28 | return center_crop(PIL.Image.open(im_path)).resize((512, 512))
29 |
30 | def get_masks(raw_image, input_points):
31 | inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
32 | with torch.no_grad():
33 | outputs = model(**inputs)
34 |
35 | masks = processor.image_processor.post_process_masks(
36 | outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
37 | )
38 | # visualize sam segmentation
39 | # scores = outputs.iou_scores
40 | # show_masks_on_image(raw_image, masks[0], scores)
41 | # show_points_on_image(raw_image, input_points[0])
42 | return masks
43 |
44 |
45 | def extract_attention(sd_model, image, prompt):
46 | controller = AttentionStore()
47 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents
48 | ptp_utils.register_attention_control_efficient(sd_model, controller)
49 | recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0]
50 | key_format = f"up_blocks_3_attentions_2_transformer_blocks_0_attn1_self"
51 | timestamp = 49
52 | res_attn = 64
53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"),
54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"),
55 | ], sd_model.unet.up_blocks[3].attentions[2].transformer_blocks[0].attn1, res_attn
56 |
57 | def heatmap(sd_model, ref_img, tgt_img):
58 | image_size = 512
59 | prompt = f"A photo"
60 | # extract query PDM features (Q and K)
61 | ref_attn_ls, _, res_attn = extract_attention(sd_model, ref_img, prompt)
62 | h = w = res_attn
63 |
64 | # get query mask using SAM or use provided mask from user
65 | source_masks = get_masks(ref_img.resize(size=(h, w)), [[[h // 2, w // 2]]])
66 | source_mask = source_masks[0][:, 1:2, :, :].squeeze(dim=0).squeeze(dim=0)
67 | mask_idx_y, mask_idx_x = torch.where(source_mask)
68 |
69 | # extract target PDM features (Q and K)
70 | target_attn_ls, _, _ = extract_attention(sd_model, tgt_img, prompt)
71 |
72 | # apply matching and show heatmap
73 | for attn_idx, (ref_attn, target_attn) in enumerate(zip(ref_attn_ls, target_attn_ls)):
74 | heatmap = torch.zeros(ref_attn.shape[0]).to("cuda")
75 | for x, y in zip(mask_idx_x, mask_idx_y):
76 | t = np.ravel_multi_index((y, x), dims=(h, w))
77 | source_vec = ref_attn[t].reshape(1, -1)
78 | euclidean_dist = torch.cdist(source_vec, target_attn)
79 | idx = torch.sort(euclidean_dist)[1][0][:100]
80 | heatmap[idx] += 1
81 | heatmap = heatmap / heatmap.max()
82 | heatmap_img_size = \
83 | nn.Upsample(size=(image_size, image_size), mode='bilinear')(heatmap.reshape(1, 1, 64, 64))[0][
84 | 0]
85 | plt.imshow(tgt_img)
86 | plt.imshow(heatmap_img_size.cpu(), alpha=0.6)
87 | plt.show()
88 |
89 |
90 | if __name__ == "__main__":
91 | model_id = "CompVis/stable-diffusion-v1-4"
92 |
93 | device = "cuda"
94 |
95 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
96 | model_id,
97 | safety_checker=None,
98 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
99 | )
100 | sd_model = sd_model.to(device)
101 | img_size = 512
102 |
103 | model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
104 | processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
105 |
106 | ref_img = load_im_into_format_from_path("dogs/source.png").convert("RGB")
107 | plt.imshow(ref_img)
108 | plt.show()
109 | tgt_img = load_im_into_format_from_path("dogs/target1.png").convert("RGB")
110 | plt.imshow(tgt_img)
111 | plt.show()
112 | heatmap(sd_model, ref_img, tgt_img)
113 |
114 |
--------------------------------------------------------------------------------
/pdm_permir.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import numpy as np
3 | import torch
4 | import os
5 | from diffusers import DDIMScheduler
6 | from sklearn.metrics import average_precision_score
7 | from torchvision.transforms import PILToTensor
8 | from tqdm import tqdm
9 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion
10 | from dift import SDFeaturizer, get_correspondences_seg
11 | from PerMIRS.visualization_utils import bbox_from_mask
12 | import torch.nn.functional as F
13 |
14 |
15 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None):
16 | if offsets is None:
17 | width, height = im.size # Get dimensions
18 | min_dim = min(width, height)
19 | left = (width - min_dim) / 2
20 | top = (height - min_dim) / 2
21 | right = (width + min_dim) / 2
22 | bottom = (height + min_dim) / 2
23 |
24 | if min_obj_x < left:
25 | diff = abs(left - min_obj_x)
26 | left = min_obj_x
27 | right = right - diff
28 | if max_obj_x > right:
29 | diff = abs(right - max_obj_x)
30 | right = max_obj_x
31 | left = left + diff
32 | else:
33 | left, top, right, bottom = offsets
34 |
35 | # Crop the center of the image
36 | im = im.crop((left, top, right, bottom))
37 | return im, (left, top, right, bottom)
38 |
39 |
40 | def load_im_into_format_from_path(im_path, size=512, offsets=None):
41 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets)
42 | return im
43 |
44 |
45 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None):
46 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x)
47 | return im.resize((size, size)), offsets
48 |
49 | if __name__ == "__main__":
50 | model_id = "CompVis/stable-diffusion-v1-4"
51 |
52 | device = "cuda"
53 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
54 | model_id,
55 | safety_checker=None,
56 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
57 | )
58 | sd_model = sd_model.to(device)
59 |
60 | dift = SDFeaturizer(sd_model)
61 | img_size = 512
62 | attn_size = 64
63 | dift_size = 32
64 |
65 | dataset_dir = "/PerMIRS"
66 |
67 | if "ret_data.pth" not in os.listdir(dataset_dir):
68 | # organize data for retrieval
69 | query_frames = []
70 | query_labels = []
71 | query_dift_features = []
72 | query_dift_mask = []
73 | query_perdiff_features = []
74 | gallery_frames = []
75 | gallery_labels = []
76 | gallery_dift_features = []
77 | gallery_perdiff_features = []
78 |
79 | for vid_idx, vid_id in tqdm(enumerate(os.listdir(dataset_dir))):
80 | masks_attn = []
81 | masks_dift = []
82 | masks_relative_size = []
83 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True)
84 | attn_ls = torch.load(f"{dataset_dir}/{vid_id}/diff_feats.pt")
85 | dift_features = []
86 | frame_paths = []
87 | for f in range(3):
88 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])]
89 | m, curr_offsets = load_im_into_format_from_image(
90 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])),
91 | min_obj_x=xmin, max_obj_x=xmax)
92 |
93 | masks_attn += [np.asarray(m.resize((attn_size, attn_size)))]
94 | masks_dift += [np.asarray(m.resize((dift_size, dift_size)))]
95 | masks_relative_size += [np.asarray(m).sum() / (img_size * img_size)]
96 | path = f"{dataset_dir}/{vid_id}/{f}.jpg"
97 | frame_paths += [path]
98 | frame = load_im_into_format_from_path(path, offsets=curr_offsets).resize(
99 | (img_size, img_size)).convert("RGB")
100 | curr_dift = dift.forward((PILToTensor()(frame) / 255.0 - 0.5) * 2,
101 | prompt="A photo",
102 | ensemble_size=2)
103 | dift_features += [curr_dift]
104 |
105 | masks_relative_size = np.array(masks_relative_size)
106 | # remove small frames
107 | if len(np.where(np.array(masks_relative_size) < 0.005)[0]) > 0:
108 | continue
109 |
110 | query_idx = masks_relative_size.argmax()
111 | for i in range(len(dift_features)):
112 | if i == query_idx:
113 | # query
114 | query_dift_features += [dift_features[i]]
115 | query_dift_mask += [masks_dift[i]]
116 | query_perdiff_features += [attn_ls[i]]
117 | query_labels += [vid_idx]
118 | query_frames += [frame_paths[i]]
119 | else:
120 | # gallery
121 | gallery_dift_features += [dift_features[i]]
122 | gallery_perdiff_features += [attn_ls[i]]
123 | gallery_labels += [vid_idx]
124 | gallery_frames += [frame_paths[i]]
125 |
126 | query_labels = torch.tensor(query_labels)
127 | gallery_labels = torch.tensor(gallery_labels)
128 |
129 | torch.save([query_frames, query_labels, query_dift_features, query_dift_mask, query_perdiff_features,
130 | gallery_frames, gallery_labels, gallery_dift_features, gallery_perdiff_features], f"{dataset_dir}/ret_data.pt")
131 | else:
132 | # retrieval performance on PerMIR
133 | query_frames, query_labels, query_dift_features, query_dift_mask, query_perdiff_features, \
134 | gallery_frames, gallery_labels, gallery_dift_features, gallery_perdiff_features = torch.load(
135 | f"{dataset_dir}/ret_data.pt")
136 | topk = 1
137 | recall_dict = {1: 0, 5: 0, 10: 0, 50: 0}
138 | ap = []
139 | for q_idx in tqdm(range(len(query_dift_features))):
140 | scores = []
141 | for g_idx in range(len(gallery_dift_features)):
142 | # First, extract correspondences using DIFT
143 | ref_points, tgt_points, dift_scores = get_correspondences_seg(query_dift_features[q_idx],
144 | gallery_dift_features[g_idx],
145 | query_dift_mask[q_idx],
146 | img_size=attn_size,
147 | topk=topk)
148 |
149 | total_maps_scores = []
150 | total_maps_scores_cs = []
151 | for attn_idx, (ref_attn, target_attn) in enumerate(
152 | zip(query_perdiff_features[q_idx], gallery_perdiff_features[g_idx])):
153 | all_point_scores = []
154 | all_point_scores_cs = []
155 | for p_ref, p_tgt, dift_score in zip(ref_points, tgt_points, dift_scores):
156 | source_vec = ref_attn[
157 | torch.tensor(np.ravel_multi_index(p_ref, dims=(attn_size, attn_size))).to("cuda")].reshape(1,
158 | -1)
159 | target_vec = target_attn[
160 | torch.tensor(np.ravel_multi_index(p_tgt.T, dims=(attn_size, attn_size))).to(
161 | "cuda")].reshape(topk, -1)
162 | euclidean_dist = torch.cdist(source_vec, target_vec)
163 | all_point_scores += [dift_score * euclidean_dist.mean().cpu().numpy()]
164 | all_point_scores_cs += [
165 | dift_score * (F.normalize(source_vec) @ F.normalize(target_vec).T).mean().cpu().numpy()]
166 | total_maps_scores += [np.mean(all_point_scores)]
167 | total_maps_scores_cs += [np.mean(all_point_scores_cs)]
168 | total_score = np.mean(total_maps_scores)
169 | total_score_cs = np.mean(total_maps_scores_cs)
170 | scores += [total_score_cs]
171 |
172 | pred_scores_idx = torch.argsort(torch.tensor(scores), descending=True) # change to false fro euclidean
173 | pred_g_labels = gallery_labels[pred_scores_idx]
174 | curr_query_lbl = query_labels[q_idx]
175 |
176 | ap += [average_precision_score((gallery_labels == curr_query_lbl).int().numpy(),scores)]
177 | for r in [1, 5, 10, 50]:
178 | if curr_query_lbl in pred_g_labels[:r]:
179 | recall_dict[r] += 1
180 |
181 | print("MAP:", np.array(ap).mean())
182 | for k in recall_dict.keys():
183 | print(f"Recall@{k}", recall_dict[k] / len(query_dift_features))
184 |
--------------------------------------------------------------------------------
/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://prompt-to-prompt.github.io/
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | import cv2
7 | from typing import Optional, Union, Tuple, Dict
8 | from matplotlib import pyplot as plt
9 |
10 |
11 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
12 | h, w, c = image.shape
13 | offset = int(h * .2)
14 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
15 | font = cv2.FONT_HERSHEY_SIMPLEX
16 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
17 | img[:h] = image
18 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
19 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
20 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
21 | return img
22 |
23 |
24 | def view_images(images, num_rows=1, offset_ratio=0.02):
25 | if type(images) is list:
26 | num_empty = len(images) % num_rows
27 | elif images.ndim == 4:
28 | num_empty = images.shape[0] % num_rows
29 | else:
30 | images = [images]
31 | num_empty = 0
32 |
33 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
34 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
35 | num_items = len(images)
36 |
37 | h, w, c = images[0].shape
38 | offset = int(h * offset_ratio)
39 | num_cols = num_items // num_rows
40 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
41 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
42 | for i in range(num_rows):
43 | for j in range(num_cols):
44 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
45 | i * num_cols + j]
46 |
47 | pil_img = Image.fromarray(image_)
48 | # display(pil_img)
49 | plt.imshow(pil_img)
50 | plt.show()
51 |
52 |
53 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
54 | if low_resource:
55 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
56 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
57 | else:
58 | latents_input = torch.cat([latents] * 2)
59 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
60 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
61 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
62 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
63 | latents = controller.step_callback(latents)
64 | return latents
65 |
66 |
67 | def latent2image(vae, latents):
68 | latents = 1 / 0.18215 * latents
69 | image = vae.decode(latents)['sample']
70 | image = (image / 2 + 0.5).clamp(0, 1)
71 | image = image.cpu().permute(0, 2, 3, 1).numpy()
72 | image = (image * 255).astype(np.uint8)
73 | return image
74 |
75 |
76 | def init_latent(latent, model, height, width, generator, batch_size):
77 | if latent is None:
78 | latent = torch.randn(
79 | (1, model.unet.in_channels, height // 8, width // 8),
80 | generator=generator,
81 | )
82 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
83 | return latent, latents
84 |
85 |
86 |
87 |
88 |
89 | def register_attention_control_efficient(model, controller):
90 | def ca_forward(self, place_in_unet):
91 | to_out = self.to_out
92 | if type(to_out) is torch.nn.modules.container.ModuleList:
93 | to_out = self.to_out[0]
94 | else:
95 | to_out = self.to_out
96 |
97 | def forward(x, encoder_hidden_states=None, attention_mask=None):
98 | batch_size, sequence_length, dim = x.shape
99 | h = self.heads
100 |
101 | is_cross = encoder_hidden_states is not None
102 | encoder_hidden_states = encoder_hidden_states if is_cross else x
103 | inject_cond = False
104 |
105 | if inject_cond:
106 | q = controller(None, is_cross, "Q_" + "_".join(place_in_unet)).cuda()
107 | k = controller(None, is_cross, "K_" + "_".join(place_in_unet)).cuda()
108 | else:
109 | q = self.to_q(x)
110 | k = self.to_k(encoder_hidden_states)
111 | if not controller.is_inject:
112 | q = controller(q, is_cross, "Q_" + "_".join(place_in_unet))
113 | k = controller(k, is_cross, "K_" + "_".join(place_in_unet))
114 |
115 |
116 | q = self.head_to_batch_dim(q)
117 | k = self.head_to_batch_dim(k)
118 |
119 | if inject_cond:
120 | v = controller(None, is_cross, "V_" + "_".join(place_in_unet)).cuda()
121 | else:
122 | v = self.to_v(encoder_hidden_states)
123 | if not controller.is_inject:
124 | v = controller(v, is_cross, "V_" + "_".join(place_in_unet))
125 | controller.cur_att_layer += 1
126 | controller.check_next_step()
127 | v = self.head_to_batch_dim(v)
128 |
129 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
130 |
131 | if attention_mask is not None:
132 | attention_mask = attention_mask.reshape(batch_size, -1)
133 | max_neg_value = -torch.finfo(sim.dtype).max
134 | attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
135 | sim.masked_fill_(~attention_mask, max_neg_value)
136 |
137 | # attention, what we cannot get enough of
138 | attn = sim.softmax(dim=-1)
139 |
140 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
141 | out = self.batch_to_head_dim(out)
142 |
143 | return to_out(out)
144 |
145 | return forward
146 |
147 | class DummyController:
148 |
149 | def __call__(self, *args):
150 | return args[0]
151 |
152 | def __init__(self):
153 | self.num_att_layers = 0
154 |
155 | if controller is None:
156 | controller = DummyController()
157 |
158 | def register_recr(net_, count, place_in_unet):
159 | if net_.__class__.__name__ == 'Attention':
160 | net_.forward = ca_forward(net_, place_in_unet)
161 | return count + 1
162 | elif hasattr(net_, 'children'):
163 | for name, net__ in net_.named_children():
164 | count = register_recr(net__, count, place_in_unet + [name])
165 | return count
166 |
167 | att_count = 0
168 | ref_modules = []
169 | sub_nets = model.unet.named_children()
170 | for name, net in sub_nets:
171 | if "down" in name or "up" in name or "mid" in name:
172 | att_count += register_recr(net, 0, [name])
173 | controller.num_att_layers = att_count
174 |
175 |
176 | def get_word_inds(text: str, word_place: int, tokenizer):
177 | split_text = text.split(" ")
178 | if type(word_place) is str:
179 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
180 | elif type(word_place) is int:
181 | word_place = [word_place]
182 | out = []
183 | if len(word_place) > 0:
184 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
185 | cur_len, ptr = 0, 0
186 |
187 | for i in range(len(words_encode)):
188 | cur_len += len(words_encode[i])
189 | if ptr in word_place:
190 | out.append(i + 1)
191 | if cur_len >= len(split_text[ptr]):
192 | ptr += 1
193 | cur_len = 0
194 | return np.array(out)
195 |
196 |
197 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
198 | word_inds: Optional[torch.Tensor] = None):
199 | if type(bounds) is float:
200 | bounds = 0, bounds
201 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
202 | if word_inds is None:
203 | word_inds = torch.arange(alpha.shape[2])
204 | alpha[: start, prompt_ind, word_inds] = 0
205 | alpha[start: end, prompt_ind, word_inds] = 1
206 | alpha[end:, prompt_ind, word_inds] = 0
207 | return alpha
208 |
209 |
210 | def get_time_words_attention_alpha(prompts, num_steps,
211 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
212 | tokenizer, max_num_words=77):
213 | if type(cross_replace_steps) is not dict:
214 | cross_replace_steps = {"default_": cross_replace_steps}
215 | if "default_" not in cross_replace_steps:
216 | cross_replace_steps["default_"] = (0., 1.)
217 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
218 | for i in range(len(prompts) - 1):
219 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
220 | i)
221 | for key, item in cross_replace_steps.items():
222 | if key != "default_":
223 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
224 | for i, ind in enumerate(inds):
225 | if len(ind) > 0:
226 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
227 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
228 | return alpha_time_words
229 |
--------------------------------------------------------------------------------
/pdm_permis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import numpy as np
4 | import torch
5 | from diffusers import DDIMScheduler
6 | from matplotlib import pyplot as plt
7 | from torch import nn
8 | from tqdm import tqdm
9 | from transformers import SamModel, SamProcessor
10 | import ptp_utils
11 | from StableDiffusionPipelineWithDDIMInversion import StableDiffusionPipelineWithDDIMInversion
12 | from PerMIRS.eval_miou import AverageMeter, intersectionAndUnion
13 | from attention_store import AttentionStore
14 | from PerMIRS.visualization_utils import bbox_from_mask
15 | from sam_utils import show_points_on_image, show_masks_on_image, show_single_mask_on_image
16 | import torch.nn.functional as F
17 |
18 |
19 | def get_sam_masks(sam_model, processor, raw_image, input_points, input_labels, attn, ref_embeddings, input_masks=None, box=None,
20 | multimask_output=True, verbose=False):
21 | inputs = processor(raw_image, input_points=input_points, input_labels=input_labels, input_boxes=box,
22 | return_tensors="pt",
23 | attention_similarity=attn).to("cuda")
24 | inputs["attention_similarity"] = attn
25 | inputs["target_embedding"] = ref_embeddings
26 | inputs["input_masks"] = input_masks
27 | inputs["multimask_output"] = multimask_output
28 | with torch.no_grad():
29 | outputs = sam_model(**inputs)
30 |
31 | masks = processor.image_processor.post_process_masks(
32 | outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
33 | )
34 | scores = outputs.iou_scores
35 | logits = outputs.pred_masks
36 | if verbose:
37 | if not multimask_output:
38 | scores = scores.reshape(1, -1)
39 | masks = masks[0].reshape(1, masks[0].shape[-2], masks[0].shape[-1])
40 | show_single_mask_on_image(raw_image, masks[0], scores)
41 | else:
42 | show_masks_on_image(raw_image, masks[0], scores)
43 | show_points_on_image(raw_image, input_points[0], input_labels[0])
44 | return masks, scores, logits
45 |
46 | def extract_attention(sd_model, image, prompt):
47 | controller = AttentionStore()
48 | inv_latents = sd_model.invert(prompt, image=image, guidance_scale=1.0).latents
49 | ptp_utils.register_attention_control_efficient(sd_model, controller)
50 | # recon_image = sd_model(prompt, latents=inv_latents, guidance_scale=1.0).images[0]
51 | key_format = f"up_blocks_3_attentions_1_transformer_blocks_0_attn1_self"
52 | timestamp = 49
53 | return [controller.attention_store[timestamp]["Q_" + key_format][0].to("cuda"),
54 | controller.attention_store[timestamp]["K_" + key_format][0].to("cuda"),
55 | ], sd_model.unet.up_blocks[3].attentions[1].transformer_blocks[0].attn1
56 |
57 |
58 | def center_crop(im, min_obj_x=None, max_obj_x=None, offsets=None):
59 | if offsets is None:
60 | width, height = im.size # Get dimensions
61 | min_dim = min(width, height)
62 | left = (width - min_dim) / 2
63 | top = (height - min_dim) / 2
64 | right = (width + min_dim) / 2
65 | bottom = (height + min_dim) / 2
66 |
67 | if min_obj_x < left:
68 | diff = abs(left - min_obj_x)
69 | left = min_obj_x
70 | right = right - diff
71 | if max_obj_x > right:
72 | diff = abs(right - max_obj_x)
73 | right = max_obj_x
74 | left = left + diff
75 | else:
76 | left, top, right, bottom = offsets
77 |
78 | # Crop the center of the image
79 | im = im.crop((left, top, right, bottom))
80 | return im, (left, top, right, bottom)
81 |
82 |
83 | def load_im_into_format_from_path(im_path, size=512, offsets=None):
84 | im, offsets = center_crop(PIL.Image.open(im_path), offsets=offsets)
85 | return im
86 |
87 |
88 | def load_im_into_format_from_image(image, size=512, min_obj_x=None, max_obj_x=None):
89 | im, offsets = center_crop(image, min_obj_x=min_obj_x, max_obj_x=max_obj_x)
90 | return im.resize((size, size)), offsets
91 |
92 |
93 | def diff_attn_images(sd_model, sam_model, sam_processor, ref_img, ref_mask, ref_attn_ls, tgt_img, target_attn_ls,
94 | verbose=True):
95 | image_size = 512
96 | prompt = f"A photo"
97 | h = w = 64
98 | mask_idx_y, mask_idx_x = torch.where(ref_mask)
99 |
100 | all_points = []
101 | all_labels = []
102 | for attn_idx, (ref_attn, target_attn) in enumerate(zip(ref_attn_ls, target_attn_ls)):
103 | heatmap = torch.zeros(ref_attn.shape[0]).to("cuda")
104 | for x, y in zip(mask_idx_x, mask_idx_y):
105 | t = np.ravel_multi_index((y, x), dims=(h, w))
106 | source_vec = ref_attn[t].reshape(1, -1)
107 | cs_similarity = F.normalize(source_vec) @ F.normalize(target_attn).T
108 | idx = torch.sort(cs_similarity, descending=True)[1][0][:100]
109 | heatmap[idx] += 1
110 | heatmap = heatmap / heatmap.max()
111 | heatmap_img_size = \
112 | nn.Upsample(size=(image_size, image_size), mode='bilinear')(heatmap.reshape(1, 1, 64, 64))[0][
113 | 0]
114 | if verbose:
115 | plt.imshow(tgt_img)
116 | plt.imshow(heatmap_img_size.cpu(), alpha=0.6)
117 | plt.axis("off")
118 | plt.show()
119 |
120 | all_points += [(heatmap_img_size == torch.max(heatmap_img_size)).nonzero()[0].reshape(1, -1).cpu().numpy()]
121 | all_labels += [1]
122 |
123 | all_points = np.concatenate(all_points)
124 | all_points[:, [0, 1]] = all_points[:, [1, 0]]
125 | pred_masks, masks_scores, mask_logits = get_sam_masks(sam_model, sam_processor, tgt_img,
126 | input_points=[list(all_points)],
127 | input_labels=[all_labels],
128 | attn=None,
129 | ref_embeddings=None,
130 | verbose=verbose
131 | )
132 | best_mask = 2
133 | y, x = pred_masks[0][0][best_mask].nonzero().T
134 | x_min = x.min().item()
135 | x_max = x.max().item()
136 | y_min = y.min().item()
137 | y_max = y.max().item()
138 | input_box = [[x_min, y_min, x_max, y_max]]
139 | pred_masks, masks_scores, mask_logits = get_sam_masks(sam_model, sam_processor, tgt_img,
140 | input_points=[list(all_points)],
141 | input_labels=[all_labels],
142 | attn=None,
143 | ref_embeddings=None,
144 | # input_masks=mask_logits[0, 0, best_mask: best_mask + 1, :,:],
145 | box=[input_box],
146 | multimask_output=True,
147 | verbose=verbose
148 | )
149 | best_idx = 2
150 | final_pred_mask = pred_masks[0][:, best_idx, :, :].squeeze().numpy()
151 | return final_pred_mask
152 |
153 |
154 | if __name__ == "__main__":
155 | model_id = "CompVis/stable-diffusion-v1-4"
156 |
157 | device = "cuda"
158 |
159 | sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
160 | sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
161 |
162 | sd_model = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
163 | model_id,
164 | safety_checker=None,
165 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
166 | )
167 | sd_model = sd_model.to(device)
168 | img_size = 512
169 | attn_size = 64
170 |
171 | dataset_dir = "/PerMIRS"
172 | intersection_meter = AverageMeter()
173 | union_meter = AverageMeter()
174 | target_meter = AverageMeter()
175 | for vid_id in tqdm(os.listdir(dataset_dir)):
176 | frames = []
177 | masks = []
178 | masks_img_size = []
179 | masks_np = np.load(f"{dataset_dir}/{vid_id}/masks.npz.npy", allow_pickle=True)
180 | attn_ls = torch.load(f"{dataset_dir}/{vid_id}/diff_feats.pt")
181 | for f in range(3):
182 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(list(masks_np[f].values())[0])]
183 | m, curr_offsets = load_im_into_format_from_image(
184 | PIL.Image.fromarray(np.uint8(list(masks_np[f].values())[0])),
185 | min_obj_x=xmin, max_obj_x=xmax)
186 | masks += [np.asarray(m.resize((attn_size, attn_size)))]
187 | masks_img_size += [np.asarray(m.resize((img_size, img_size)))]
188 | frames += [
189 | load_im_into_format_from_path(f"{dataset_dir}/{vid_id}/{f}.jpg", offsets=curr_offsets).resize(
190 | (img_size, img_size)).convert("RGB")]
191 |
192 | ref_idx = 0
193 | tgt_idx = 1
194 |
195 | # remove tiny objects
196 | if masks_img_size[tgt_idx].sum() / (img_size * img_size) < 0.005:
197 | continue
198 | pred_mask = diff_attn_images(sd_model, sam_model, sam_processor, frames[ref_idx], torch.tensor(masks[ref_idx]),
199 | attn_ls[ref_idx][0:2],
200 | frames[tgt_idx], attn_ls[tgt_idx][0:2], verbose=False)
201 |
202 | pred_mask = np.uint8(pred_mask)
203 | gt_mask = np.uint8(masks_img_size[tgt_idx])
204 |
205 | intersection, union, target = intersectionAndUnion(pred_mask, gt_mask)
206 | print(vid_id, intersection, union, target)
207 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
208 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
209 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
210 |
211 | print("\nmIoU: %.2f" % (100 * iou_class))
212 | print("mAcc: %.2f\n" % (100 * accuracy_class))
213 |
--------------------------------------------------------------------------------
/PerMIRS/video.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/Ali2500/BURST-benchmark
2 | from typing import Optional, List, Tuple, Union, Dict, Any
3 |
4 | import cv2
5 | import numpy as np
6 | import os.path as osp
7 |
8 | import utils
9 | from PerMIRS.visualization_utils import bbox_from_mask
10 |
11 |
12 | class BURSTVideo:
13 | def __init__(self, video_dict: Dict[str, Any], images_dir: Optional[str] = None):
14 |
15 | self.annotated_image_paths: List[str] = video_dict["annotated_image_paths"]
16 | self.all_images_paths: List[str] = video_dict["all_image_paths"]
17 | self.segmentations: List[Dict[int, Dict[str, Any]]] = video_dict["segmentations"]
18 | self._track_category_ids: Dict[int, int] = video_dict["track_category_ids"]
19 | self.image_size: Tuple[int, int] = (video_dict["height"], video_dict["width"])
20 |
21 | self.id = video_dict["id"]
22 | self.dataset = video_dict["dataset"]
23 | self.name = video_dict["seq_name"]
24 | self.negative_category_ids = video_dict["neg_category_ids"]
25 | self.not_exhaustive_category_ids = video_dict["not_exhaustive_category_ids"]
26 |
27 | self._images_dir = images_dir
28 |
29 | @property
30 | def num_annotated_frames(self) -> int:
31 | return len(self.annotated_image_paths)
32 |
33 | @property
34 | def num_total_frames(self) -> int:
35 | return len(self.all_images_paths)
36 |
37 | @property
38 | def image_height(self) -> int:
39 | return self.image_size[0]
40 |
41 | @property
42 | def image_width(self) -> int:
43 | return self.image_size[1]
44 |
45 | @property
46 | def track_ids(self) -> List[int]:
47 | return list(sorted(self._track_category_ids.keys()))
48 |
49 | @property
50 | def track_category_ids(self) -> Dict[int, int]:
51 | return {
52 | track_id: self._track_category_ids[track_id]
53 | for track_id in self.track_ids
54 | }
55 |
56 | def filter_dataset_for_benchmark(self, masks_per_frame):
57 | # First, filter low quality frames and frames that do not contain more than two different instances of the same object
58 | global_relevant_tracks = [key for key, value in self.track_category_ids.items() if
59 | list(self.track_category_ids.values()).count(value) > 1]
60 | tracks_count = dict.fromkeys(global_relevant_tracks, 0)
61 | #tracks_bbox_size_per_frame = dict.fromkeys(global_relevant_tracks, list())
62 | tracks_per_frame = dict.fromkeys(global_relevant_tracks, list())
63 | # For each frame, first check if the frame contains multi-instance of the same object, then, proceed to
64 | # calculate #occurance for each track-id and it's size.
65 | for i in range(len(masks_per_frame)):
66 | curr_id_categories = dict(zip(self.segmentations[i].keys(),
67 | list(map(self.track_category_ids.get, self.segmentations[i].keys()))))
68 | curr_relevant_tracks = [key for key, value in curr_id_categories.items() if
69 | list(curr_id_categories.values()).count(value) > 1]
70 | if len(curr_relevant_tracks) >= 2:
71 | #potential_frames += [i]
72 | frame_masks = masks_per_frame[i]
73 | frame_size = list(frame_masks.values())[0].size
74 | tracks_count.update(
75 | zip(curr_relevant_tracks, list(map(lambda x: tracks_count.get(x) + 1, curr_relevant_tracks))))
76 | for t_id in curr_relevant_tracks:
77 | xmin, ymin, xmax, ymax = [int(x) for x in bbox_from_mask(frame_masks[t_id])]
78 | relative_size = (abs(ymax - ymin) * abs(xmax - xmin) / frame_size)
79 | #relative_size = frame_masks[t_id].sum() / frame_size
80 | tracks_per_frame[t_id] = tracks_per_frame[t_id] + [(i, relative_size)]
81 | #tracks_per_frame.update(
82 | # zip(curr_relevant_tracks,
83 | # list(map(lambda x: tracks_per_frame.get(x) + [(i, relative_size)], curr_relevant_tracks))))
84 | #tracks_bbox_size_per_frame[t_id] = tracks_bbox_size_per_frame[t_id] + [
85 | # (abs(ymax - ymin) * abs(xmax - xmin) / frame_size)]
86 |
87 | # Take the instance which appeared the most during the video
88 | max_inst = max(tracks_count, key=tracks_count.get)
89 | # Select n frames, where the instance appeared the largest, make sure the frames are far apart.
90 | inst_size_per_frame = dict(tracks_per_frame[max_inst])
91 | sorted_frames = np.array(list(dict(sorted(inst_size_per_frame.items(), key=lambda item: item[1], reverse=True)).keys()))
92 | first_frame = sorted_frames[0]
93 | final_frames = [first_frame] + list(sorted_frames[np.where(abs(sorted_frames - first_frame) > 15)[0][:2]])
94 | final_masks = []
95 | for f in final_frames:
96 | final_masks += [{max_inst: masks_per_frame[f][max_inst]}]
97 | return final_frames, final_masks
98 |
99 | def get_image_paths(self, frame_indices: Optional[List[int]] = None) -> List[str]:
100 | """
101 | Get file paths to all image frames
102 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy
103 | 0 <= t < len(self.num_annotated_frames)
104 | :return: List of file paths
105 | """
106 | if frame_indices is None:
107 | frame_indices = list(range(self.num_annotated_frames))
108 | else:
109 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \
110 | f"invalid"
111 |
112 | return [osp.join(self._images_dir, self.annotated_image_paths[t]) for t in frame_indices]
113 |
114 | def load_images(self, frame_indices: Optional[List[int]] = None) -> List[np.ndarray]:
115 | """
116 | Load annotated image frames for the video
117 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy
118 | 0 <= t < len(self.num_annotated_frames)
119 | :return: List of images as numpy arrays of dtype uint8 and shape [H, W, 3] (RGB)
120 | """
121 | assert self._images_dir is not None, f"Images cannot be loaded because 'images_dir' is None"
122 |
123 | if frame_indices is None:
124 | frame_indices = list(range(self.num_annotated_frames))
125 | else:
126 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \
127 | f"invalid"
128 |
129 | images = []
130 |
131 | for t in frame_indices:
132 | filepath = osp.join(self._images_dir, self.annotated_image_paths[t])
133 | assert osp.exists(filepath), f"Image file not found: '{filepath}'"
134 | images.append(cv2.imread(filepath, cv2.IMREAD_COLOR)[:, :, ::-1]) # convert BGR to RGB
135 |
136 | return images
137 |
138 | def images_paths(self, frame_indices: Optional[List[int]] = None) -> List[np.ndarray]:
139 | """
140 | Load annotated image frames for the video
141 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy
142 | 0 <= t < len(self.num_annotated_frames)
143 | :return: List of images as numpy arrays of dtype uint8 and shape [H, W, 3] (RGB)
144 | """
145 | assert self._images_dir is not None, f"Images cannot be loaded because 'images_dir' is None"
146 |
147 | if frame_indices is None:
148 | frame_indices = list(range(self.num_annotated_frames))
149 | else:
150 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \
151 | f"invalid"
152 |
153 | paths = []
154 |
155 | for t in frame_indices:
156 | filepath = osp.join(self._images_dir, self.annotated_image_paths[t])
157 | assert osp.exists(filepath), f"Image file not found: '{filepath}'"
158 | #images.append(cv2.imread(filepath, cv2.IMREAD_COLOR)[:, :, ::-1]) # convert BGR to RGB
159 | paths += [filepath]
160 |
161 | return paths
162 |
163 | def load_masks(self, frame_indices: Optional[List[int]] = None) -> List[Dict[int, np.ndarray]]:
164 | """
165 | Decode RLE masks into mask images
166 | :param frame_indices: Optional argument specifying list of frame indices to load. All indices should be satisfy
167 | 0 <= t < len(self.num_annotated_frames)
168 | :return: List of dicts (one per frame). Each dict has track IDs as keys and mask images as values.
169 | """
170 | if frame_indices is None:
171 | frame_indices = list(range(self.num_annotated_frames))
172 | else:
173 | assert all([0 <= t < self.num_annotated_frames for t in frame_indices]), f"One or more frame indices are " \
174 | f"invalid"
175 |
176 | zero_mask = np.zeros(self.image_size, bool)
177 | masks = []
178 |
179 | for t in frame_indices:
180 | masks_t = dict()
181 |
182 | for track_id in self.track_ids:
183 | if track_id in self.segmentations[t]:
184 | masks_t[track_id] = utils.rle_ann_to_mask(self.segmentations[t][track_id]["rle"], self.image_size)
185 | else:
186 | masks_t[track_id] = zero_mask
187 |
188 | masks.append(masks_t)
189 |
190 | return masks
191 |
192 | def filter_category_ids(self, category_ids_to_keep: List[int]):
193 | track_ids_to_keep = [
194 | track_id for track_id, category_id in self._track_category_ids.items()
195 | if category_id in category_ids_to_keep
196 | ]
197 |
198 | self._track_category_ids = {
199 | track_id: category_id for track_id, category_id in self._track_category_ids.items()
200 | if track_id in track_ids_to_keep
201 | }
202 |
203 | for t in range(self.num_annotated_frames):
204 | self.segmentations[t] = {
205 | track_id: seg for track_id, seg in self.segmentations[t].items()
206 | if track_id in track_ids_to_keep
207 | }
208 |
209 | def stats(self) -> Dict[str, Any]:
210 | total_masks = 0
211 | for segs_t in self.segmentations:
212 | total_masks += len(segs_t)
213 |
214 | return {
215 | "Annotated frames": self.num_annotated_frames,
216 | "Object tracks": len(self.track_ids),
217 | "Object masks": total_masks,
218 | "Unique category IDs": list(set(self.track_category_ids.values()))
219 | }
220 |
221 | def load_first_frame_annotations(self) -> List[Dict[int, Dict[str, Any]]]:
222 | annotations = []
223 | for t in range(self.num_annotated_frames):
224 | annotations_t = dict()
225 |
226 | for track_id, annotation in self.segmentations[t].items():
227 | annotations_t[track_id] = {
228 | "mask": utils.rle_ann_to_mask(annotation["rle"], self.image_size),
229 | "bbox": annotation["bbox"], # xywh format
230 | "point": annotation["point"]
231 | }
232 |
233 | annotations.append(annotations_t)
234 |
235 | return annotations
236 |
--------------------------------------------------------------------------------
/StableDiffusionPipelineWithDDIMInversion.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://github.com/mkshing/svdiff-pytorch
2 | from typing import Any, Callable, Dict, List, Optional, Union
3 | import PIL
4 | import torch
5 | from diffusers import StableDiffusionPipeline, DDIMInverseScheduler
6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
7 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero import Pix2PixInversionPipelineOutput
8 | from diffusers.schedulers.scheduling_ddim_inverse import DDIMSchedulerOutput
9 | from matplotlib import pyplot as plt
10 |
11 |
12 | class StableDiffusionPipelineWithDDIMInversion(StableDiffusionPipeline):
13 | def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor,
14 | requires_safety_checker: bool = True):
15 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor,
16 | requires_safety_checker)
17 | self.inverse_scheduler = DDIMInverseScheduler.from_config(self.scheduler.config)
18 | # self.register_modules(inverse_scheduler=DDIMInverseScheduler.from_config(self.scheduler.config))
19 |
20 | def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
21 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
22 | raise ValueError(
23 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
24 | )
25 |
26 | image = image.to(device=device, dtype=dtype)
27 |
28 | if isinstance(generator, list) and len(generator) != batch_size:
29 | raise ValueError(
30 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
31 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
32 | )
33 |
34 | if isinstance(generator, list):
35 | init_latents = [
36 | self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
37 | ]
38 | init_latents = torch.cat(init_latents, dim=0)
39 | else:
40 | init_latents = self.vae.encode(image).latent_dist.sample(generator)
41 |
42 | init_latents = self.vae.config.scaling_factor * init_latents
43 |
44 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
45 | raise ValueError(
46 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
47 | )
48 | else:
49 | init_latents = torch.cat([init_latents], dim=0)
50 |
51 | latents = init_latents
52 |
53 | return latents
54 |
55 | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
56 | pred_type = self.inverse_scheduler.config.prediction_type
57 | alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
58 |
59 | beta_prod_t = 1 - alpha_prod_t
60 |
61 | if pred_type == "epsilon":
62 | return model_output
63 | elif pred_type == "sample":
64 | return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
65 | elif pred_type == "v_prediction":
66 | return (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
67 | else:
68 | raise ValueError(
69 | f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
70 | )
71 |
72 | def auto_corr_loss(self, hidden_states, generator=None):
73 | batch_size, channel, height, width = hidden_states.shape
74 | if batch_size > 1:
75 | raise ValueError("Only batch_size 1 is supported for now")
76 |
77 | hidden_states = hidden_states.squeeze(0)
78 | # hidden_states must be shape [C,H,W] now
79 | reg_loss = 0.0
80 | for i in range(hidden_states.shape[0]):
81 | noise = hidden_states[i][None, None, :, :]
82 | while True:
83 | roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
84 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
85 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
86 |
87 | if noise.shape[2] <= 8:
88 | break
89 | noise = F.avg_pool2d(noise, kernel_size=2)
90 | return reg_loss
91 |
92 | def kl_divergence(self, hidden_states):
93 | mean = hidden_states.mean()
94 | var = hidden_states.var()
95 | return var + mean ** 2 - 1 - torch.log(var + 1e-7)
96 |
97 | def scheduler_next_step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True):
98 | # 1. get next step value (=t+1)
99 | next_timestep = timestep
100 | timestep = min(
101 | timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
102 | # 2. compute alphas, betas
103 | # change original implementation to exactly match noise levels for analogous forward process
104 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
105 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
106 |
107 | beta_prod_t = 1 - alpha_prod_t
108 |
109 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
110 | pred_epsilon = model_output
111 |
112 | # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
113 | next_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon
114 |
115 | # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
116 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
117 |
118 | if not return_dict:
119 | return (next_sample, next_sample_direction)
120 | return DDIMSchedulerOutput(prev_sample=next_sample, pred_original_sample=next_sample_direction)
121 |
122 | def decode_to_tensor(self, latents):
123 | latents = 1 / self.vae.config.scaling_factor * latents
124 | image = self.vae.decode(latents).sample
125 | return image
126 |
127 | # based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py#L1063
128 | @torch.no_grad()
129 | def invert(
130 | self,
131 | prompt: Optional[str] = None,
132 | image: Union[torch.FloatTensor, PIL.Image.Image] = None,
133 | num_inference_steps: int = 50,
134 | guidance_scale: float = 1,
135 | num_images_per_prompt = 1,
136 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
137 | prompt_embeds: Optional[torch.FloatTensor] = None,
138 | output_type: Optional[str] = "pil",
139 | return_dict: bool = True,
140 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
141 | callback_steps: Optional[int] = 1,
142 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
143 | lambda_auto_corr: float = 20.0,
144 | lambda_kl: float = 20.0,
145 | num_reg_steps: int = 0, # disabled
146 | num_auto_corr_rolls: int = 5,
147 | ):
148 | # 1. Define call parameters
149 | if prompt is not None and isinstance(prompt, str):
150 | batch_size = 1
151 | elif prompt is not None and isinstance(prompt, list):
152 | batch_size = len(prompt)
153 | else:
154 | batch_size = prompt_embeds.shape[0]
155 | if cross_attention_kwargs is None:
156 | cross_attention_kwargs = {}
157 |
158 | device = self._execution_device
159 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
160 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
161 | # corresponds to doing no classifier free guidance.
162 | do_classifier_free_guidance = guidance_scale > 1.0
163 |
164 | # 3. Preprocess image
165 | image = preprocess(image)
166 |
167 | # 4. Prepare latent variables
168 | latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
169 |
170 | # 5. Encode input prompt
171 | #num_images_per_prompt = num_images_per_prompt
172 | prompt_embeds = self._encode_prompt(
173 | prompt,
174 | device,
175 | num_images_per_prompt,
176 | do_classifier_free_guidance,
177 | prompt_embeds=prompt_embeds,
178 | )
179 |
180 | # 4. Prepare timesteps
181 | self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
182 | self.scheduler.set_timesteps(num_inference_steps, device=device)
183 | timesteps = self.inverse_scheduler.timesteps
184 |
185 | # 7. Denoising loop where we obtain the cross-attention maps.
186 | num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
187 | with self.progress_bar(total=num_inference_steps) as progress_bar:
188 | for i, t in enumerate(timesteps):
189 | # expand the latents if we are doing classifier free guidance
190 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
191 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
192 |
193 | # predict the noise residual
194 | noise_pred = self.unet(
195 | latent_model_input,
196 | t,
197 | encoder_hidden_states=prompt_embeds,
198 | cross_attention_kwargs=cross_attention_kwargs,
199 | ).sample
200 |
201 | # perform guidance
202 | if do_classifier_free_guidance:
203 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
204 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
205 |
206 | # regularization of the noise prediction
207 | with torch.enable_grad():
208 | for _ in range(num_reg_steps):
209 | if lambda_auto_corr > 0:
210 | for _ in range(num_auto_corr_rolls):
211 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
212 |
213 | # Derive epsilon from model output before regularizing to IID standard normal
214 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
215 |
216 | l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
217 | l_ac.backward()
218 |
219 | grad = var.grad.detach() / num_auto_corr_rolls
220 | noise_pred = noise_pred - lambda_auto_corr * grad
221 |
222 | if lambda_kl > 0:
223 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
224 |
225 | # Derive epsilon from model output before regularizing to IID standard normal
226 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
227 |
228 | l_kld = self.kl_divergence(var_epsilon)
229 | l_kld.backward()
230 |
231 | grad = var.grad.detach()
232 | noise_pred = noise_pred - lambda_kl * grad
233 |
234 | noise_pred = noise_pred.detach()
235 |
236 | # compute the previous noisy sample x_t -> x_t-1
237 | # latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
238 | latents = self.scheduler_next_step(noise_pred, t, latents).prev_sample
239 |
240 | # call the callback, if provided
241 | if i == len(timesteps) - 1 or (
242 | (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
243 | ):
244 | progress_bar.update()
245 | if callback is not None and i % callback_steps == 0:
246 | callback(i, t, latents)
247 |
248 | inverted_latents = latents.detach().clone()
249 |
250 | # 8. Post-processing
251 | image = self.decode_latents(latents.detach())
252 |
253 | # Offload last model to CPU
254 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
255 | self.final_offload_hook.offload()
256 |
257 | # 9. Convert to PIL.
258 | if output_type == "pil":
259 | image = self.numpy_to_pil(image)
260 |
261 | if not return_dict:
262 | return (inverted_latents, image)
263 |
264 | return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)
--------------------------------------------------------------------------------
/dift.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import numpy as np
4 | import torch
5 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
6 | from diffusers import DDIMScheduler
7 | from diffusers import StableDiffusionPipeline
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from diffusers.models.unet_2d_condition import UNet2DConditionModel
11 |
12 | class MyUNet2DConditionModel(UNet2DConditionModel):
13 | def forward(
14 | self,
15 | sample: torch.FloatTensor,
16 | timestep: Union[torch.Tensor, float, int],
17 | up_ft_indices,
18 | encoder_hidden_states: torch.Tensor,
19 | class_labels: Optional[torch.Tensor] = None,
20 | timestep_cond: Optional[torch.Tensor] = None,
21 | attention_mask: Optional[torch.Tensor] = None,
22 | cross_attention_kwargs: Optional[Dict[str, Any]] = None):
23 | r"""
24 | Args:
25 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
26 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
27 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
28 | cross_attention_kwargs (`dict`, *optional*):
29 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
30 | `self.processor` in
31 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
32 | """
33 | # By default samples have to be AT least a multiple of the overall upsampling factor.
34 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
35 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
36 | # on the fly if necessary.
37 | default_overall_up_factor = 2 ** self.num_upsamplers
38 |
39 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
40 | forward_upsample_size = False
41 | upsample_size = None
42 |
43 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
44 | # logger.info("Forward upsample size to force interpolation output size.")
45 | forward_upsample_size = True
46 |
47 | # prepare attention_mask
48 | if attention_mask is not None:
49 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
50 | attention_mask = attention_mask.unsqueeze(1)
51 |
52 | # 0. center input if necessary
53 | if self.config.center_input_sample:
54 | sample = 2 * sample - 1.0
55 |
56 | # 1. time
57 | timesteps = timestep
58 | if not torch.is_tensor(timesteps):
59 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
60 | # This would be a good case for the `match` statement (Python 3.10+)
61 | is_mps = sample.device.type == "mps"
62 | if isinstance(timestep, float):
63 | dtype = torch.float32 if is_mps else torch.float64
64 | else:
65 | dtype = torch.int32 if is_mps else torch.int64
66 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
67 | elif len(timesteps.shape) == 0:
68 | timesteps = timesteps[None].to(sample.device)
69 |
70 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
71 | timesteps = timesteps.expand(sample.shape[0])
72 |
73 | t_emb = self.time_proj(timesteps)
74 |
75 | # timesteps does not contain any weights and will always return f32 tensors
76 | # but time_embedding might actually be running in fp16. so we need to cast here.
77 | # there might be better ways to encapsulate this.
78 | t_emb = t_emb.to(dtype=self.dtype)
79 |
80 | emb = self.time_embedding(t_emb, timestep_cond)
81 |
82 | if self.class_embedding is not None:
83 | if class_labels is None:
84 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
85 |
86 | if self.config.class_embed_type == "timestep":
87 | class_labels = self.time_proj(class_labels)
88 |
89 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
90 | emb = emb + class_emb
91 |
92 | # 2. pre-process
93 | sample = self.conv_in(sample)
94 |
95 | # 3. down
96 | down_block_res_samples = (sample,)
97 | for downsample_block in self.down_blocks:
98 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
99 | sample, res_samples = downsample_block(
100 | hidden_states=sample,
101 | temb=emb,
102 | encoder_hidden_states=encoder_hidden_states,
103 | attention_mask=attention_mask,
104 | cross_attention_kwargs=cross_attention_kwargs,
105 | )
106 | else:
107 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
108 |
109 | down_block_res_samples += res_samples
110 |
111 | # 4. mid
112 | if self.mid_block is not None:
113 | sample = self.mid_block(
114 | sample,
115 | emb,
116 | encoder_hidden_states=encoder_hidden_states,
117 | attention_mask=attention_mask,
118 | cross_attention_kwargs=cross_attention_kwargs,
119 | )
120 |
121 | # 5. up
122 | up_ft = {}
123 | for i, upsample_block in enumerate(self.up_blocks):
124 |
125 | if i > np.max(up_ft_indices):
126 | break
127 |
128 | is_final_block = i == len(self.up_blocks) - 1
129 |
130 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
131 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
132 |
133 | # if we have not reached the final block and need to forward the
134 | # upsample size, we do it here
135 | if not is_final_block and forward_upsample_size:
136 | upsample_size = down_block_res_samples[-1].shape[2:]
137 |
138 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
139 | sample = upsample_block(
140 | hidden_states=sample,
141 | temb=emb,
142 | res_hidden_states_tuple=res_samples,
143 | encoder_hidden_states=encoder_hidden_states,
144 | cross_attention_kwargs=cross_attention_kwargs,
145 | upsample_size=upsample_size,
146 | attention_mask=attention_mask,
147 | )
148 | else:
149 | sample = upsample_block(
150 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
151 | )
152 |
153 | if i in up_ft_indices:
154 | up_ft[i] = sample.detach()
155 |
156 | output = {}
157 | output['up_ft'] = up_ft
158 | return output
159 |
160 |
161 | class OneStepSDPipeline(StableDiffusionPipeline):
162 | @torch.no_grad()
163 | def __call__(
164 | self,
165 | img_tensor,
166 | t,
167 | up_ft_indices,
168 | negative_prompt: Optional[Union[str, List[str]]] = None,
169 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
170 | prompt_embeds: Optional[torch.FloatTensor] = None,
171 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
172 | callback_steps: int = 1,
173 | cross_attention_kwargs: Optional[Dict[str, Any]] = None
174 | ):
175 | device = self._execution_device
176 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
177 | t = torch.tensor(t, dtype=torch.long, device=device)
178 | noise = torch.randn_like(latents).to(device)
179 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
180 | unet_output = self.unet(latents_noisy,
181 | t,
182 | up_ft_indices,
183 | encoder_hidden_states=prompt_embeds,
184 | cross_attention_kwargs=cross_attention_kwargs)
185 | return unet_output
186 |
187 |
188 | class SDFeaturizer:
189 | def __init__(self, sd_model):
190 | unet = MyUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
191 | onestep_pipe = OneStepSDPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet,
192 | safety_checker=None)
193 | onestep_pipe.vae.decoder = None
194 | onestep_pipe.scheduler = sd_model.scheduler
195 |
196 | gc.collect()
197 | onestep_pipe = onestep_pipe.to("cuda")
198 | onestep_pipe.enable_attention_slicing()
199 | onestep_pipe.enable_xformers_memory_efficient_attention()
200 | self.pipe = onestep_pipe
201 |
202 | @torch.no_grad()
203 | def forward(self,
204 | img_tensor, # single image, [1,c,h,w]
205 | prompt,
206 | t=261,
207 | up_ft_index=1,
208 | ensemble_size=8):
209 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
210 |
211 | prompt_embeds = self.pipe._encode_prompt(
212 | prompt=prompt,
213 | device='cuda',
214 | num_images_per_prompt=1,
215 | do_classifier_free_guidance=False) # [1, 77, dim]
216 |
217 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
218 | torch.manual_seed(42)
219 | unet_ft_all = self.pipe(
220 | img_tensor=img_tensor,
221 | t=t,
222 | up_ft_indices=[up_ft_index],
223 | prompt_embeds=prompt_embeds)
224 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
225 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
226 | return unet_ft
227 |
228 |
229 | def get_correspondences(src_ft, tgt_ft, ref_bbox, img_size=512, topk=10):
230 | num_channel = src_ft.shape[1]
231 | ref_object_start_token = np.ravel_multi_index([(ref_bbox[1]), (ref_bbox[0])], (img_size, img_size))
232 | bbox_w, bbox_h = ref_bbox[2] - ref_bbox[0], ref_bbox[3] - ref_bbox[1]
233 | with torch.no_grad():
234 | src_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(src_ft)
235 | src_vec = F.normalize(src_ft.view(num_channel, -1)) # C, HW
236 | tgt_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(tgt_ft)
237 | trg_vec = F.normalize(tgt_ft.view(num_channel, -1)) # C, HW
238 | # For efficient computation on high-memory GPUs, process all tokens simultaneously rather than per row.
239 | all_ref_points = []
240 | all_tgt_points = []
241 | all_cosine_similarity = []
242 | for i in range(bbox_h):
243 | curr_ref_tokens = list(range(ref_object_start_token + img_size*i, ref_object_start_token + img_size*i+bbox_w+1))
244 | all_ref_points += [np.array(np.unravel_index(curr_ref_tokens, shape=(img_size, img_size))).T]
245 | cos_map = torch.matmul(src_vec.T[curr_ref_tokens], trg_vec)
246 | #tgt_tokens = cos_map.argmax(dim=-1).cpu().numpy()
247 | res = cos_map.topk(k=topk,dim=-1)
248 | all_cosine_similarity += [res[0].cpu().numpy()]
249 | tgt_tokens = res[1].cpu().numpy()
250 | all_tgt_points += [np.array(np.unravel_index(tgt_tokens, shape=(img_size, img_size))).T]
251 |
252 | return np.concatenate(all_ref_points), np.concatenate(all_tgt_points, axis=1).reshape(-1,topk,2), np.concatenate(all_cosine_similarity).reshape(-1)
253 |
254 | def get_correspondences_seg(src_ft, tgt_ft, src_mask, img_size=512, topk=10):
255 | num_channel = src_ft.shape[1]
256 | with torch.no_grad():
257 | src_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(src_ft)
258 | src_vec = F.normalize(src_ft.view(num_channel, -1)) # C, HW
259 | tgt_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(tgt_ft)
260 | trg_vec = F.normalize(tgt_ft.view(num_channel, -1)) # C, HW
261 | # For efficient computation on high-memory GPUs, process all tokens simultaneously rather than per row.
262 | all_tgt_points = []
263 | all_cosine_similarity = []
264 | all_ref_points = np.column_stack(src_mask.nonzero())
265 | for p in all_ref_points:
266 | curr_ref_tokens = np.ravel_multi_index([(p[1]), (p[0])], (img_size, img_size))
267 | cos_map = torch.matmul(src_vec.T[curr_ref_tokens], trg_vec)
268 | #tgt_tokens = cos_map.argmax(dim=-1).cpu().numpy()
269 | res = cos_map.topk(k=topk,dim=-1)
270 | all_cosine_similarity += [res[0].cpu().numpy()]
271 | tgt_tokens = res[1].cpu().numpy()
272 | all_tgt_points += [np.array(np.unravel_index(tgt_tokens, shape=(img_size, img_size))).T]
273 |
274 | return all_ref_points, np.concatenate(all_tgt_points, axis=1).reshape(-1,topk,2), np.concatenate(all_cosine_similarity).reshape(-1)
275 |
--------------------------------------------------------------------------------