├── Correspondence ├── eval.py └── sc_models │ ├── .DS_Store │ ├── dift │ ├── dift_sd.py │ └── get_cor.py │ ├── dino_vit │ ├── extractor.py │ └── get_cor.py │ ├── ldm_sc │ ├── get_cor.py │ ├── optimize.py │ └── ptp_utils.py │ └── sd_dino │ ├── cor_utils.py │ ├── extractor_dino.py │ ├── extractor_sd.py │ └── get_cor.py ├── README.md ├── Retrieve ├── Feature_extraction.py └── Retriever.py └── assets ├── pipeline.png └── teaser.png /Correspondence/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import cv2 4 | import torch 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | import logging 11 | import argparse 12 | import shutil 13 | 14 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 15 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 16 | for name in logging.root.manager.loggerDict: 17 | if re.match(prefix_re, name): 18 | logging.getLogger(name).setLevel(level) 19 | 20 | def get_cor_cfg(method): 21 | cor_cfg = {} 22 | if method == 'dift': 23 | cor_cfg['img_size'] = 768 24 | cor_cfg['ensemble_size'] = 8 25 | elif method == 'ldm_sc': 26 | cor_cfg['img_size'] = 512 27 | elif method == 'sd_dino': 28 | cor_cfg['model_type'] = 'dinov2_vitb14' 29 | elif method == 'dino_vit': 30 | cor_cfg['img_size'] = 256 31 | cor_cfg['model_type'] = 'dino_vits8' 32 | cor_cfg['stride'] = 4 33 | return cor_cfg 34 | 35 | def get_cor_pairs(method, model, src_image, trg_image, src_points, src_prompt, trg_prompt, cfg, transpose_img_func=lambda x:x, transpose_pts_func=lambda x, y: (x, y), device='cuda'): 36 | if method == 'dift': 37 | from sc_models.dift.get_cor import get_cor_pairs 38 | return get_cor_pairs(model, src_image, trg_image, src_points, src_prompt, trg_prompt, cfg['img_size'], cfg['ensemble_size'], return_cos_maps=cfg['visualize'], transpose_img_func=transpose_img_func, transpose_pts_func=transpose_pts_func) 39 | elif method == 'ldm_sc': # ldm_sc don't get transpose_img_func and transpose_pts_func cause it will be too slow. 40 | from sc_models.ldm_sc.get_cor import get_cor_pairs 41 | return get_cor_pairs(model, src_image, trg_image, src_points, cfg['img_size'], device), None 42 | elif method == 'sd_dino': 43 | from sc_models.sd_dino.get_cor import get_cor_pairs 44 | model, aug, extractor = model 45 | return get_cor_pairs(model, aug, extractor, src_image, trg_image, src_points, src_prompt, trg_prompt, transpose_img_func=transpose_img_func, transpose_pts_func=transpose_pts_func, device=device) 46 | elif method == 'dino_vit': 47 | from sc_models.dino_vit.get_cor import get_cor_pairs 48 | return get_cor_pairs(model, src_image, trg_image, src_points, cfg['img_size'], transpose_img_func=transpose_img_func, transpose_pts_func=transpose_pts_func, device=device) 49 | else: 50 | raise NotImplementedError 51 | 52 | def get_model(method, cor_cfg, device='cuda'): 53 | if method == 'dift': 54 | from sc_models.dift.dift_sd import SDFeaturizer 55 | return SDFeaturizer(device) 56 | elif method == 'ldm_sc': 57 | from sc_models.ldm_sc.optimize import load_ldm 58 | return load_ldm(device, 'CompVis/stable-diffusion-v1-4') 59 | elif method == 'sd_dino': 60 | from sc_models.sd_dino.extractor_sd import load_model 61 | from sc_models.sd_dino.extractor_dino import ViTExtractor 62 | model_type = cor_cfg['model_type'] 63 | stride = 14 if 'v2' in model_type else 8 64 | extractor = ViTExtractor(model_type, stride, device=device) 65 | model, aug = load_model(diffusion_ver='v1-5', image_size=960, num_timesteps=100, block_indices=(2,5,8,11)) 66 | return model, aug, extractor 67 | elif method == 'dino_vit': 68 | from sc_models.dino_vit.extractor import ViTExtractor 69 | model_type = cor_cfg['model_type'] 70 | stride = cor_cfg['stride'] 71 | return ViTExtractor(model_type, stride, device=device) 72 | 73 | def plot_img_pairs(imglist, src_points_list, trg_points_list, trg_mask, top_k=1, cos_map_list=None, save_name='corr.png', fig_size=5, alpha=0.45, scatter_size=30): 74 | num_imgs = top_k + 1 75 | src_images = len(imglist) - 1 76 | fig, axes = plt.subplots(src_images, num_imgs + 1, figsize=(fig_size*(num_imgs + 1), fig_size*src_images)) 77 | plt.tight_layout() 78 | 79 | for i in range(src_images): 80 | ax = axes[i] if src_images > 1 else axes 81 | ax[0].imshow(imglist[i]) 82 | ax[0].axis('off') 83 | ax[0].set_title('source') 84 | for x, y in src_points_list[i]: 85 | x, y = int(np.round(x)), int(np.round(y)) 86 | ax[0].scatter(x, y, s=scatter_size) 87 | 88 | for j in range(1, num_imgs): 89 | ax[j].imshow(imglist[-1]) 90 | if cos_map_list[0] is not None: 91 | heatmap = cos_map_list[i][j - 1][0] 92 | heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) # Normalize to [0, 1] 93 | ax[j].imshow(255 * heatmap, alpha=alpha, cmap='viridis') 94 | ax[j].axis('off') 95 | ax[j].scatter(trg_points_list[i][j - 1][0], trg_points_list[i][j - 1][1], c='C%d' % (j - 1), s=scatter_size) 96 | ax[j].set_title('target') 97 | 98 | ax[-1].imshow(trg_mask, cmap='gray') 99 | ax[-1].axis('off') 100 | ax[-1].set_title('target mask') 101 | trg_point = np.mean(trg_points_list[i], axis=0) 102 | ax[-1].scatter(trg_point[0], trg_point[1], c='C%d' % (j - 1), s=scatter_size) 103 | plt.plot() 104 | plt.savefig(save_name) 105 | plt.close() 106 | 107 | 108 | def nearest_distance_to_mask_contour(mask, x, y, threshold=122, stride=30): 109 | # Convert the boolean mask to an 8-bit image 110 | dist_list = [] 111 | last_mask = ((mask > 0).astype(np.uint8) * 255) 112 | threshold_list = list(range(0, 255, stride)) + [threshold] # the last one is the threshold value 113 | for mask_threshold in threshold_list: 114 | mask_8bit = ((mask > mask_threshold).astype(np.uint8) * 255) 115 | if mask_8bit.sum() == 0: 116 | mask_8bit = (mask == mask.max()).astype(np.uint8) * 255 117 | # Find the contours in the mask 118 | contours, _ = cv2.findContours(mask_8bit, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 119 | # Check if point is inside any contour 120 | num = 0 121 | y = min(y, np.array(mask).shape[0] - 1) 122 | x = min(x, np.array(mask).shape[1] - 1) 123 | for contour in contours: 124 | if cv2.pointPolygonTest(contour, (x, y), False) == 1: # Inside contour 125 | num += 1 126 | if num % 2 == 1: 127 | dist_list.append(0) 128 | continue 129 | 130 | # If point is outside all contours, find the minimum distance between the point and each contour 131 | min_distance = float('inf') 132 | for contour in contours: 133 | distance = cv2.pointPolygonTest(contour, (x, y), True) # Measure distance 134 | if abs(distance) < min_distance: 135 | min_distance = abs(distance) 136 | 137 | 138 | # normalize the distance with the diagonal length of the mask 139 | diag_len = np.sqrt(mask.shape[0]**2 + mask.shape[1]**2) 140 | dist_list.append(abs(min_distance) / diag_len) 141 | nss_value = np.array(mask)[int(y), int(x)] 142 | thres_dist = dist_list.pop() 143 | return dist_list, nss_value, thres_dist 144 | 145 | 146 | def dataset_walkthrough(base_dir, method, model, exp_name, cor_cfg={}, average_pts=True, visualize=False, mask_threshold=120, top_k=1, top_k_type='max', transpose_types=1, device='cuda'): 147 | eval_pairs = 0 148 | total_dists, nss_values, thres_dists, res_trg_points = {}, {}, {}, {} 149 | gt_dir = os.path.join(base_dir, 'GT') 150 | base_dir = os.path.join(base_dir, 'egocentric') 151 | transpose_img_funcs = [ 152 | lambda x:x, 153 | lambda x:x.rotate(90, expand=True), 154 | lambda x:x.rotate(180, expand=True), 155 | lambda x:x.rotate(-90, expand=True), 156 | lambda x:x.transpose(Image.FLIP_LEFT_RIGHT), 157 | lambda x:x.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True), 158 | lambda x:x.transpose(Image.FLIP_LEFT_RIGHT).rotate(180, expand=True), 159 | lambda x:x.transpose(Image.FLIP_LEFT_RIGHT).rotate(-90, expand=True), 160 | ] 161 | for trg_object in os.listdir(base_dir): 162 | eval_pairs += len(os.listdir(os.path.join(base_dir, trg_object))) 163 | print(f'Start evaluating {eval_pairs} correspondance pairs...') 164 | 165 | cor_cfg['device'] = device 166 | cor_cfg['visualize'] = visualize 167 | 168 | pbar = tqdm(total=eval_pairs) 169 | if visualize: 170 | if os.path.exists(f'results_arxiv/{method}/{exp_name}'): 171 | confrim = input(f'The result folder {method}/{exp_name} already exists. input y to remove it...') 172 | if confrim == 'y': 173 | shutil.rmtree(f'results_arxiv/{method}/{exp_name}', ignore_errors=True) 174 | else: 175 | exit() 176 | for trg_object in os.listdir(base_dir): 177 | object_path = os.path.join(base_dir, trg_object) 178 | total_dists[trg_object], nss_values[trg_object], thres_dists[trg_object], res_trg_points[trg_object] = [], [], [], {} 179 | for instance in os.listdir(object_path): 180 | instance_path = os.path.join(object_path, instance) 181 | src_images, src_points_list, trg_points_list, cor_map_list, src_object_list = [], [], [], [], [] 182 | for file in os.listdir(instance_path): 183 | if file.endswith('.jpg'): 184 | trg_image = os.path.join(instance_path, file) 185 | mask_file = os.path.join(gt_dir, trg_object, file.replace('jpg', 'png')) 186 | with Image.open(mask_file) as img: 187 | try: 188 | trg_mask = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY) 189 | except: 190 | trg_mask = np.array(img) 191 | elif file.endswith('.txt') and ('top' not in file or int(file.strip('.txt').split('top')[1]) <= top_k): 192 | src_images.append(os.path.join(instance_path, file).replace('txt', 'png')) 193 | src_object_list.append(file.split('_')[0]) 194 | with open(os.path.join(instance_path, file), 'r') as f: 195 | lines = f.readlines() 196 | src_points = [list(map(float, line.rstrip().split(','))) for line in lines if re.match(r'^\d+.\d+,.*\d+.\d+$', line.rstrip())] 197 | if average_pts: 198 | src_points = [np.mean(np.array(src_points), axis=0).astype(np.int32)] 199 | src_points_list.append(src_points) 200 | pbar.set_description(f'{trg_object}-{instance}') 201 | trg_prompt = f'a photo of {trg_object}' 202 | imglist, new_src_points_list, cor_values_list = [], [], [] 203 | for i in range(len(src_images)): 204 | src_prompt = f'a photo of a {src_object_list[i]}' 205 | w, h = Image.open(src_images[i]).size 206 | transpose_pts_funcs = [ 207 | lambda x, y: (x, y), 208 | lambda x, y: (y, w - x), # 90 209 | lambda x, y: (w - x, h - y), # 180 210 | lambda x, y: (h - y, x), # -90 211 | lambda x, y: (w - x, y), # flip 212 | lambda x, y: (y, x), # flip 90 213 | lambda x, y: (x, h - y), # flip 180 214 | lambda x, y: (h - y, w - x), # flip -90 215 | ] 216 | trg_pnts_tmp, src_pnts_tmp, cor_maps_tmp, src_image_tmp, cor_values_tmp = [], [], [], [], [] 217 | os.makedirs(f'results_arxiv/{method}/{exp_name}/{trg_object}', exist_ok=True) 218 | for j in range(transpose_types): 219 | trg_pnts, src_pnts, cor_maps, src_image, cor_values = get_cor_pairs(method, model, src_images[i], trg_image, src_points_list[i], src_prompt, trg_prompt, cor_cfg, transpose_img_funcs[j], transpose_pts_funcs[j], device) 220 | trg_pnts_tmp.append(trg_pnts) 221 | src_pnts_tmp.append(src_pnts) 222 | cor_maps_tmp.append(cor_maps) 223 | src_image_tmp.append(src_image) 224 | cor_values_tmp.append(np.mean(cor_values)) 225 | # cor_values_str = ", ".join(map(lambda x: "%.2f" % x, cor_values)) 226 | # plot_img_pairs([src_image, Image.open(trg_image).convert('RGB')], [src_pnts], [trg_pnts], trg_mask, [cor_maps], os.path.join(f'results_arxiv/{method}/{exp_name}/{trg_object}', f'{instance}_{j}_{cor_values_str}.png')) 227 | selected_idx = np.argmax(cor_values_tmp) 228 | new_src_points_list.append(src_pnts_tmp[selected_idx]) 229 | trg_points_list.append(trg_pnts_tmp[selected_idx]) 230 | cor_map_list.append(cor_maps_tmp[selected_idx]) 231 | cor_values_list.append(cor_values_tmp[selected_idx]) 232 | imglist.append(src_image_tmp[selected_idx]) 233 | trg_points = np.mean(trg_points_list, axis=1) 234 | if top_k_type == 'max': 235 | trg_point = trg_points[np.argmax(cor_values_list)] 236 | elif top_k_type == 'avg': 237 | trg_point = np.mean(trg_points, axis=0) 238 | trg_dist, nss_value, thres_dist = nearest_distance_to_mask_contour(trg_mask, trg_point[0], trg_point[1], mask_threshold) 239 | total_dists[trg_object].append(trg_dist) 240 | thres_dists[trg_object].append(thres_dist) 241 | nss_values[trg_object].append(nss_value) 242 | res_trg_points[trg_object][instance] = trg_points_list 243 | # print(trg_point, trg_dist)ipy 244 | if visualize: 245 | res_dir = f'results_arxiv/{method}/{exp_name}/{trg_object}' 246 | imglist.append(Image.open(trg_image).convert('RGB')) 247 | os.makedirs(res_dir, exist_ok=True) 248 | file_name = f'{instance}_{thres_dist:.2f}_{nss_value}' 249 | if top_k_type == 'max': 250 | file_name += f'_max_idx{np.argmax(cor_values_list)}' 251 | plot_img_pairs(imglist, new_src_points_list, trg_points_list, trg_mask, top_k, cor_map_list, os.path.join(res_dir, file_name + '.png')) 252 | pbar.update(1) 253 | pbar.close() 254 | return total_dists, nss_values, thres_dists, res_trg_points 255 | 256 | 257 | def analyze_dists(total_dists, nss_values, thres_dists, res_dir=None): 258 | all_dists, all_nss, thres_dist, lines = [], [], [], [] 259 | 260 | if res_dir is not None: 261 | fig, axes = plt.subplots(1, 3, figsize=(20,5)) 262 | for trg_object in total_dists.keys(): 263 | all_dists += total_dists[trg_object] 264 | dist_curve = np.array(total_dists[trg_object]).mean(axis=0) 265 | sr_curve = (np.array(total_dists[trg_object]) == 0).sum(axis=0) / len(total_dists[trg_object]) 266 | axes[0].plot(dist_curve, label=trg_object) 267 | axes[1].plot(sr_curve, label=trg_object) 268 | 269 | axes[0].plot(np.array(all_dists).mean(axis=0), label='all', linewidth=3, color='black') 270 | axes[1].plot((np.array(all_dists) == 0).sum(axis=0) / len(all_dists), label='all', linewidth=3, color='black') 271 | axes[0].legend() 272 | axes[1].legend() 273 | axes[0].set_title('DTM') 274 | axes[1].set_title('SR') 275 | 276 | plt.savefig(os.path.join(res_dir, 'dist_curve.png')) 277 | 278 | for trg_object in thres_dists.keys(): 279 | all_nss += nss_values[trg_object] 280 | thres_dist += thres_dists[trg_object] 281 | lines.append(f'{trg_object.split("_")[0]:12s}: dist mean:{np.array(thres_dists[trg_object]).mean():.3f}, nss mean: {np.array(nss_values[trg_object]).mean():.1f}, success rate: {(np.array(thres_dists[trg_object]) == 0).sum() / len(thres_dists[trg_object]):.3f} ({(np.array(thres_dists[trg_object]) == 0).sum()}/{len(thres_dists[trg_object])})') 282 | lines.append(f'=== ALL ===: dist mean:{np.mean(thres_dist):.3f}, nss mean: {np.array(all_nss).mean():.1f}, success rate: {(np.array(thres_dist)==0).sum() / len(thres_dist):.3f} ({(np.array(thres_dist)==0).sum()}/{len(thres_dist)})') 283 | if res_dir is not None: 284 | with open(os.path.join(res_dir, 'total_dists.txt'), 'w') as f: 285 | f.writelines([line + '\n' for line in lines]) 286 | for line in lines: 287 | print(line) 288 | 289 | if __name__ == '__main__': 290 | parser = argparse.ArgumentParser() 291 | parser.add_argument('--method', '-m', type=str, default='dift', choices=['dift', 'ldm_sc', 'sd_dino', 'dino_vit'], help='method for correspondance') 292 | parser.add_argument('--dataset', '-d', type=str, default='clip_b32_x2', choices=['clip_b32', 'clib_b32_x0.5', 'clip_b32_x2', 'clip_b16', 'clip_b16_x0.5', 'clip_b16_x2', 'clip_b32_lpips', 'clip_b32_lpips_x0.5', 'clip_b32_lpips_x2', 'resnet_50', 'resnet_50_x0.5', 'resnet_50_x2'], help='dataset for affordance memory') 293 | parser.add_argument('--exp_name', '-e', type=str, default='', help='experiment name') 294 | parser.add_argument('--mask_threshold', '-s', type=int, default=122, help='mask threshold for success rate calculation') 295 | parser.add_argument('--visualize', '-v', action='store_true', help='visualize the correspondance pairs') 296 | parser.add_argument('--avg_pts', '-a', action='store_true', help='average the five source points before (True) or after (False) correspondance') 297 | parser.add_argument('--top_k', '-k', type=int, default=5, help='use the top k retrieved images') 298 | parser.add_argument('--top_k_type', '-kt', type=str, default='max', choices=['max', 'avg'], help='max: use the top k with max cor value, avg: use the average results of top k') 299 | parser.add_argument('--transpose_types', '-t', type=int, default=8, help='1: no transpose, 4: rotations, 8: flip and rotations') 300 | args = parser.parse_args() 301 | args.avg_pts = False 302 | device = 'cuda' if torch.cuda.is_available else 'cpu' 303 | average_pts, visualize = args.avg_pts, args.visualize 304 | exp_name = args.dataset 305 | if args.mask_threshold != 122: 306 | exp_name += '_s' + str(args.mask_threshold) 307 | if args.top_k != 1: 308 | exp_name += f'_top{args.top_k}' 309 | if args.top_k_type == 'max': 310 | exp_name += '_max' 311 | if args.avg_pts: 312 | exp_name += '_avg' 313 | if args.transpose_types > 1: 314 | exp_name += '_transpose' 315 | exp_name = exp_name if len(args.exp_name) == 0 else args.exp_name 316 | cor_cfg = get_cor_cfg(args.method) 317 | 318 | model = get_model(args.method, cor_cfg, device=device) 319 | base_dir = f'datasets/{args.dataset}' 320 | res_dir = f'results/{args.method}/{exp_name}' 321 | print(f'res_dis: {res_dir}') 322 | 323 | total_dists, nss_values, thres_dists, trg_points = dataset_walkthrough(base_dir, args.method, model, exp_name, cor_cfg, average_pts, visualize, args.mask_threshold, args.top_k, args.top_k_type, args.transpose_types, device) 324 | 325 | with open(os.path.join(res_dir, 'results.pkl'), 'wb') as f: 326 | pickle.dump({'args': args, 'total_dists': total_dists, 'nss_values': nss_values, 'thres_dists': thres_dists, 'trg_points': trg_points}, f) 327 | 328 | analyze_dists(total_dists, nss_values, thres_dists, res_dir) -------------------------------------------------------------------------------- /Correspondence/sc_models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/Robo-ABC/0ce7ac90d0ce61099988690f77f19785a388bb20/Correspondence/sc_models/.DS_Store -------------------------------------------------------------------------------- /Correspondence/sc_models/dift/dift_sd.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from typing import Any, Callable, Dict, List, Optional, Union 7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 8 | from diffusers import DDIMScheduler 9 | import gc 10 | from PIL import Image 11 | 12 | class MyUNet2DConditionModel(UNet2DConditionModel): 13 | def forward( 14 | self, 15 | sample: torch.FloatTensor, 16 | timestep: Union[torch.Tensor, float, int], 17 | up_ft_indices, 18 | encoder_hidden_states: torch.Tensor, 19 | class_labels: Optional[torch.Tensor] = None, 20 | timestep_cond: Optional[torch.Tensor] = None, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | cross_attention_kwargs: Optional[Dict[str, Any]] = None): 23 | r""" 24 | Args: 25 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 26 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 27 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 28 | cross_attention_kwargs (`dict`, *optional*): 29 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 30 | `self.processor` in 31 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 32 | """ 33 | # By default samples have to be AT least a multiple of the overall upsampling factor. 34 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 35 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 36 | # on the fly if necessary. 37 | default_overall_up_factor = 2**self.num_upsamplers 38 | 39 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 40 | forward_upsample_size = False 41 | upsample_size = None 42 | 43 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 44 | # logger.info("Forward upsample size to force interpolation output size.") 45 | forward_upsample_size = True 46 | 47 | # prepare attention_mask 48 | if attention_mask is not None: 49 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 50 | attention_mask = attention_mask.unsqueeze(1) 51 | 52 | # 0. center input if necessary 53 | if self.config.center_input_sample: 54 | sample = 2 * sample - 1.0 55 | 56 | # 1. time 57 | timesteps = timestep 58 | if not torch.is_tensor(timesteps): 59 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 60 | # This would be a good case for the `match` statement (Python 3.10+) 61 | is_mps = sample.device.type == "mps" 62 | if isinstance(timestep, float): 63 | dtype = torch.float32 if is_mps else torch.float64 64 | else: 65 | dtype = torch.int32 if is_mps else torch.int64 66 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 67 | elif len(timesteps.shape) == 0: 68 | timesteps = timesteps[None].to(sample.device) 69 | 70 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 71 | timesteps = timesteps.expand(sample.shape[0]) 72 | 73 | t_emb = self.time_proj(timesteps) 74 | 75 | # timesteps does not contain any weights and will always return f32 tensors 76 | # but time_embedding might actually be running in fp16. so we need to cast here. 77 | # there might be better ways to encapsulate this. 78 | t_emb = t_emb.to(dtype=self.dtype) 79 | 80 | emb = self.time_embedding(t_emb, timestep_cond) 81 | 82 | if self.class_embedding is not None: 83 | if class_labels is None: 84 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 85 | 86 | if self.config.class_embed_type == "timestep": 87 | class_labels = self.time_proj(class_labels) 88 | 89 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 90 | emb = emb + class_emb 91 | 92 | # 2. pre-process 93 | sample = self.conv_in(sample) 94 | 95 | # 3. down 96 | down_block_res_samples = (sample,) 97 | for downsample_block in self.down_blocks: 98 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 99 | sample, res_samples = downsample_block( 100 | hidden_states=sample, 101 | temb=emb, 102 | encoder_hidden_states=encoder_hidden_states, 103 | attention_mask=attention_mask, 104 | cross_attention_kwargs=cross_attention_kwargs, 105 | ) 106 | else: 107 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 108 | 109 | down_block_res_samples += res_samples 110 | 111 | # 4. mid 112 | if self.mid_block is not None: 113 | sample = self.mid_block( 114 | sample, 115 | emb, 116 | encoder_hidden_states=encoder_hidden_states, 117 | attention_mask=attention_mask, 118 | cross_attention_kwargs=cross_attention_kwargs, 119 | ) 120 | 121 | # 5. up 122 | up_ft = {} 123 | for i, upsample_block in enumerate(self.up_blocks): 124 | 125 | if i > np.max(up_ft_indices): 126 | break 127 | 128 | is_final_block = i == len(self.up_blocks) - 1 129 | 130 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 131 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 132 | 133 | # if we have not reached the final block and need to forward the 134 | # upsample size, we do it here 135 | if not is_final_block and forward_upsample_size: 136 | upsample_size = down_block_res_samples[-1].shape[2:] 137 | 138 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 139 | sample = upsample_block( 140 | hidden_states=sample, 141 | temb=emb, 142 | res_hidden_states_tuple=res_samples, 143 | encoder_hidden_states=encoder_hidden_states, 144 | cross_attention_kwargs=cross_attention_kwargs, 145 | upsample_size=upsample_size, 146 | attention_mask=attention_mask, 147 | ) 148 | else: 149 | sample = upsample_block( 150 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 151 | ) 152 | 153 | if i in up_ft_indices: 154 | up_ft[i] = sample.detach() 155 | 156 | output = {} 157 | output['up_ft'] = up_ft 158 | return output 159 | 160 | class OneStepSDPipeline(StableDiffusionPipeline): 161 | @torch.no_grad() 162 | def __call__( 163 | self, 164 | img_tensor, 165 | t, 166 | up_ft_indices, 167 | negative_prompt: Optional[Union[str, List[str]]] = None, 168 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 169 | prompt_embeds: Optional[torch.FloatTensor] = None, 170 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 171 | callback_steps: int = 1, 172 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 173 | ): 174 | 175 | device = self._execution_device 176 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor 177 | t = torch.tensor(t, dtype=torch.long, device=device) 178 | noise = torch.randn_like(latents).to(device) 179 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 180 | unet_output = self.unet(latents_noisy, 181 | t, 182 | up_ft_indices, 183 | encoder_hidden_states=prompt_embeds, 184 | cross_attention_kwargs=cross_attention_kwargs) 185 | return unet_output 186 | 187 | 188 | class SDFeaturizer: 189 | def __init__(self, device="cuda"): 190 | sd_id='stabilityai/stable-diffusion-2-1' 191 | unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet", local_files_only=True) 192 | onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None, local_files_only=True) 193 | onestep_pipe.vae.decoder = None 194 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler", local_files_only=True) 195 | gc.collect() 196 | onestep_pipe = onestep_pipe.to(device) 197 | onestep_pipe.enable_attention_slicing() 198 | onestep_pipe.enable_xformers_memory_efficient_attention() 199 | self.pipe = onestep_pipe 200 | self.device = device 201 | 202 | @torch.no_grad() 203 | def forward(self, 204 | img_tensor, 205 | prompt, 206 | t=261, 207 | up_ft_index=1, 208 | ensemble_size=8): 209 | ''' 210 | Args: 211 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] 212 | prompt: the prompt to use, a string 213 | t: the time step to use, should be an int in the range of [0, 1000] 214 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] 215 | ensemble_size: the number of repeated images used in the batch to extract features 216 | Return: 217 | unet_ft: a torch tensor in the shape of [1, c, h, w] 218 | ''' 219 | img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).to(self.device) # ensem, c, h, w 220 | prompt_embeds = self.pipe._encode_prompt( 221 | prompt=prompt, 222 | device=self.device, 223 | num_images_per_prompt=1, 224 | do_classifier_free_guidance=False) # [1, 77, dim] 225 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1) 226 | unet_ft_all = self.pipe( 227 | img_tensor=img_tensor, 228 | t=t, 229 | up_ft_indices=[up_ft_index], 230 | prompt_embeds=prompt_embeds) 231 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 232 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 233 | return unet_ft -------------------------------------------------------------------------------- /Correspondence/sc_models/dift/get_cor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from PIL import Image 6 | from torchvision.transforms import PILToTensor 7 | 8 | def pad_image(image, pixel_locs=[]): 9 | width, height = image.size 10 | # Calculate padding to make the image square 11 | if width > height: 12 | # Width is greater, pad height 13 | padding = (width - height) // 2 14 | padded_image = Image.new("RGB", (width, width), (255, 255, 255)) 15 | padded_image.paste(image, (0, padding)) 16 | padded_pixel_locs = [(pixel_loc[0], pixel_loc[1] + padding) for pixel_loc in pixel_locs] 17 | else: 18 | # Height is greater, pad width 19 | padding = (height - width) // 2 20 | padded_image = Image.new("RGB", (height, height), (255, 255, 255)) 21 | padded_image.paste(image, (padding, 0)) 22 | padded_pixel_locs = [(pixel_loc[0] + padding, pixel_loc[1]) for pixel_loc in pixel_locs] 23 | 24 | return padded_image, padded_pixel_locs 25 | 26 | def crop_array(array, org_h, org_w): 27 | org_h, org_w = round(org_h), round(org_w) 28 | padding = abs((org_w - org_h) // 2) 29 | # Convert the new location back to the original image's coordinates 30 | if org_w > org_h: 31 | # If the original width was greater, adjust the y-coordinate 32 | cropped_array = array[:, padding: -padding, :] 33 | else: 34 | # If the original height was greater, adjust the x-coordinate 35 | cropped_array = array[:, :, padding: -padding] 36 | return cropped_array 37 | 38 | 39 | def get_cor_pairs(dift, src_image, trg_image, src_points, src_prompt, trg_prompt, img_size, ensemble_size=8, return_cos_maps=False, transpose_img_func=lambda x:x, transpose_pts_func = lambda x, y: (x, y)): 40 | """ 41 | src_image, trg_image: relative path of src and trg images 42 | src_points: resized affordance points in src_image 43 | average_pts: average before correspondance or not 44 | ----- 45 | return: correspondance maps of each src_point and each target_point 46 | """ 47 | trg_points = [] 48 | 49 | with Image.open(src_image) as img: 50 | # src_image, src_points = pad_image(img, src_points) 51 | src_image = transpose_img_func(img) 52 | src_w, src_h = src_image.size 53 | src_image = src_image.resize((img_size, img_size)).convert('RGB') 54 | src_points = [transpose_pts_func(x, y) for x, y in src_points] 55 | src_x_scale, src_y_scale = img_size / src_w, img_size / src_h 56 | with Image.open(trg_image) as img: 57 | trg_w, trg_h = img.size 58 | # trg_image, _ = pad_image(img) 59 | trg_image = img.resize((img_size, img_size)).convert('RGB') 60 | trg_x_scale, trg_y_scale = img_size / trg_w, img_size / trg_h 61 | 62 | src_points = [[int(np.round(x * src_x_scale)), int(np.round(y * src_y_scale))] for (x, y) in src_points] 63 | 64 | src_tensor = (PILToTensor()(src_image) / 255.0 - 0.5) * 2 65 | trg_tensor = (PILToTensor()(trg_image) / 255.0 - 0.5) * 2 66 | src_ft = dift.forward(src_tensor, prompt=src_prompt, ensemble_size=ensemble_size) 67 | trg_ft = dift.forward(trg_tensor, prompt=trg_prompt, ensemble_size=ensemble_size) 68 | num_channel = src_ft.size(1) 69 | cos = nn.CosineSimilarity(dim=1) 70 | 71 | 72 | src_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(src_ft) 73 | src_vectors = [src_ft[0, :, y, x].view(1, num_channel, 1, 1) for (x, y) in src_points] 74 | del src_ft 75 | gc.collect() 76 | torch.cuda.empty_cache() 77 | 78 | trg_ft = nn.Upsample(size=(img_size, img_size), mode='bilinear')(trg_ft) 79 | cos_maps = [cos(src_vec, trg_ft).cpu().numpy() for src_vec in src_vectors] 80 | cos_values = [cos_map.max() for cos_map in cos_maps] 81 | # cos_maps = [crop_array(cos(src_vec, trg_ft).cpu().numpy(), trg_h * trg_x_scale, trg_w * trg_y_scale) for src_vec in src_vectors] 82 | 83 | del trg_ft 84 | gc.collect() 85 | torch.cuda.empty_cache() 86 | 87 | for cos_map in cos_maps: 88 | max_yx = np.unravel_index(cos_map.argmax(), cos_map.shape)[1:] 89 | trg_points.append([max_yx[1], max_yx[0]]) 90 | 91 | trg_points = [[int(np.round(x / trg_x_scale)), int(np.round(y / trg_y_scale))] for (x, y) in trg_points] 92 | cos_maps = [nn.Upsample(size=(trg_h, trg_w), mode='bilinear')(torch.tensor(cos_map)[None, :, :, :]).numpy()[0] for cos_map in cos_maps] if return_cos_maps else None 93 | return trg_points, src_points, cos_maps, src_image, cos_values 94 | 95 | -------------------------------------------------------------------------------- /Correspondence/sc_models/dino_vit/extractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision.transforms 4 | from torch import nn 5 | from torchvision import transforms 6 | import torch.nn.modules.utils as nn_utils 7 | import math 8 | import timm 9 | import types 10 | from pathlib import Path 11 | from typing import Union, List, Tuple 12 | from PIL import Image 13 | 14 | 15 | class ViTExtractor: 16 | """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. 17 | 18 | We use the following notation in the documentation of the module's methods: 19 | B - batch size 20 | h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW 21 | p - patch size of the ViT. either 8 or 16. 22 | t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width 23 | of the input image. 24 | d - the embedding dimension in the ViT. 25 | """ 26 | 27 | def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): 28 | """ 29 | :param model_type: A string specifying the type of model to extract from. 30 | [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | 31 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] 32 | :param stride: stride of first convolution layer. small stride -> higher resolution. 33 | :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. 34 | should be compatible with model_type. 35 | """ 36 | self.model_type = model_type 37 | self.device = device 38 | if model is not None: 39 | self.model = model 40 | else: 41 | self.model = ViTExtractor.create_model(model_type) 42 | 43 | self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) 44 | self.model.eval() 45 | self.model.to(self.device) 46 | self.p = self.model.patch_embed.patch_size 47 | self.stride = self.model.patch_embed.proj.stride 48 | 49 | self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) 50 | self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) 51 | 52 | self._feats = [] 53 | self.hook_handlers = [] 54 | self.load_size = None 55 | self.num_patches = None 56 | 57 | @staticmethod 58 | def create_model(model_type: str) -> nn.Module: 59 | """ 60 | :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | 61 | dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | 62 | vit_base_patch16_224] 63 | :return: the model 64 | """ 65 | if 'dino' in model_type: 66 | model = torch.hub.load('facebookresearch/dino:main', model_type) 67 | else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). 68 | temp_model = timm.create_model(model_type, pretrained=True) 69 | model_type_dict = { 70 | 'vit_small_patch16_224': 'dino_vits16', 71 | 'vit_small_patch8_224': 'dino_vits8', 72 | 'vit_base_patch16_224': 'dino_vitb16', 73 | 'vit_base_patch8_224': 'dino_vitb8' 74 | } 75 | model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) 76 | temp_state_dict = temp_model.state_dict() 77 | del temp_state_dict['head.weight'] 78 | del temp_state_dict['head.bias'] 79 | model.load_state_dict(temp_state_dict) 80 | return model 81 | 82 | @staticmethod 83 | def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): 84 | """ 85 | Creates a method for position encoding interpolation. 86 | :param patch_size: patch size of the model. 87 | :param stride_hw: A tuple containing the new height and width stride respectively. 88 | :return: the interpolation method 89 | """ 90 | def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: 91 | npatch = x.shape[1] - 1 92 | N = self.pos_embed.shape[1] - 1 93 | if npatch == N and w == h: 94 | return self.pos_embed 95 | class_pos_embed = self.pos_embed[:, 0] 96 | patch_pos_embed = self.pos_embed[:, 1:] 97 | dim = x.shape[-1] 98 | # compute number of tokens taking stride into account 99 | w0 = 1 + (w - patch_size) // stride_hw[1] 100 | h0 = 1 + (h - patch_size) // stride_hw[0] 101 | assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and 102 | stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" 103 | # we add a small number to avoid floating point error in the interpolation 104 | # see discussion at https://github.com/facebookresearch/dino/issues/8 105 | w0, h0 = w0 + 0.1, h0 + 0.1 106 | patch_pos_embed = nn.functional.interpolate( 107 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 108 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 109 | mode='bicubic', 110 | align_corners=False, recompute_scale_factor=False 111 | ) 112 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 113 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 114 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 115 | 116 | return interpolate_pos_encoding 117 | 118 | @staticmethod 119 | def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: 120 | """ 121 | change resolution of model output by changing the stride of the patch extraction. 122 | :param model: the model to change resolution for. 123 | :param stride: the new stride parameter. 124 | :return: the adjusted model 125 | """ 126 | patch_size = model.patch_embed.patch_size 127 | if stride == patch_size: # nothing to do 128 | return model 129 | 130 | stride = nn_utils._pair(stride) 131 | assert all([(patch_size // s_) * s_ == patch_size for s_ in 132 | stride]), f'stride {stride} should divide patch_size {patch_size}' 133 | 134 | # fix the stride 135 | model.patch_embed.proj.stride = stride 136 | # fix the positional encoding code 137 | model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) 138 | return model 139 | 140 | def preprocess(self, pil_image, 141 | load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: 142 | """ 143 | Preprocesses an image before extraction. 144 | :param image_path: path to image to be extracted. 145 | :param load_size: optional. Size to resize image before the rest of preprocessing. 146 | :return: a tuple containing: 147 | (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. 148 | (2) the pil image in relevant dimensions 149 | """ 150 | if load_size is not None: 151 | pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) 152 | prep = transforms.Compose([ 153 | transforms.ToTensor(), 154 | transforms.Normalize(mean=self.mean, std=self.std) 155 | ]) 156 | prep_img = prep(pil_image)[None, ...] 157 | return prep_img, pil_image 158 | 159 | def _get_hook(self, facet: str): 160 | """ 161 | generate a hook method for a specific block and facet. 162 | """ 163 | if facet in ['attn', 'token']: 164 | def _hook(model, input, output): 165 | self._feats.append(output) 166 | return _hook 167 | 168 | if facet == 'query': 169 | facet_idx = 0 170 | elif facet == 'key': 171 | facet_idx = 1 172 | elif facet == 'value': 173 | facet_idx = 2 174 | else: 175 | raise TypeError(f"{facet} is not a supported facet.") 176 | 177 | def _inner_hook(module, input, output): 178 | input = input[0] 179 | B, N, C = input.shape 180 | qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) 181 | self._feats.append(qkv[facet_idx]) #Bxhxtxd 182 | return _inner_hook 183 | 184 | def _register_hooks(self, layers: List[int], facet: str) -> None: 185 | """ 186 | register hook to extract features. 187 | :param layers: layers from which to extract features. 188 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] 189 | """ 190 | for block_idx, block in enumerate(self.model.blocks): 191 | if block_idx in layers: 192 | if facet == 'token': 193 | self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) 194 | elif facet == 'attn': 195 | self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) 196 | elif facet in ['key', 'query', 'value']: 197 | self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) 198 | else: 199 | raise TypeError(f"{facet} is not a supported facet.") 200 | 201 | def _unregister_hooks(self) -> None: 202 | """ 203 | unregisters the hooks. should be called after feature extraction. 204 | """ 205 | for handle in self.hook_handlers: 206 | handle.remove() 207 | self.hook_handlers = [] 208 | 209 | def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: 210 | """ 211 | extract features from the model 212 | :param batch: batch to extract features for. Has shape BxCxHxW. 213 | :param layers: layer to extract. A number between 0 to 11. 214 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] 215 | :return : tensor of features. 216 | if facet is 'key' | 'query' | 'value' has shape Bxhxtxd 217 | if facet is 'attn' has shape Bxhxtxt 218 | if facet is 'token' has shape Bxtxd 219 | """ 220 | B, C, H, W = batch.shape 221 | self._feats = [] 222 | self._register_hooks(layers, facet) 223 | _ = self.model(batch) 224 | self._unregister_hooks() 225 | self.load_size = (H, W) 226 | self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) 227 | return self._feats 228 | 229 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: 230 | """ 231 | create a log-binned descriptor. 232 | :param x: tensor of features. Has shape Bxhxtxd. 233 | :param hierarchy: how many bin hierarchies to use. 234 | """ 235 | B = x.shape[0] 236 | num_bins = 1 + 8 * hierarchy 237 | 238 | bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) 239 | bin_x = bin_x.permute(0, 2, 1) 240 | bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) 241 | # Bx(dxh)xnum_patches[0]xnum_patches[1] 242 | sub_desc_dim = bin_x.shape[1] 243 | 244 | avg_pools = [] 245 | # compute bins of all sizes for all spatial locations. 246 | for k in range(0, hierarchy): 247 | # avg pooling with kernel 3**kx3**k 248 | win_size = 3 ** k 249 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) 250 | avg_pools.append(avg_pool(bin_x)) 251 | 252 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) 253 | for y in range(self.num_patches[0]): 254 | for x in range(self.num_patches[1]): 255 | part_idx = 0 256 | # fill all bins for a spatial location (y, x) 257 | for k in range(0, hierarchy): 258 | kernel_size = 3 ** k 259 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): 260 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): 261 | if i == y and j == x and k != 0: 262 | continue 263 | if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: 264 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 265 | :, :, i, j] 266 | else: # handle padding in a more delicate way than zero padding 267 | temp_i = max(0, min(i, self.num_patches[0] - 1)) 268 | temp_j = max(0, min(j, self.num_patches[1] - 1)) 269 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 270 | :, :, temp_i, 271 | temp_j] 272 | part_idx += 1 273 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) 274 | # Bx1x(t-1)x(dxh) 275 | return bin_x 276 | 277 | def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', 278 | bin: bool = False, include_cls: bool = False) -> torch.Tensor: 279 | """ 280 | extract descriptors from the model 281 | :param batch: batch to extract descriptors for. Has shape BxCxHxW. 282 | :param layers: layer to extract. A number between 0 to 11. 283 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] 284 | :param bin: apply log binning to the descriptor. default is False. 285 | :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. 286 | """ 287 | assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. 288 | choose from ['key' | 'query' | 'value' | 'token'] """ 289 | self._extract_features(batch, [layer], facet) 290 | x = self._feats[0] 291 | if facet == 'token': 292 | x.unsqueeze_(dim=1) #Bx1xtxd 293 | if not include_cls: 294 | x = x[:, :, 1:, :] # remove cls token 295 | else: 296 | assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." 297 | if not bin: 298 | desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) 299 | else: 300 | desc = self._log_bin(x) 301 | return desc 302 | 303 | def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: 304 | """ 305 | extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer 306 | in of the CLS token. All values are then normalized to range between 0 and 1. 307 | :param batch: batch to extract saliency maps for. Has shape BxCxHxW. 308 | :return: a tensor of saliency maps. has shape Bxt-1 309 | """ 310 | assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." 311 | self._extract_features(batch, [11], 'attn') 312 | head_idxs = [0, 2, 4, 5] 313 | curr_feats = self._feats[0] #Bxhxtxt 314 | cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) 315 | temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] 316 | cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] 317 | return cls_attn_maps 318 | 319 | """ taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" 320 | def str2bool(v): 321 | if isinstance(v, bool): 322 | return v 323 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 324 | return True 325 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 326 | return False 327 | else: 328 | raise argparse.ArgumentTypeError('Boolean value expected.') 329 | 330 | if __name__ == "__main__": 331 | parser = argparse.ArgumentParser(description='Facilitate ViT Descriptor extraction.') 332 | parser.add_argument('--image_path', type=str, required=True, help='path of the extracted image.') 333 | parser.add_argument('--output_path', type=str, required=True, help='path to file containing extracted descriptors.') 334 | parser.add_argument('--load_size', default=224, type=int, help='load size of the input image.') 335 | parser.add_argument('--stride', default=4, type=int, help="""stride of first convolution layer. 336 | small stride -> higher resolution.""") 337 | parser.add_argument('--model_type', default='dino_vits8', type=str, 338 | help="""type of model to extract. 339 | Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | 340 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""") 341 | parser.add_argument('--facet', default='key', type=str, help="""facet to create descriptors from. 342 | options: ['key' | 'query' | 'value' | 'token']""") 343 | parser.add_argument('--layer', default=11, type=int, help="layer to create descriptors from.") 344 | parser.add_argument('--bin', default='False', type=str2bool, help="create a binned descriptor if True.") 345 | 346 | args = parser.parse_args() 347 | 348 | with torch.no_grad(): 349 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 350 | extractor = ViTExtractor(args.model_type, args.stride, device=device) 351 | image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size) 352 | print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.") 353 | descriptors = extractor.extract_descriptors(image_batch.to(device), args.layer, args.facet, args.bin) 354 | print(f"Descriptors are of size: {descriptors.shape}") 355 | torch.save(descriptors, args.output_path) 356 | print(f"Descriptors saved to: {args.output_path}") 357 | -------------------------------------------------------------------------------- /Correspondence/sc_models/dino_vit/get_cor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def get_cor_pairs(extractor, src_image: str, trg_image: str, src_points: list, image_size: int = 512, layer: int = 9, 8 | facet: str = 'key', bin: bool = True, transpose_img_func=lambda x:x, transpose_pts_func = lambda x, y: (x, y),device='cuda:0'): 9 | 10 | # extracting descriptors for each image 11 | with Image.open(src_image) as img: 12 | src_image = transpose_img_func(img) 13 | src_image_width, src_image_height = img.size 14 | src_image = src_image.resize((image_size, image_size)).convert('RGB') 15 | src_points = [transpose_pts_func(x, y) for x, y in src_points] 16 | src_x_scale, src_y_scale = image_size / src_image_width, image_size / src_image_height 17 | src_points = [[int(np.round(x * src_x_scale)), int(np.round(y * src_y_scale))] for (x, y) in src_points] 18 | 19 | with Image.open(trg_image) as img: 20 | trg_image_width, trg_image_height = img.size 21 | # trg_image, _ = pad_image(img) 22 | trg_image = img.resize((image_size, image_size)).convert('RGB') 23 | trg_x_scale, trg_y_scale = image_size / trg_image_width, image_size / trg_image_height 24 | 25 | src_image_batch, src_image_pil = extractor.preprocess(src_image, image_size) 26 | descriptors_src = extractor.extract_descriptors(src_image_batch.to(device), layer, facet, bin) 27 | 28 | num_patches_src, _ = extractor.num_patches, extractor.load_size 29 | 30 | indices_to_show = [] 31 | for i in range(len(src_points)): 32 | transferred_x1 = (src_points[i][0] - extractor.stride[1] - extractor.p // 2)/extractor.stride[1] + 1 33 | transferred_y1 = (src_points[i][1] - extractor.stride[0] - extractor.p // 2)/extractor.stride[0] + 1 34 | indices_to_show.append(int(transferred_y1) * num_patches_src[1] + int(transferred_x1)) 35 | 36 | descriptors_src_vec = descriptors_src[:, :, torch.Tensor(indices_to_show).to(torch.long)] 37 | 38 | del descriptors_src, src_image_batch 39 | gc.collect() 40 | torch.cuda.empty_cache() 41 | 42 | trg_image_batch, _ = extractor.preprocess(trg_image, image_size) 43 | descriptors_trg = extractor.extract_descriptors(trg_image_batch.to(device), layer, facet, bin) 44 | num_patches_trg, _ = extractor.num_patches, extractor.load_size 45 | 46 | # calculate similarity between src_image and trg_image descriptors 47 | similarities = chunk_cosine_sim(descriptors_src_vec, descriptors_trg) 48 | 49 | # calculate best buddies 50 | sim_src, nn_src = torch.max(similarities, dim=-1) # nn_1 - indices of block2 closest to block1 51 | sim_src, nn_src = sim_src[0, 0], nn_src[0, 0] 52 | 53 | del descriptors_trg, descriptors_src_vec, similarities, trg_image_batch 54 | gc.collect() 55 | torch.cuda.empty_cache() 56 | 57 | trg_img_indices_to_show = nn_src 58 | sim_values = sim_src.detach().cpu().numpy() 59 | # coordinates in descriptor map's dimensions 60 | trg_img_y_to_show = (trg_img_indices_to_show / num_patches_trg[1]).cpu().numpy() 61 | trg_img_x_to_show = (trg_img_indices_to_show % num_patches_trg[1]).cpu().numpy() 62 | trg_points = [] 63 | for y, x in zip(trg_img_y_to_show, trg_img_x_to_show): 64 | x_trg_show = (int(x) - 1) * extractor.stride[1] + extractor.stride[1] + extractor.p // 2 65 | y_trg_show = (int(y) - 1) * extractor.stride[0] + extractor.stride[0] + extractor.p // 2 66 | trg_points.append([y_trg_show, x_trg_show]) 67 | trg_points = [[int(np.round(x / trg_x_scale)), int(np.round(y / trg_y_scale))] for (x, y) in trg_points] 68 | return trg_points, src_points, None, src_image, sim_values 69 | 70 | def chunk_cosine_sim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 71 | """ Computes cosine similarity between all possible pairs in two sets of vectors. 72 | Operates on chunks so no large amount of GPU RAM is required. 73 | :param x: an tensor of descriptors of shape Bx1x(t_x)xd' where d' is the dimensionality of the descriptors and t_x 74 | is the number of tokens in x. 75 | :param y: a tensor of descriptors of shape Bx1x(t_y)xd' where d' is the dimensionality of the descriptors and t_y 76 | is the number of tokens in y. 77 | :return: cosine similarity between all descriptors in x and all descriptors in y. Has shape of Bx1x(t_x)x(t_y) """ 78 | result_list = [] 79 | num_token_x = x.shape[2] 80 | for token_idx in range(num_token_x): 81 | token = x[:, :, token_idx, :].unsqueeze(dim=2) # Bx1x1xd' 82 | result_list.append(torch.nn.CosineSimilarity(dim=3)(token, y)) # Bx1xt 83 | return torch.stack(result_list, dim=2) # Bx1x(t_x)x(t_y) 84 | -------------------------------------------------------------------------------- /Correspondence/sc_models/ldm_sc/get_cor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import cv2 4 | import torch 5 | import pickle 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | from sc_models.ldm_sc.optimize import optimize_prompt, run_image_with_tokens_cropped, find_max_pixel_value, load_ldm 11 | 12 | 13 | def get_cor_pairs(ldm, src_image, trg_image, src_points, img_size, device='cuda:0'): 14 | """ 15 | src_image, trg_image: relative path of src and trg images 16 | src_points: resized affordance points in src_image 17 | average_pts: average before correspondance or not 18 | ----- 19 | return: correspondance maps of each src_point and each target_point 20 | """ 21 | trg_points = [] 22 | layers = [5,6,7,8] 23 | with Image.open(src_image) as img: 24 | src_w, src_h = img.size 25 | src_image = img.resize((img_size, img_size), Image.BILINEAR).convert('RGB') 26 | src_tensor = torch.Tensor(np.array(src_image).transpose(2, 0, 1)) / 255.0 27 | src_x_scale, src_y_scale = img_size / src_w, img_size / src_h 28 | with Image.open(trg_image) as img: 29 | trg_w, trg_h = img.size 30 | trg_image = img.resize((img_size, img_size), Image.BILINEAR).convert('RGB') 31 | trg_tensor = torch.Tensor(np.array(trg_image).transpose(2, 0, 1)) / 255.0 32 | trg_x_scale, trg_y_scale = img_size / trg_w, img_size / trg_h 33 | 34 | src_points = [torch.Tensor([int(np.round(x * src_x_scale)), int(np.round(y * src_y_scale))]) for (x, y) in src_points] 35 | all_contexts = [] 36 | for src_point in src_points: 37 | contexts = [] 38 | for _ in range(5): 39 | context = optimize_prompt(ldm, src_tensor, src_point/img_size, num_steps=129, device=device, layers=layers, lr = 0.0023755632081200314, upsample_res=img_size, noise_level=-8, sigma = 27.97853316316864, flip_prob=0.0, crop_percent=93.16549294381423) 40 | contexts.append(context) 41 | all_contexts.append(torch.stack(contexts)) 42 | 43 | all_maps = [] 44 | for context in contexts: 45 | maps = [] 46 | attn_maps, _ = run_image_with_tokens_cropped(ldm, trg_tensor, context, index=0, upsample_res = img_size, noise_level=-8, layers=layers, device=device, crop_percent=93.16549294381423, num_iterations=20, image_mask = None) 47 | for k in range(attn_maps.shape[0]): 48 | avg = torch.mean(attn_maps[k], dim=0, keepdim=True) 49 | maps.append(avg) 50 | maps = torch.stack(maps, dim=0) 51 | all_maps.append(maps) 52 | all_maps = torch.stack(all_maps, dim=0) 53 | all_maps = torch.mean(all_maps, dim=0) 54 | all_maps = torch.nn.Softmax(dim=-1)(all_maps.reshape(len(layers), img_size*img_size)) 55 | all_maps = all_maps.reshape(len(layers), img_size, img_size) 56 | 57 | all_maps = torch.mean(all_maps, dim=0) 58 | trg_points.append(find_max_pixel_value(all_maps, img_size = img_size).cpu().numpy()) 59 | 60 | 61 | trg_points = [[int(np.round(x / trg_x_scale)), int(np.round(y / trg_y_scale))] for (x, y) in trg_points] 62 | 63 | return trg_points 64 | 65 | -------------------------------------------------------------------------------- /Correspondence/sc_models/ldm_sc/optimize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from diffusers import StableDiffusionPipeline, DDIMScheduler 17 | import numpy as np 18 | import abc 19 | from PIL import Image 20 | from sc_models.ldm_sc.ptp_utils import diffusion_step, register_attention_control 21 | import torch.nn.functional as F 22 | 23 | import torch.nn as nn 24 | 25 | 26 | def load_ldm(device, type="CompVis/stable-diffusion-v1-4"): 27 | 28 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) 29 | 30 | MY_TOKEN = '' 31 | LOW_RESOURCE = False 32 | NUM_DDIM_STEPS = 50 33 | GUIDANCE_SCALE = 7.5 34 | MAX_NUM_WORDS = 77 35 | scheduler.set_timesteps(NUM_DDIM_STEPS) 36 | 37 | ldm = StableDiffusionPipeline.from_pretrained(type, use_auth_token=MY_TOKEN, scheduler=scheduler, local_files_only=True).to(device) 38 | 39 | for param in ldm.vae.parameters(): 40 | param.requires_grad = False 41 | for param in ldm.text_encoder.parameters(): 42 | param.requires_grad = False 43 | for param in ldm.unet.parameters(): 44 | param.requires_grad = False 45 | 46 | return ldm 47 | 48 | 49 | 50 | class AttentionControl(abc.ABC): 51 | 52 | def step_callback(self, x_t): 53 | 54 | return x_t 55 | 56 | def between_steps(self): 57 | return 58 | 59 | @property 60 | def num_uncond_att_layers(self): 61 | return 0 62 | 63 | @abc.abstractmethod 64 | def forward (self, attn, is_cross: bool, place_in_unet: str): 65 | raise NotImplementedError 66 | 67 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 68 | 69 | if self.cur_att_layer >= self.num_uncond_att_layers: 70 | h = attn.shape[0] 71 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 72 | self.cur_att_layer += 1 73 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 74 | self.cur_att_layer = 0 75 | self.cur_step += 1 76 | self.between_steps() 77 | return attn 78 | 79 | def reset(self): 80 | self.cur_step = 0 81 | self.cur_att_layer = 0 82 | 83 | def __init__(self): 84 | self.cur_step = 0 85 | self.num_att_layers = -1 86 | self.cur_att_layer = 0 87 | 88 | 89 | 90 | class AttentionStore(AttentionControl): 91 | 92 | @staticmethod 93 | def get_empty_store(): 94 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 95 | "down_self": [], "mid_self": [], "up_self": []} 96 | 97 | def forward(self, attn, is_cross: bool, place_in_unet: str): 98 | 99 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 100 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 101 | self.step_store[key].append(attn) 102 | return attn 103 | 104 | def between_steps(self): 105 | 106 | if len(self.attention_store) == 0: 107 | self.attention_store = self.step_store 108 | else: 109 | for key in self.attention_store: 110 | for i in range(len(self.attention_store[key])): 111 | self.attention_store[key][i] += self.step_store[key][i] 112 | self.step_store = self.get_empty_store() 113 | 114 | def get_average_attention(self): 115 | 116 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 117 | return average_attention 118 | 119 | 120 | def reset(self): 121 | super(AttentionStore, self).reset() 122 | self.step_store = self.get_empty_store() 123 | self.attention_store = {} 124 | 125 | def __init__(self): 126 | super(AttentionStore, self).__init__() 127 | self.step_store = self.get_empty_store() 128 | self.attention_store = {} 129 | 130 | 131 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 132 | if type(image_path) is str: 133 | image = np.array(Image.open(image_path))[:, :, :3] 134 | else: 135 | image = image_path 136 | h, w, c = image.shape 137 | left = min(left, w-1) 138 | right = min(right, w - left - 1) 139 | top = min(top, h - left - 1) 140 | bottom = min(bottom, h - top - 1) 141 | image = image[top:h-bottom, left:w-right] 142 | h, w, c = image.shape 143 | if h < w: 144 | offset = (w - h) // 2 145 | image = image[:, offset:offset + h] 146 | elif w < h: 147 | offset = (h - w) // 2 148 | image = image[offset:offset + w] 149 | image = np.array(Image.fromarray(image).resize((512, 512))) 150 | return image 151 | 152 | 153 | def init_prompt(model, prompt: str): 154 | uncond_input = model.tokenizer( 155 | [""], padding="max_length", max_length=model.tokenizer.model_max_length, 156 | return_tensors="pt" 157 | ) 158 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 159 | text_input = model.tokenizer( 160 | [prompt], 161 | padding="max_length", 162 | max_length=model.tokenizer.model_max_length, 163 | truncation=True, 164 | return_tensors="pt", 165 | ) 166 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 167 | context = torch.cat([uncond_embeddings, text_embeddings]) 168 | prompt = prompt 169 | 170 | return context, prompt 171 | 172 | def init_random_noise(device, num_words = 77): 173 | return torch.randn(1, num_words, 768).to(device) 174 | 175 | def image2latent(model, image, device): 176 | with torch.no_grad(): 177 | if type(image) is Image: 178 | image = np.array(image) 179 | if type(image) is torch.Tensor and image.dim() == 4: 180 | latents = image 181 | else: 182 | # print the max and min values of the image 183 | image = torch.from_numpy(image).float() * 2 - 1 184 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 185 | latents = model.vae.encode(image)['latent_dist'].mean 186 | latents = latents * 0.18215 187 | return latents 188 | 189 | 190 | def reshape_attention(attention_map): 191 | """takes average over 0th dimension and reshapes into square image 192 | 193 | Args: 194 | attention_map (4, img_size, -1): _description_ 195 | """ 196 | attention_map = attention_map.mean(0) 197 | img_size = int(np.sqrt(attention_map.shape[0])) 198 | attention_map = attention_map.reshape(img_size, img_size, -1) 199 | return attention_map 200 | 201 | def visualize_attention_map(attention_map, file_name): 202 | # save attention map 203 | attention_map = attention_map.unsqueeze(-1).repeat(1, 1, 3) 204 | attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) 205 | attention_map = attention_map.detach().cpu().numpy() 206 | attention_map = (attention_map * 255).astype(np.uint8) 207 | img = Image.fromarray(attention_map) 208 | img.save(file_name) 209 | 210 | 211 | @torch.no_grad() 212 | def run_image_with_tokens_cropped(ldm, image, tokens, device='cuda', from_where = ["down_cross", "mid_cross", "up_cross"], index=0, upsample_res=512, noise_level=10, layers=[0, 1, 2, 3, 4, 5], num_iterations=20, crop_percent=100.0, image_mask = None): 213 | 214 | # if image is a torch.tensor, convert to numpy 215 | if type(image) == torch.Tensor: 216 | image = image.permute(1, 2, 0).detach().cpu().numpy() 217 | 218 | num_samples = torch.zeros(len(layers), 4, 512, 512).to(device) 219 | sum_samples = torch.zeros(len(layers), 4, 512, 512).to(device) 220 | 221 | pixel_locs = torch.tensor([[0, 0], [0, 512], [512, 0], [512, 512]]).float().to(device) 222 | 223 | collected_attention_maps = [] 224 | 225 | for i in range(num_iterations): 226 | 227 | if i < 4: 228 | pixel_loc = pixel_locs[i] 229 | else: 230 | 231 | _attention_maps = sum_samples/num_samples 232 | 233 | # remove all the nans 234 | _attention_maps[_attention_maps != _attention_maps] = 0 235 | 236 | _attention_maps = torch.mean(_attention_maps, dim=0) 237 | _attention_maps = torch.mean(_attention_maps, dim=0) 238 | 239 | max_val = find_max_pixel_value(_attention_maps, img_size = 512)+0.5 240 | 241 | pixel_loc = max_val.clone() 242 | 243 | cropped_image, cropped_pixel, y_start, height, x_start, width = crop_image(image, pixel_loc, crop_percent = crop_percent) 244 | 245 | latents = image2latent(ldm, cropped_image, device) 246 | 247 | controller = AttentionStore() 248 | 249 | register_attention_control(ldm, controller) 250 | 251 | latents = ldm.scheduler.add_noise(latents, torch.rand_like(latents), ldm.scheduler.timesteps[-3]) 252 | 253 | latents = diffusion_step(ldm, controller, latents, tokens, ldm.scheduler.timesteps[-3], cfg=False) 254 | 255 | assert height == width 256 | 257 | _attention_maps = upscale_to_img_size(controller, from_where = from_where, upsample_res=height, layers=layers) 258 | 259 | num_samples[:, :, y_start:y_start+height, x_start:x_start+width] += 1 260 | sum_samples[:, :, y_start:y_start+height, x_start:x_start+width] += _attention_maps 261 | 262 | _attention_maps = sum_samples/num_samples 263 | 264 | if image_mask is not None: 265 | _attention_maps = _attention_maps * image_mask[None, None].to(device) 266 | 267 | collected_attention_maps.append(_attention_maps.clone()) 268 | 269 | # visualize sum_samples/num_samples 270 | attention_maps = sum_samples/num_samples 271 | 272 | if image_mask is not None: 273 | attention_maps = attention_maps * image_mask[None, None].to(device) 274 | 275 | return attention_maps, collected_attention_maps 276 | 277 | 278 | def upscale_to_img_size(controller, from_where = ["down_cross", "mid_cross", "up_cross"], upsample_res=512, layers=[0, 1, 2, 3, 4, 5]): 279 | """ 280 | returns the bilinearly upsampled attention map of size upsample_res x upsample_res for the first word in the prompt 281 | """ 282 | 283 | attention_maps = controller.get_average_attention() 284 | 285 | imgs = [] 286 | 287 | layer_overall = -1 288 | 289 | for key in from_where: 290 | for layer in range(len(attention_maps[key])): 291 | 292 | layer_overall += 1 293 | 294 | 295 | if layer_overall not in layers: 296 | continue 297 | 298 | img = attention_maps[key][layer] 299 | 300 | img = img.reshape(4, int(img.shape[1]**0.5), int(img.shape[1]**0.5), img.shape[2])[None, :, :, :, 1] 301 | 302 | if upsample_res != -1: 303 | # bilinearly upsample the image to img_sizeximg_size 304 | img = F.interpolate(img, size=(upsample_res, upsample_res), mode='bilinear', align_corners=False) 305 | 306 | imgs.append(img) 307 | 308 | imgs = torch.cat(imgs, dim=0) 309 | 310 | return imgs 311 | 312 | 313 | def softargmax2d(input, beta = 1000): 314 | *_, h, w = input.shape 315 | 316 | assert h == w, "only square images are supported" 317 | 318 | input = input.reshape(*_, h * w) 319 | input = nn.functional.softmax(input*beta, dim=-1) 320 | 321 | indices_c, indices_r = np.meshgrid( 322 | np.linspace(0, 1, w), 323 | np.linspace(0, 1, h), 324 | indexing='xy' 325 | ) 326 | 327 | indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w))).to(input.device).float() 328 | indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w))).to(input.device).float() 329 | 330 | result_r = torch.sum((h - 1) * input * indices_r, dim=-1) 331 | result_c = torch.sum((w - 1) * input * indices_c, dim=-1) 332 | 333 | result = torch.stack([result_c, result_r], dim=-1) 334 | 335 | return result/h 336 | 337 | 338 | def find_context(image, ldm, pixel_loc, context_estimator, device='cuda'): 339 | 340 | with torch.no_grad(): 341 | latent = image2latent(ldm, image.numpy().transpose(1, 2, 0), device) 342 | 343 | context = context_estimator(latent, pixel_loc) 344 | 345 | return context 346 | 347 | 348 | def find_max_pixel_value(tens, img_size=512, ignore_border = True): 349 | """finds the 2d pixel location that is the max value in the tensor 350 | 351 | Args: 352 | tens (tensor): shape (height, width) 353 | """ 354 | 355 | assert len(tens.shape) == 2, "tens must be 2d" 356 | 357 | _tens = tens.clone() 358 | height = _tens.shape[0] 359 | 360 | _tens = _tens.reshape(-1) 361 | max_loc = torch.argmax(_tens) 362 | max_pixel = torch.stack([max_loc % height, torch.div(max_loc, height, rounding_mode='floor')]) 363 | 364 | max_pixel = max_pixel/height*img_size 365 | 366 | return max_pixel 367 | 368 | def visualize_image_with_points(image, point, name, save_folder = "outputs"): 369 | 370 | """The point is in pixel numbers 371 | """ 372 | 373 | import matplotlib.pyplot as plt 374 | 375 | # if image is a torch.tensor, convert to numpy 376 | if type(image) == torch.Tensor: 377 | try: 378 | image = image.permute(1, 2, 0).detach().cpu().numpy() 379 | except: 380 | import ipdb; ipdb.set_trace() 381 | 382 | 383 | # make the figure without a border 384 | fig = plt.figure(frameon=False) 385 | fig.set_size_inches(10, 10) 386 | 387 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 388 | ax.set_axis_off() 389 | fig.add_axes(ax) 390 | 391 | plt.imshow(image, aspect='auto') 392 | 393 | if point is not None: 394 | # plot point on image 395 | plt.scatter(point[0].cpu(), point[1].cpu(), s=20, marker='o', c='r') 396 | 397 | 398 | plt.savefig(f'{save_folder}/{name}.png', dpi=200) 399 | plt.close() 400 | 401 | 402 | def gaussian_circle(pos, size=64, sigma=16, device = "cuda"): 403 | """Create a 2D Gaussian circle with a given size, standard deviation, and center coordinates. 404 | 405 | pos is in between 0 and 1 406 | 407 | """ 408 | _pos = pos*size 409 | grid = torch.meshgrid(torch.arange(size).to(device), torch.arange(size).to(device), indexing='ij') 410 | grid = torch.stack(grid, dim=-1) 411 | dist_sq = (grid[..., 1] - _pos[0])**2 + (grid[..., 0] - _pos[1])**2 412 | dist_sq = -1*dist_sq / (2. * sigma**2.) 413 | gaussian = torch.exp(dist_sq) 414 | return gaussian 415 | 416 | 417 | def crop_image(image, pixel, crop_percent=80, margin=0.15): 418 | 419 | """pixel is an integer between 0 and image.shape[1] or image.shape[2] 420 | """ 421 | 422 | assert 0 < crop_percent <= 100, "crop_percent should be between 0 and 100" 423 | 424 | height, width, channels = image.shape 425 | crop_height = int(height * crop_percent / 100) 426 | crop_width = int(width * crop_percent / 100) 427 | 428 | # Calculate the crop region's top-left corner 429 | x, y = pixel 430 | 431 | # Calculate safe margin 432 | safe_margin_x = int(crop_width * margin) 433 | safe_margin_y = int(crop_height * margin) 434 | 435 | x_start_min = max(0, x - crop_width + safe_margin_x) 436 | x_start_min = min(x_start_min, width - crop_width) 437 | x_start_max = max(0, x - safe_margin_x) 438 | x_start_max = min(x_start_max, width - crop_width) 439 | 440 | y_start_min = max(0, y - crop_height + safe_margin_y) 441 | y_start_min = min(y_start_min, height - crop_height) 442 | y_start_max = max(0, y - safe_margin_y) 443 | y_start_max = min(y_start_max, height - crop_height) 444 | 445 | # Choose a random top-left corner within the allowed bounds 446 | x_start = torch.randint(int(x_start_min), int(x_start_max) + 1, (1,)).item() 447 | y_start = torch.randint(int(y_start_min), int(y_start_max) + 1, (1,)).item() 448 | 449 | # Crop the image 450 | cropped_image = image[y_start:y_start + crop_height, x_start:x_start + crop_width] 451 | 452 | # bilinearly upsample to 512x512 453 | cropped_image = torch.nn.functional.interpolate(torch.tensor(cropped_image[None]).permute(0, 3, 1, 2), size=(512, 512), mode='bilinear', align_corners=False)[0] 454 | 455 | # calculate new pixel location 456 | new_pixel = torch.stack([x-x_start, y-y_start]) 457 | new_pixel = new_pixel/crop_width 458 | 459 | return cropped_image.permute(1, 2, 0).numpy(), new_pixel, y_start, crop_height, x_start, crop_width 460 | 461 | 462 | def optimize_prompt(ldm, image, pixel_loc, context=None, device="cuda", num_steps=100, from_where = ["down_cross", "mid_cross", "up_cross"], upsample_res = 32, layers = [0, 1, 2, 3, 4, 5], lr=1e-3, noise_level = -1, sigma = 32, flip_prob = 0.5, crop_percent=80): 463 | 464 | # if image is a torch.tensor, convert to numpy 465 | if type(image) == torch.Tensor: 466 | image = image.permute(1, 2, 0).detach().cpu().numpy() 467 | 468 | if context is None: 469 | context = init_random_noise(device) 470 | 471 | context.requires_grad = True 472 | 473 | # optimize context to maximize attention at pixel_loc 474 | optimizer = torch.optim.Adam([context], lr=lr) 475 | 476 | # time the optimization 477 | import time 478 | start = time.time() 479 | 480 | for iteration in range(num_steps): 481 | 482 | with torch.no_grad(): 483 | 484 | if np.random.rand() > flip_prob: 485 | 486 | cropped_image, cropped_pixel, _, _, _, _ = crop_image(image, pixel_loc*512, crop_percent = crop_percent) 487 | 488 | latent = image2latent(ldm, cropped_image, device) 489 | 490 | _pixel_loc = cropped_pixel.clone() 491 | else: 492 | 493 | image_flipped = np.flip(image, axis=1).copy() 494 | 495 | pixel_loc_flipped = pixel_loc.clone() 496 | # flip pixel loc 497 | pixel_loc_flipped[0] = 1 - pixel_loc_flipped[0] 498 | 499 | cropped_image, cropped_pixel, _, _, _, _ = crop_image(image_flipped, pixel_loc_flipped*512, crop_percent = crop_percent) 500 | 501 | _pixel_loc = cropped_pixel.clone() 502 | 503 | latent = image2latent(ldm, cropped_image, device) 504 | 505 | noisy_image = ldm.scheduler.add_noise(latent, torch.rand_like(latent), ldm.scheduler.timesteps[noise_level]) 506 | 507 | controller = AttentionStore() 508 | 509 | register_attention_control(ldm, controller) 510 | 511 | _ = diffusion_step(ldm, controller, noisy_image, context, ldm.scheduler.timesteps[noise_level], cfg = False) 512 | 513 | attention_maps = upscale_to_img_size(controller, from_where = from_where, upsample_res=upsample_res, layers = layers) 514 | num_maps = attention_maps.shape[0] 515 | 516 | # divide by the mean along the dim=1 517 | attention_maps = torch.mean(attention_maps, dim=1) 518 | 519 | gt_maps = gaussian_circle(_pixel_loc, size=upsample_res, sigma=sigma, device = device) 520 | 521 | gt_maps = gt_maps.reshape(1, -1).repeat(num_maps, 1) 522 | attention_maps = attention_maps.reshape(num_maps, -1) 523 | 524 | loss = torch.nn.MSELoss()(attention_maps, gt_maps) 525 | loss.backward() 526 | optimizer.step() 527 | optimizer.zero_grad() 528 | 529 | return context 530 | 531 | -------------------------------------------------------------------------------- /Correspondence/sc_models/ldm_sc/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from typing import Optional, Union, Tuple, List, Dict 18 | from tqdm.notebook import tqdm 19 | import torch.nn.functional as F 20 | 21 | 22 | 23 | def diffusion_step(model, controller, latents, context, t, guidance_scale=None, cfg = True): 24 | 25 | if cfg: 26 | latents_input = torch.cat([latents] * 2) 27 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 28 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 29 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 30 | else: 31 | noise_pred = model.unet(latents, t, encoder_hidden_states=context)["sample"] 32 | 33 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 34 | latents = controller.step_callback(latents) 35 | return latents 36 | 37 | 38 | def latent2image(vae, latents): 39 | latents = 1 / 0.18215 * latents 40 | image = vae.decode(latents)['sample'] 41 | image = (image / 2 + 0.5).clamp(0, 1) 42 | image = image.cpu().permute(0, 2, 3, 1).numpy() 43 | image = (image * 255).astype(np.uint8) 44 | return image 45 | 46 | 47 | def init_latent(latent, model, height, width, generator, batch_size): 48 | if latent is None: 49 | latent = torch.randn( 50 | (1, model.unet.in_channels, height // 8, width // 8), 51 | generator=generator, 52 | ) 53 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 54 | return latent, latents 55 | 56 | 57 | @torch.no_grad() 58 | def text2image_ldm( 59 | model, 60 | prompt: List[str], 61 | controller, 62 | num_inference_steps: int = 50, 63 | guidance_scale: Optional[float] = 7., 64 | generator: Optional[torch.Generator] = None, 65 | latent: Optional[torch.FloatTensor] = None, 66 | ): 67 | register_attention_control(model, controller) 68 | height = width = 256 69 | batch_size = len(prompt) 70 | 71 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 72 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 73 | 74 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 75 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 76 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 77 | context = torch.cat([uncond_embeddings, text_embeddings]) 78 | 79 | model.scheduler.set_timesteps(num_inference_steps) 80 | for t in tqdm(model.scheduler.timesteps): 81 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 82 | 83 | image = latent2image(model.vqvae, latents) 84 | 85 | return image, latent 86 | 87 | 88 | @torch.no_grad() 89 | def text2image_ldm_stable( 90 | model, 91 | prompt: List[str], 92 | controller, 93 | num_inference_steps: int = 50, 94 | guidance_scale: float = 7.5, 95 | generator: Optional[torch.Generator] = None, 96 | latent: Optional[torch.FloatTensor] = None, 97 | low_resource: bool = False, 98 | ): 99 | register_attention_control(model, controller) 100 | height = width = 512 101 | batch_size = len(prompt) 102 | 103 | text_input = model.tokenizer( 104 | prompt, 105 | padding="max_length", 106 | max_length=model.tokenizer.model_max_length, 107 | truncation=True, 108 | return_tensors="pt", 109 | ) 110 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 111 | max_length = text_input.input_ids.shape[-1] 112 | uncond_input = model.tokenizer( 113 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 114 | ) 115 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 116 | 117 | context = [uncond_embeddings, text_embeddings] 118 | if not low_resource: 119 | context = torch.cat(context) 120 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 121 | 122 | # set timesteps 123 | # extra_set_kwargs = {"offset": 1} 124 | extra_set_kwargs = {} 125 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 126 | for t in tqdm(model.scheduler.timesteps): 127 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 128 | 129 | image = latent2image(model.vae, latents) 130 | 131 | return image, latent 132 | 133 | def softmax_torch(x): # Assuming x has atleast 2 dimensions 134 | maxes = torch.max(x, -1, keepdim=True)[0] 135 | x_exp = torch.exp(x-maxes) 136 | x_exp_sum = torch.sum(x_exp, -1, keepdim=True) 137 | probs = x_exp/x_exp_sum 138 | return probs 139 | 140 | 141 | def register_attention_control(model, controller): 142 | def ca_forward(self, place_in_unet): 143 | to_out = self.to_out 144 | if type(to_out) is torch.nn.modules.container.ModuleList: 145 | to_out = self.to_out[0] 146 | else: 147 | to_out = self.to_out 148 | 149 | def forward(x, context=None, mask=None): 150 | batch_size, sequence_length, dim = x.shape 151 | h = self.heads 152 | q = self.to_q(x) 153 | is_cross = context is not None 154 | context = context if is_cross else x 155 | k = self.to_k(context) 156 | v = self.to_v(context) 157 | q = self.reshape_heads_to_batch_dim(q) 158 | k = self.reshape_heads_to_batch_dim(k) 159 | v = self.reshape_heads_to_batch_dim(v) 160 | 161 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 162 | # sim = torch.matmul(q, k.permute(0, 2, 1)) * self.scale 163 | 164 | if mask is not None: 165 | mask = mask.reshape(batch_size, -1) 166 | max_neg_value = -torch.finfo(sim.dtype).max 167 | mask = mask[:, None, :].repeat(h, 1, 1) 168 | sim = sim.masked_fill(~mask, max_neg_value) 169 | 170 | # attention, what we cannot get enough of 171 | attn = torch.nn.Softmax(dim=-1)(sim) 172 | attn = attn.clone() 173 | attn = controller(attn, is_cross, place_in_unet) 174 | out = torch.matmul(attn, v) 175 | 176 | out = self.reshape_batch_dim_to_heads(out) 177 | return to_out(out) 178 | 179 | return forward 180 | 181 | class DummyController: 182 | 183 | def __call__(self, *args): 184 | return args[0] 185 | 186 | def __init__(self): 187 | self.num_att_layers = 0 188 | 189 | if controller is None: 190 | controller = DummyController() 191 | 192 | def register_recr(net_, count, place_in_unet): 193 | if net_.__class__.__name__ == 'CrossAttention': 194 | net_.forward = ca_forward(net_, place_in_unet) 195 | return count + 1 196 | elif hasattr(net_, 'children'): 197 | for net__ in net_.children(): 198 | count = register_recr(net__, count, place_in_unet) 199 | return count 200 | 201 | cross_att_count = 0 202 | sub_nets = model.unet.named_children() 203 | for net in sub_nets: 204 | if "down" in net[0]: 205 | cross_att_count += register_recr(net[1], 0, "down") 206 | elif "up" in net[0]: 207 | cross_att_count += register_recr(net[1], 0, "up") 208 | elif "mid" in net[0]: 209 | cross_att_count += register_recr(net[1], 0, "mid") 210 | 211 | controller.num_att_layers = cross_att_count 212 | 213 | 214 | def get_word_inds(text: str, word_place: int, tokenizer): 215 | split_text = text.split(" ") 216 | if type(word_place) is str: 217 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 218 | elif type(word_place) is int: 219 | word_place = [word_place] 220 | out = [] 221 | if len(word_place) > 0: 222 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 223 | cur_len, ptr = 0, 0 224 | 225 | for i in range(len(words_encode)): 226 | cur_len += len(words_encode[i]) 227 | if ptr in word_place: 228 | out.append(i + 1) 229 | if cur_len >= len(split_text[ptr]): 230 | ptr += 1 231 | cur_len = 0 232 | return np.array(out) 233 | 234 | 235 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 236 | word_inds: Optional[torch.Tensor]=None): 237 | if type(bounds) is float: 238 | bounds = 0, bounds 239 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 240 | if word_inds is None: 241 | word_inds = torch.arange(alpha.shape[2]) 242 | alpha[: start, prompt_ind, word_inds] = 0 243 | alpha[start: end, prompt_ind, word_inds] = 1 244 | alpha[end:, prompt_ind, word_inds] = 0 245 | return alpha 246 | 247 | 248 | def get_time_words_attention_alpha(prompts, num_steps, 249 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 250 | tokenizer, max_num_words=77): 251 | if type(cross_replace_steps) is not dict: 252 | cross_replace_steps = {"default_": cross_replace_steps} 253 | if "default_" not in cross_replace_steps: 254 | cross_replace_steps["default_"] = (0., 1.) 255 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 256 | for i in range(len(prompts) - 1): 257 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 258 | i) 259 | for key, item in cross_replace_steps.items(): 260 | if key != "default_": 261 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 262 | for i, ind in enumerate(inds): 263 | if len(ind) > 0: 264 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 265 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 266 | return alpha_time_words 267 | -------------------------------------------------------------------------------- /Correspondence/sc_models/sd_dino/cor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | from matplotlib.colors import ListedColormap 7 | from typing import List, Tuple 8 | import faiss 9 | import cv2 10 | import os 11 | from matplotlib.patches import ConnectionPatch 12 | 13 | def resize(img, target_res, resize=True, to_pil=True, edge=False): 14 | original_width, original_height = img.size 15 | original_channels = len(img.getbands()) 16 | if not edge: 17 | canvas = np.zeros([target_res, target_res, 3], dtype=np.uint8) 18 | if original_channels == 1: 19 | canvas = np.zeros([target_res, target_res], dtype=np.uint8) 20 | if original_height <= original_width: 21 | if resize: 22 | img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.Resampling.LANCZOS) 23 | width, height = img.size 24 | img = np.asarray(img) 25 | canvas[(width - height) // 2: (width + height) // 2] = img 26 | else: 27 | if resize: 28 | img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.Resampling.LANCZOS) 29 | width, height = img.size 30 | img = np.asarray(img) 31 | canvas[:, (height - width) // 2: (height + width) // 2] = img 32 | else: 33 | if original_height <= original_width: 34 | if resize: 35 | img = img.resize((target_res, int(np.around(target_res * original_height / original_width))), Image.Resampling.LANCZOS) 36 | width, height = img.size 37 | img = np.asarray(img) 38 | top_pad = (target_res - height) // 2 39 | bottom_pad = target_res - height - top_pad 40 | img = np.pad(img, pad_width=[(top_pad, bottom_pad), (0, 0), (0, 0)], mode='edge') 41 | else: 42 | if resize: 43 | img = img.resize((int(np.around(target_res * original_width / original_height)), target_res), Image.Resampling.LANCZOS) 44 | width, height = img.size 45 | img = np.asarray(img) 46 | left_pad = (target_res - width) // 2 47 | right_pad = target_res - width - left_pad 48 | img = np.pad(img, pad_width=[(0, 0), (left_pad, right_pad), (0, 0)], mode='edge') 49 | canvas = img 50 | if to_pil: 51 | canvas = Image.fromarray(canvas) 52 | return canvas 53 | 54 | 55 | def find_nearest_patchs(mask1, mask2, image1, image2, features1, features2, mask=False, resolution=None, edit_image=None): 56 | def polar_color_map(image_shape): 57 | h, w = image_shape[:2] 58 | x = np.linspace(-1, 1, w) 59 | y = np.linspace(-1, 1, h) 60 | xx, yy = np.meshgrid(x, y) 61 | 62 | # Find the center of the mask 63 | mask=mask2.cpu() 64 | mask_center = np.array(np.where(mask > 0)) 65 | mask_center = np.round(np.mean(mask_center, axis=1)).astype(int) 66 | mask_center_y, mask_center_x = mask_center 67 | 68 | # Calculate distance and angle based on mask_center 69 | xx_shifted, yy_shifted = xx - x[mask_center_x], yy - y[mask_center_y] 70 | max_radius = np.sqrt(h**2 + w**2) / 2 71 | radius = np.sqrt(xx_shifted**2 + yy_shifted**2) * max_radius 72 | angle = np.arctan2(yy_shifted, xx_shifted) / (2 * np.pi) + 0.5 73 | 74 | angle = 0.2 + angle * 0.6 # Map angle to the range [0.25, 0.75] 75 | radius = np.where(radius <= max_radius, radius, max_radius) # Limit radius values to the unit circle 76 | radius = 0.2 + radius * 0.6 / max_radius # Map radius to the range [0.1, 1] 77 | 78 | return angle, radius 79 | 80 | if resolution is not None: # resize the feature map to the resolution 81 | features1 = F.interpolate(features1, size=resolution, mode='bilinear') 82 | features2 = F.interpolate(features2, size=resolution, mode='bilinear') 83 | 84 | # resize the image to the shape of the feature map 85 | resized_image1 = resize(image1, features1.shape[2], resize=True, to_pil=False) 86 | resized_image2 = resize(image2, features2.shape[2], resize=True, to_pil=False) 87 | 88 | if mask: # mask the features 89 | resized_mask1 = F.interpolate(mask1.cuda().unsqueeze(0).unsqueeze(0).float(), size=features1.shape[2:], mode='nearest') 90 | resized_mask2 = F.interpolate(mask2.cuda().unsqueeze(0).unsqueeze(0).float(), size=features2.shape[2:], mode='nearest') 91 | features1 = features1 * resized_mask1.repeat(1, features1.shape[1], 1, 1) 92 | features2 = features2 * resized_mask2.repeat(1, features2.shape[1], 1, 1) 93 | # set where mask==0 a very large number 94 | features1[(features1.sum(1)==0).repeat(1, features1.shape[1], 1, 1)] = 100000 95 | features2[(features2.sum(1)==0).repeat(1, features2.shape[1], 1, 1)] = 100000 96 | 97 | features1_2d = features1.reshape(features1.shape[1], -1).permute(1, 0).cpu().detach().numpy() 98 | features2_2d = features2.reshape(features2.shape[1], -1).permute(1, 0).cpu().detach().numpy() 99 | 100 | features1_2d = torch.tensor(features1_2d).to("cuda") 101 | features2_2d = torch.tensor(features2_2d).to("cuda") 102 | resized_image1 = torch.tensor(resized_image1).to("cuda").float() 103 | resized_image2 = torch.tensor(resized_image2).to("cuda").float() 104 | 105 | mask1 = F.interpolate(mask1.cuda().unsqueeze(0).unsqueeze(0).float(), size=resized_image1.shape[:2], mode='nearest').squeeze(0).squeeze(0) 106 | mask2 = F.interpolate(mask2.cuda().unsqueeze(0).unsqueeze(0).float(), size=resized_image2.shape[:2], mode='nearest').squeeze(0).squeeze(0) 107 | 108 | # Mask the images 109 | resized_image1 = resized_image1 * mask1.unsqueeze(-1).repeat(1, 1, 3) 110 | resized_image2 = resized_image2 * mask2.unsqueeze(-1).repeat(1, 1, 3) 111 | # Normalize the images to the range [0, 1] 112 | resized_image1 = (resized_image1 - resized_image1.min()) / (resized_image1.max() - resized_image1.min()) 113 | resized_image2 = (resized_image2 - resized_image2.min()) / (resized_image2.max() - resized_image2.min()) 114 | 115 | angle, radius = polar_color_map(resized_image2.shape) 116 | 117 | angle_mask = angle * mask2.cpu().numpy() 118 | radius_mask = radius * mask2.cpu().numpy() 119 | 120 | hsv_mask = np.zeros(resized_image2.shape, dtype=np.float32) 121 | hsv_mask[:, :, 0] = angle_mask 122 | hsv_mask[:, :, 1] = radius_mask 123 | hsv_mask[:, :, 2] = 1 124 | 125 | rainbow_mask2 = cv2.cvtColor((hsv_mask * 255).astype(np.uint8), cv2.COLOR_HSV2BGR) / 255 126 | 127 | if edit_image is not None: 128 | rainbow_mask2 = cv2.imread(edit_image, cv2.IMREAD_COLOR) 129 | rainbow_mask2 = cv2.cvtColor(rainbow_mask2, cv2.COLOR_BGR2RGB) / 255 130 | rainbow_mask2 = cv2.resize(rainbow_mask2, (resized_image2.shape[1], resized_image2.shape[0])) 131 | 132 | # Apply the rainbow mask to image2 133 | rainbow_image2 = rainbow_mask2 * mask2.cpu().numpy()[:, :, None] 134 | 135 | # Create a white background image 136 | background_color = np.array([1, 1, 1], dtype=np.float32) 137 | background_image = np.ones(resized_image2.shape, dtype=np.float32) * background_color 138 | 139 | # Apply the rainbow mask to image2 only in the regions where mask2 is 1 140 | rainbow_image2 = np.where(mask2.cpu().numpy()[:, :, None] == 1, rainbow_mask2, background_image) 141 | 142 | nearest_patches = [] 143 | 144 | distances = torch.cdist(features1_2d, features2_2d) 145 | nearest_patch_indices = torch.argmin(distances, dim=1) 146 | nearest_patches = torch.index_select(torch.tensor(rainbow_mask2).cuda().reshape(-1, 3), 0, nearest_patch_indices) 147 | 148 | nearest_patches_image = nearest_patches.reshape(resized_image1.shape) 149 | rainbow_image2 = torch.tensor(rainbow_image2).to("cuda") 150 | 151 | # TODO: upsample the nearest_patches_image to the resolution of the original image 152 | # nearest_patches_image = F.interpolate(nearest_patches_image.permute(2,0,1).unsqueeze(0), size=256, mode='bilinear').squeeze(0).permute(1,2,0) 153 | # rainbow_image2 = F.interpolate(rainbow_image2.permute(2,0,1).unsqueeze(0), size=256, mode='bilinear').squeeze(0).permute(1,2,0) 154 | 155 | nearest_patches_image = (nearest_patches_image).cpu().numpy() 156 | resized_image2 = (rainbow_image2).cpu().numpy() 157 | 158 | return nearest_patches_image, resized_image2 159 | 160 | 161 | def find_nearest_patchs_replace(mask1, mask2, image1, image2, features1, features2, mask=False, resolution=128, draw_gif=False, save_path=None, gif_reverse=False): 162 | 163 | if resolution is not None: # resize the feature map to the resolution 164 | features1 = F.interpolate(features1, size=resolution, mode='bilinear') 165 | features2 = F.interpolate(features2, size=resolution, mode='bilinear') 166 | 167 | # resize the image to the shape of the feature map 168 | resized_image1 = resize(image1, features1.shape[2], resize=True, to_pil=False) 169 | resized_image2 = resize(image2, features2.shape[2], resize=True, to_pil=False) 170 | 171 | if mask: # mask the features 172 | resized_mask1 = F.interpolate(mask1.cuda().unsqueeze(0).unsqueeze(0).float(), size=features1.shape[2:], mode='nearest') 173 | resized_mask2 = F.interpolate(mask2.cuda().unsqueeze(0).unsqueeze(0).float(), size=features2.shape[2:], mode='nearest') 174 | features1 = features1 * resized_mask1.repeat(1, features1.shape[1], 1, 1) 175 | features2 = features2 * resized_mask2.repeat(1, features2.shape[1], 1, 1) 176 | # set where mask==0 a very large number 177 | features1[(features1.sum(1)==0).repeat(1, features1.shape[1], 1, 1)] = 100000 178 | features2[(features2.sum(1)==0).repeat(1, features2.shape[1], 1, 1)] = 100000 179 | 180 | features1_2d = features1.reshape(features1.shape[1], -1).permute(1, 0) 181 | features2_2d = features2.reshape(features2.shape[1], -1).permute(1, 0) 182 | 183 | resized_image1 = torch.tensor(resized_image1).to("cuda").float() 184 | resized_image2 = torch.tensor(resized_image2).to("cuda").float() 185 | 186 | mask1 = F.interpolate(mask1.cuda().unsqueeze(0).unsqueeze(0).float(), size=resized_image1.shape[:2], mode='nearest').squeeze(0).squeeze(0) 187 | mask2 = F.interpolate(mask2.cuda().unsqueeze(0).unsqueeze(0).float(), size=resized_image2.shape[:2], mode='nearest').squeeze(0).squeeze(0) 188 | 189 | # Mask the images 190 | resized_image1 = resized_image1 * mask1.unsqueeze(-1).repeat(1, 1, 3) 191 | resized_image2 = resized_image2 * mask2.unsqueeze(-1).repeat(1, 1, 3) 192 | # Normalize the images to the range [0, 1] 193 | resized_image1 = (resized_image1 - resized_image1.min()) / (resized_image1.max() - resized_image1.min()) 194 | resized_image2 = (resized_image2 - resized_image2.min()) / (resized_image2.max() - resized_image2.min()) 195 | 196 | distances = torch.cdist(features1_2d, features2_2d) 197 | nearest_patch_indices = torch.argmin(distances, dim=1) 198 | nearest_patches = torch.index_select(resized_image2.cuda().clone().detach().reshape(-1, 3), 0, nearest_patch_indices) 199 | 200 | nearest_patches_image = nearest_patches.reshape(resized_image1.shape) 201 | 202 | if draw_gif: 203 | assert save_path is not None, "save_path must be provided when draw_gif is True" 204 | img_1 = resize(image1, features1.shape[2], resize=True, to_pil=True) 205 | img_2 = resize(image2, features2.shape[2], resize=True, to_pil=True) 206 | mapping = torch.zeros((img_1.size[1], img_1.size[0], 2)) 207 | for i in range(len(nearest_patch_indices)): 208 | mapping[i // img_1.size[0], i % img_1.size[0]] = torch.tensor([nearest_patch_indices[i] // img_2.size[0], nearest_patch_indices[i] % img_2.size[0]]) 209 | animate_image_transfer(img_1, img_2, mapping, save_path) if gif_reverse else animate_image_transfer_reverse(img_1, img_2, mapping, save_path) 210 | 211 | # TODO: upsample the nearest_patches_image to the resolution of the original image 212 | # nearest_patches_image = F.interpolate(nearest_patches_image.permute(2,0,1).unsqueeze(0), size=256, mode='bilinear').squeeze(0).permute(1,2,0) 213 | # resized_image2 = F.interpolate(resized_image2.permute(2,0,1).unsqueeze(0), size=256, mode='bilinear').squeeze(0).permute(1,2,0) 214 | 215 | nearest_patches_image = (nearest_patches_image).cpu().numpy() 216 | resized_image2 = (resized_image2).cpu().numpy() 217 | 218 | return nearest_patches_image, resized_image2 219 | 220 | def chunk_cosine_sim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 221 | """ Computes cosine similarity between all possible pairs in two sets of vectors. 222 | Operates on chunks so no large amount of GPU RAM is required. 223 | :param x: an tensor of descriptors of shape Bx1x(t_x)xd' where d' is the dimensionality of the descriptors and t_x 224 | is the number of tokens in x. 225 | :param y: a tensor of descriptors of shape Bx1x(t_y)xd' where d' is the dimensionality of the descriptors and t_y 226 | is the number of tokens in y. 227 | :return: cosine similarity between all descriptors in x and all descriptors in y. Has shape of Bx1x(t_x)x(t_y) """ 228 | result_list = [] 229 | num_token_x = x.shape[2] 230 | for token_idx in range(num_token_x): 231 | token = x[:, :, token_idx, :].unsqueeze(dim=2) # Bx1x1xd' 232 | result_list.append(torch.nn.CosineSimilarity(dim=3)(token, y)) # Bx1xt 233 | return torch.stack(result_list, dim=2) # Bx1x(t_x)x(t_y) 234 | 235 | def pairwise_sim(x: torch.Tensor, y: torch.Tensor, p=2, normalize=False) -> torch.Tensor: 236 | # compute similarity based on euclidean distances 237 | if normalize: 238 | x = torch.nn.functional.normalize(x, dim=-1) 239 | y = torch.nn.functional.normalize(y, dim=-1) 240 | result_list=[] 241 | num_token_x = x.shape[2] 242 | for token_idx in range(num_token_x): 243 | token = x[:, :, token_idx, :].unsqueeze(dim=2) 244 | result_list.append(torch.nn.PairwiseDistance(p=p)(token, y)*(-1)) 245 | return torch.stack(result_list, dim=2) 246 | 247 | def draw_correspondences_gathered(points1: List[Tuple[float, float]], points2: List[Tuple[float, float]], 248 | image1: Image.Image, image2: Image.Image) -> plt.Figure: 249 | """ 250 | draw point correspondences on images. 251 | :param points1: a list of (y, x) coordinates of image1, corresponding to points2. 252 | :param points2: a list of (y, x) coordinates of image2, corresponding to points1. 253 | :param image1: a PIL image. 254 | :param image2: a PIL image. 255 | :return: a figure of images with marked points. 256 | """ 257 | assert len(points1) == len(points2), f"points lengths are incompatible: {len(points1)} != {len(points2)}." 258 | num_points = len(points1) 259 | 260 | if num_points > 15: 261 | cmap = plt.get_cmap('tab10') 262 | else: 263 | cmap = ListedColormap(["red", "yellow", "blue", "lime", "magenta", "indigo", "orange", "cyan", "darkgreen", 264 | "maroon", "black", "white", "chocolate", "gray", "blueviolet"]) 265 | colors = np.array([cmap(x) for x in range(num_points)]) 266 | radius1, radius2 = 0.03*max(image1.size), 0.01*max(image1.size) 267 | 268 | # plot a subfigure put image1 in the top, image2 in the bottom 269 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) 270 | ax1.axis('off') 271 | ax2.axis('off') 272 | ax1.imshow(image1) 273 | ax2.imshow(image2) 274 | 275 | for point1, point2, color in zip(points1, points2, colors): 276 | y1, x1 = point1 277 | circ1_1 = plt.Circle((x1, y1), radius1, facecolor=color, edgecolor='white', alpha=0.5) 278 | circ1_2 = plt.Circle((x1, y1), radius2, facecolor=color, edgecolor='white') 279 | ax1.add_patch(circ1_1) 280 | ax1.add_patch(circ1_2) 281 | y2, x2 = point2 282 | circ2_1 = plt.Circle((x2, y2), radius1, facecolor=color, edgecolor='white', alpha=0.5) 283 | circ2_2 = plt.Circle((x2, y2), radius2, facecolor=color, edgecolor='white') 284 | ax2.add_patch(circ2_1) 285 | ax2.add_patch(circ2_2) 286 | 287 | return fig 288 | 289 | def draw_correspondences_lines(points1: List[Tuple[float, float]], points2: List[Tuple[float, float]], 290 | gt_points2: List[Tuple[float, float]], image1: Image.Image, 291 | image2: Image.Image, threshold=None) -> plt.Figure: 292 | """ 293 | draw point correspondences on images. 294 | :param points1: a list of (y, x) coordinates of image1, corresponding to points2. 295 | :param points2: a list of (y, x) coordinates of image2, corresponding to points1. 296 | :param gt_points2: a list of ground truth (y, x) coordinates of image2. 297 | :param image1: a PIL image. 298 | :param image2: a PIL image. 299 | :param threshold: distance threshold to determine correct matches. 300 | :return: a figure of images with marked points and lines between them showing correspondence. 301 | """ 302 | 303 | points2=points2.cpu().numpy() 304 | gt_points2=gt_points2.cpu().numpy() 305 | 306 | def compute_correct(): 307 | alpha = torch.tensor([0.1, 0.05, 0.01]) 308 | correct = torch.zeros(len(alpha)) 309 | err = (torch.tensor(points2) - torch.tensor(gt_points2)).norm(dim=-1) 310 | err = err.unsqueeze(0).repeat(len(alpha), 1) 311 | correct = err < threshold.unsqueeze(-1) if len(threshold.shape)==1 else err < threshold 312 | return correct 313 | 314 | correct = compute_correct()[0] 315 | # print(correct.shape, len(points1)) 316 | 317 | assert len(points1) == len(points2), f"points lengths are incompatible: {len(points1)} != {len(points2)}." 318 | num_points = len(points1) 319 | 320 | if num_points > 15: 321 | cmap = plt.get_cmap('tab10') 322 | else: 323 | cmap = ListedColormap(["red", "yellow", "blue", "lime", "magenta", "indigo", "orange", "cyan", "darkgreen", 324 | "maroon", "black", "white", "chocolate", "gray", "blueviolet"]) 325 | colors = np.array([cmap(x) for x in range(num_points)]) 326 | radius1, radius2 = 0.03*max(image1.size), 0.01*max(image1.size) 327 | 328 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) 329 | ax1.axis('off') 330 | ax2.axis('off') 331 | ax1.imshow(image1) 332 | ax2.imshow(image2) 333 | ax1.set_xlim(0, image1.size[0]) 334 | ax1.set_ylim(image1.size[1], 0) 335 | ax2.set_xlim(0, image2.size[0]) 336 | ax2.set_ylim(image2.size[1], 0) 337 | 338 | for i, (point1, point2) in enumerate(zip(points1, points2)): 339 | y1, x1 = point1 340 | circ1_1 = plt.Circle((x1, y1), radius1, facecolor=colors[i], edgecolor='white', alpha=0.5) 341 | circ1_2 = plt.Circle((x1, y1), radius2, facecolor=colors[i], edgecolor='white') 342 | ax1.add_patch(circ1_1) 343 | ax1.add_patch(circ1_2) 344 | y2, x2 = point2 345 | circ2_1 = plt.Circle((x2, y2), radius1, facecolor=colors[i], edgecolor='white', alpha=0.5) 346 | circ2_2 = plt.Circle((x2, y2), radius2, facecolor=colors[i], edgecolor='white') 347 | ax2.add_patch(circ2_1) 348 | ax2.add_patch(circ2_2) 349 | 350 | # Draw lines 351 | color = 'blue' if correct[i].item() else 'red' 352 | con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", 353 | axesA=ax2, axesB=ax1, color=color, linewidth=1.5) 354 | ax2.add_artist(con) 355 | 356 | return fig 357 | 358 | def co_pca(features1, features2, dim=[128,128,128]): 359 | 360 | processed_features1 = {} 361 | processed_features2 = {} 362 | s5_size = features1['s5'].shape[-1] 363 | s4_size = features1['s4'].shape[-1] 364 | s3_size = features1['s3'].shape[-1] 365 | # Get the feature tensors 366 | s5_1 = features1['s5'].reshape(features1['s5'].shape[0], features1['s5'].shape[1], -1) 367 | s4_1 = features1['s4'].reshape(features1['s4'].shape[0], features1['s4'].shape[1], -1) 368 | s3_1 = features1['s3'].reshape(features1['s3'].shape[0], features1['s3'].shape[1], -1) 369 | 370 | s5_2 = features2['s5'].reshape(features2['s5'].shape[0], features2['s5'].shape[1], -1) 371 | s4_2 = features2['s4'].reshape(features2['s4'].shape[0], features2['s4'].shape[1], -1) 372 | s3_2 = features2['s3'].reshape(features2['s3'].shape[0], features2['s3'].shape[1], -1) 373 | # Define the target dimensions 374 | target_dims = {'s5': dim[0], 's4': dim[1], 's3': dim[2]} 375 | 376 | # Compute the PCA 377 | for name, tensors in zip(['s5', 's4', 's3'], [[s5_1, s5_2], [s4_1, s4_2], [s3_1, s3_2]]): 378 | target_dim = target_dims[name] 379 | 380 | # Concatenate the features 381 | features = torch.cat(tensors, dim=-1) # along the spatial dimension 382 | features = features.permute(0, 2, 1) # Bx(t_x+t_y)x(d) 383 | 384 | # Compute the PCA 385 | # pca = faiss.PCAMatrix(features.shape[-1], target_dim) 386 | 387 | # Train the PCA 388 | # pca.train(features[0].cpu().numpy()) 389 | 390 | # Apply the PCA 391 | # features = pca.apply(features[0].cpu().numpy()) # (t_x+t_y)x(d) 392 | 393 | # convert to tensor 394 | # features = torch.tensor(features, device=features1['s5'].device).unsqueeze(0).permute(0, 2, 1) # Bx(d)x(t_x+t_y) 395 | 396 | 397 | # equivalent to the above, pytorch implementation 398 | mean = torch.mean(features[0], dim=0, keepdim=True) 399 | centered_features = features[0] - mean 400 | U, S, V = torch.pca_lowrank(centered_features, q=target_dim) 401 | reduced_features = torch.matmul(centered_features, V[:, :target_dim]) # (t_x+t_y)x(d) 402 | features = reduced_features.unsqueeze(0).permute(0, 2, 1) # Bx(d)x(t_x+t_y) 403 | 404 | 405 | # Split the features 406 | processed_features1[name] = features[:, :, :features.shape[-1] // 2] # Bx(d)x(t_x) 407 | processed_features2[name] = features[:, :, features.shape[-1] // 2:] # Bx(d)x(t_y) 408 | 409 | # reshape the features 410 | processed_features1['s5']=processed_features1['s5'].reshape(processed_features1['s5'].shape[0], -1, s5_size, s5_size) 411 | processed_features1['s4']=processed_features1['s4'].reshape(processed_features1['s4'].shape[0], -1, s4_size, s4_size) 412 | processed_features1['s3']=processed_features1['s3'].reshape(processed_features1['s3'].shape[0], -1, s3_size, s3_size) 413 | 414 | processed_features2['s5']=processed_features2['s5'].reshape(processed_features2['s5'].shape[0], -1, s5_size, s5_size) 415 | processed_features2['s4']=processed_features2['s4'].reshape(processed_features2['s4'].shape[0], -1, s4_size, s4_size) 416 | processed_features2['s3']=processed_features2['s3'].reshape(processed_features2['s3'].shape[0], -1, s3_size, s3_size) 417 | 418 | # Upsample s5 spatially by a factor of 2 419 | processed_features1['s5'] = F.interpolate(processed_features1['s5'], size=(processed_features1['s4'].shape[-2:]), mode='bilinear', align_corners=False) 420 | processed_features2['s5'] = F.interpolate(processed_features2['s5'], size=(processed_features2['s4'].shape[-2:]), mode='bilinear', align_corners=False) 421 | 422 | # Concatenate upsampled_s5 and s4 to create a new s5 423 | processed_features1['s5'] = torch.cat([processed_features1['s4'], processed_features1['s5']], dim=1) 424 | processed_features2['s5'] = torch.cat([processed_features2['s4'], processed_features2['s5']], dim=1) 425 | 426 | # Set s3 as the new s4 427 | processed_features1['s4'] = processed_features1['s3'] 428 | processed_features2['s4'] = processed_features2['s3'] 429 | 430 | # Remove s3 from the features dictionary 431 | processed_features1.pop('s3') 432 | processed_features2.pop('s3') 433 | 434 | # current order are layer 8, 5, 2 435 | features1_gether_s4_s5 = torch.cat([processed_features1['s4'], F.interpolate(processed_features1['s5'], size=(processed_features1['s4'].shape[-2:]), mode='bilinear')], dim=1) 436 | features2_gether_s4_s5 = torch.cat([processed_features2['s4'], F.interpolate(processed_features2['s5'], size=(processed_features2['s4'].shape[-2:]), mode='bilinear')], dim=1) 437 | 438 | return features1_gether_s4_s5, features2_gether_s4_s5 439 | 440 | def animate_image_transfer(image1, image2, mapping, output_path): 441 | import numpy as np 442 | from PIL import Image 443 | import matplotlib.pyplot as plt 444 | import matplotlib.animation as animation 445 | 446 | # # Load your two images 447 | # image1 = Image.open(image1_path) 448 | # image2 = Image.open(image2_path) 449 | 450 | # Ensure the two images are the same size 451 | assert image1.size == image2.size, "Images must be the same size." 452 | rec_size = 2 453 | # Convert the images into numpy arrays 454 | image1_array = np.array(image1) 455 | image2_array = np.array(image2) 456 | 457 | # Retrieve the width and height of the images 458 | height, width, _ = image1_array.shape 459 | 460 | # Assume we have a mapping list 461 | mapping = mapping.cpu().numpy() 462 | 463 | # We add a column of white pixels between the two images 464 | gap = width // 10 465 | 466 | # Create a canvas with a width that is the sum of the widths of the two images and the gap. 467 | # The height is the same as the height of the images. 468 | fig, ax = plt.subplots(figsize=((2 * width + gap) / 200, height / 200), dpi=300) 469 | 470 | # Remove the axes 471 | ax.axis('off') 472 | 473 | # Create an image object, initializing it as entirely white 474 | combined_image = np.ones((height, 2 * width + gap, 3), dtype=np.uint8) * 255 475 | 476 | # Place image1 on the left, image2 on the right, with a gap in the middle 477 | combined_image[:, :width] = image1_array 478 | combined_image[:, width + gap:] = image2_array 479 | 480 | img_obj = ax.imshow(combined_image) 481 | 482 | # For each frame of the computation and animation, we need to know the start and target positions of each pixel 483 | starts = np.mgrid[:height, :width].reshape(2, -1).T 484 | targets = np.array([mapping[i, j] for i in range(height) for j in range(width)]) + [0, width + gap] 485 | 486 | # To better display the animation, we divide the pixel movement into several frames 487 | num_frames = 30 488 | 489 | def calculate_path(start, target, num_frames): 490 | """Calculate the path of a pixel from start to target over num_frames.""" 491 | # Generate linear values from 0 to 1 492 | t = np.linspace(0, 1, num_frames) 493 | 494 | # Apply the quadratic easing out function (starts fast, then slows down) 495 | t = 1 - (1 - t) ** 2 496 | 497 | # Calculate the path 498 | path = start + t[:, np.newaxis] * (target - start) 499 | 500 | return path 501 | 502 | def update(frame): 503 | # At the start of each frame, we initialize the canvas with image1 on the left, image2 on the right, and white in the middle 504 | combined_image.fill(255) 505 | combined_image[:, :width] = image1_array 506 | combined_image[:, width + gap:] = image2_array 507 | # In each frame, we move a small portion of pixels from the left image to the right image 508 | # This gives a better view of how the pixels move 509 | if frame >= num_frames - 1: 510 | frame = num_frames - 1 511 | for i in range(height): 512 | for j in range(width): 513 | # Calculate the current pixel's position 514 | start = starts[i * width + j] 515 | target = targets[i * width + j] 516 | # If the mapped target position is greater than 0, move the pixel, otherwise keep it the same 517 | if target[0] > 0 and target[1] > 0: 518 | position = calculate_path(start, target, num_frames)[frame] 519 | # Copy the current pixel's color to the new position 520 | combined_image[int(position[0])-rec_size//2:int(position[0])-rec_size//2+rec_size, int(position[1])-rec_size//2:int(position[1])-rec_size//2+rec_size] = image1_array[i, j] 521 | img_obj.set_array(combined_image) # Update the displayed image 522 | return img_obj, 523 | 524 | # Create the animation 525 | ani = animation.FuncAnimation(fig, update, frames=num_frames + 30, blit=True) 526 | if not os.path.exists(os.path.dirname(output_path)): 527 | os.makedirs(os.path.dirname(output_path)) 528 | # Save the animation 529 | ani.save(output_path, writer='pillow', fps=30) 530 | # save mapping 531 | np.save(output_path[:-4]+'.npy', mapping) 532 | 533 | 534 | def animate_image_transfer_reverse(image1, image2, mapping, output_path): 535 | import numpy as np 536 | from PIL import Image 537 | import matplotlib.pyplot as plt 538 | import matplotlib.animation as animation 539 | 540 | # # Load your two images 541 | # image1 = Image.open(image1_path) 542 | # image2 = Image.open(image2_path) 543 | 544 | # Ensure the two images are the same size 545 | assert image1.size == image2.size, "Images must be the same size." 546 | # rec_size = 2 547 | # Convert the images into numpy arrays 548 | image1_array = np.array(image1) 549 | image2_array = np.array(image2) 550 | 551 | # Retrieve the width and height of the images 552 | height, width, _ = image1_array.shape 553 | 554 | # Assume we have a mapping list 555 | mapping = mapping.cpu().numpy() 556 | 557 | # We add a column of white pixels between the two images 558 | gap = width // 10 559 | 560 | # Create a canvas with a width that is the sum of the widths of the two images and the gap. 561 | # The height is the same as the height of the images. 562 | fig, ax = plt.subplots(figsize=((2 * width + gap) / 200, height / 200), dpi=300) 563 | 564 | # Remove the axes 565 | ax.axis('off') 566 | 567 | # Create an image object, initializing it as entirely white 568 | combined_image = np.ones((height, 2 * width + gap, 3), dtype=np.uint8) * 255 569 | 570 | # Place image1 on the left, image2 on the right, with a gap in the middle 571 | combined_image[:, :width] = image2_array 572 | combined_image[:, width + gap:] = image1_array 573 | 574 | img_obj = ax.imshow(combined_image) 575 | 576 | # For each frame of the computation and animation, we need to know the start and target positions of each pixel 577 | starts = np.mgrid[:height, :width].reshape(2, -1).T + [0, width + gap] 578 | targets = np.array([mapping[i, j] for i in range(height) for j in range(width)]) 579 | 580 | # To better display the animation, we divide the pixel movement into several frames 581 | num_frames = 30 582 | 583 | def calculate_path(start, target, num_frames): 584 | """Calculate the path of a pixel from start to target over num_frames.""" 585 | # Generate linear values from 0 to 1 586 | t = np.linspace(1, 0, num_frames) 587 | 588 | # Apply the quadratic easing out function (starts fast, then slows down) 589 | t = 1 - (1 - t) ** 2 590 | 591 | # Calculate the path 592 | path = start + t[:, np.newaxis] * (target - start) 593 | 594 | return path 595 | 596 | def update(frame): 597 | # At the start of each frame, we initialize the canvas with image1 on the left, image2 on the right, and white in the middle 598 | combined_image.fill(255) 599 | combined_image[:, :width] = image2_array 600 | combined_image[:, width + gap:] = image1_array 601 | # In each frame, we move a small portion of pixels from the left image to the right image 602 | # This gives a better view of how the pixels move 603 | if frame >= num_frames - 1: 604 | frame = num_frames - 1 605 | if frame >= num_frames // 6 * 5: 606 | rec_size = 1 607 | else: 608 | rec_size = 2 609 | for i in range(height): 610 | for j in range(width): 611 | # Calculate the current pixel's position 612 | start = starts[i * width + j] 613 | target = targets[i * width + j] 614 | # If the mapped target position is greater than 0, move the pixel, otherwise keep it the same 615 | if target[0] > 0 and target[1] > 0: 616 | position = calculate_path(start, target, num_frames)[frame] 617 | # Copy the current pixel's color to the new position 618 | combined_image[int(position[0])-rec_size//2:int(position[0])-rec_size//2+rec_size, int(position[1])-rec_size//2:int(position[1])-rec_size//2+rec_size] = image2_array[int(mapping[i, j][0]), int(mapping[i, j][1])] 619 | img_obj.set_array(combined_image) # Update the displayed image 620 | return img_obj, 621 | 622 | # Create the animation 623 | ani = animation.FuncAnimation(fig, update, frames=num_frames + 30, blit=True) 624 | if not os.path.exists(os.path.dirname(output_path)): 625 | os.makedirs(os.path.dirname(output_path)) 626 | # Save the animation 627 | ani.save(output_path, writer='pillow', fps=30) 628 | # save the maping 629 | np.save(output_path[:-4]+'.npy', mapping) -------------------------------------------------------------------------------- /Correspondence/sc_models/sd_dino/extractor_dino.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torchvision.transforms 4 | from torch import nn 5 | from torchvision import transforms 6 | import torch.nn.modules.utils as nn_utils 7 | import math 8 | import timm 9 | import types 10 | from pathlib import Path 11 | from typing import Union, List, Tuple 12 | from PIL import Image 13 | 14 | 15 | class ViTExtractor: 16 | """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. 17 | We use the following notation in the documentation of the module's methods: 18 | B - batch size 19 | h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW 20 | p - patch size of the ViT. either 8 or 16. 21 | t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width 22 | of the input image. 23 | d - the embedding dimension in the ViT. 24 | """ 25 | 26 | def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): 27 | """ 28 | :param model_type: A string specifying the type of model to extract from. 29 | [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | 30 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] 31 | :param stride: stride of first convolution layer. small stride -> higher resolution. 32 | :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. 33 | should be compatible with model_type. 34 | """ 35 | self.model_type = model_type 36 | self.device = device 37 | if model is not None: 38 | self.model = model 39 | else: 40 | self.model = ViTExtractor.create_model(model_type) 41 | 42 | self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) 43 | self.model.eval() 44 | self.model.to(self.device) 45 | self.p = self.model.patch_embed.patch_size 46 | if type(self.p)==tuple: 47 | self.p = self.p[0] 48 | self.stride = self.model.patch_embed.proj.stride 49 | 50 | self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) 51 | self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) 52 | 53 | self._feats = [] 54 | self.hook_handlers = [] 55 | self.load_size = None 56 | self.num_patches = None 57 | 58 | @staticmethod 59 | def create_model(model_type: str) -> nn.Module: 60 | """ 61 | :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | 62 | dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | 63 | vit_base_patch16_224] 64 | :return: the model 65 | """ 66 | torch.hub._validate_not_a_forked_repo=lambda a,b,c: True 67 | if 'v2' in model_type: 68 | model = torch.hub.load('facebookresearch/dinov2', model_type) 69 | elif 'dino' in model_type: 70 | model = torch.hub.load('facebookresearch/dino:main', model_type) 71 | else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). 72 | temp_model = timm.create_model(model_type, pretrained=True) 73 | model_type_dict = { 74 | 'vit_small_patch16_224': 'dino_vits16', 75 | 'vit_small_patch8_224': 'dino_vits8', 76 | 'vit_base_patch16_224': 'dino_vitb16', 77 | 'vit_base_patch8_224': 'dino_vitb8' 78 | } 79 | model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) 80 | temp_state_dict = temp_model.state_dict() 81 | del temp_state_dict['head.weight'] 82 | del temp_state_dict['head.bias'] 83 | model.load_state_dict(temp_state_dict) 84 | return model 85 | 86 | @staticmethod 87 | def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): 88 | """ 89 | Creates a method for position encoding interpolation. 90 | :param patch_size: patch size of the model. 91 | :param stride_hw: A tuple containing the new height and width stride respectively. 92 | :return: the interpolation method 93 | """ 94 | def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: 95 | npatch = x.shape[1] - 1 96 | N = self.pos_embed.shape[1] - 1 97 | if npatch == N and w == h: 98 | return self.pos_embed 99 | class_pos_embed = self.pos_embed[:, 0] 100 | patch_pos_embed = self.pos_embed[:, 1:] 101 | dim = x.shape[-1] 102 | # compute number of tokens taking stride into account 103 | w0 = 1 + (w - patch_size) // stride_hw[1] 104 | h0 = 1 + (h - patch_size) // stride_hw[0] 105 | assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and 106 | stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" 107 | # we add a small number to avoid floating point error in the interpolation 108 | # see discussion at https://github.com/facebookresearch/dino/issues/8 109 | w0, h0 = w0 + 0.1, h0 + 0.1 110 | patch_pos_embed = nn.functional.interpolate( 111 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 112 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 113 | mode='bicubic', 114 | align_corners=False, recompute_scale_factor=False 115 | ) 116 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 117 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 118 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 119 | 120 | return interpolate_pos_encoding 121 | 122 | @staticmethod 123 | def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: 124 | """ 125 | change resolution of model output by changing the stride of the patch extraction. 126 | :param model: the model to change resolution for. 127 | :param stride: the new stride parameter. 128 | :return: the adjusted model 129 | """ 130 | patch_size = model.patch_embed.patch_size 131 | if type(patch_size) == tuple: 132 | patch_size = patch_size[0] 133 | if stride == patch_size: # nothing to do 134 | return model 135 | 136 | stride = nn_utils._pair(stride) 137 | assert all([(patch_size // s_) * s_ == patch_size for s_ in 138 | stride]), f'stride {stride} should divide patch_size {patch_size}' 139 | 140 | # fix the stride 141 | model.patch_embed.proj.stride = stride 142 | # fix the positional encoding code 143 | model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) 144 | return model 145 | 146 | def preprocess(self, image_path: Union[str, Path], 147 | load_size: Union[int, Tuple[int, int]] = None, patch_size: int = 14) -> Tuple[torch.Tensor, Image.Image]: 148 | """ 149 | Preprocesses an image before extraction. 150 | :param image_path: path to image to be extracted. 151 | :param load_size: optional. Size to resize image before the rest of preprocessing. 152 | :return: a tuple containing: 153 | (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. 154 | (2) the pil image in relevant dimensions 155 | """ 156 | def divisible_by_num(num, dim): 157 | return num * (dim // num) 158 | pil_image = Image.open(image_path).convert('RGB') 159 | if load_size is not None: 160 | pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) 161 | 162 | width, height = pil_image.size 163 | new_width = divisible_by_num(patch_size, width) 164 | new_height = divisible_by_num(patch_size, height) 165 | pil_image = pil_image.resize((new_width, new_height), resample=Image.LANCZOS) 166 | 167 | prep = transforms.Compose([ 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=self.mean, std=self.std) 170 | ]) 171 | prep_img = prep(pil_image)[None, ...] 172 | return prep_img, pil_image 173 | 174 | def preprocess_pil(self, pil_image): 175 | """ 176 | Preprocesses an image before extraction. 177 | :param image_path: path to image to be extracted. 178 | :param load_size: optional. Size to resize image before the rest of preprocessing. 179 | :return: a tuple containing: 180 | (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. 181 | (2) the pil image in relevant dimensions 182 | """ 183 | prep = transforms.Compose([ 184 | transforms.ToTensor(), 185 | transforms.Normalize(mean=self.mean, std=self.std) 186 | ]) 187 | prep_img = prep(pil_image)[None, ...] 188 | return prep_img 189 | 190 | def _get_hook(self, facet: str): 191 | """ 192 | generate a hook method for a specific block and facet. 193 | """ 194 | if facet in ['attn', 'token']: 195 | def _hook(model, input, output): 196 | self._feats.append(output) 197 | return _hook 198 | 199 | if facet == 'query': 200 | facet_idx = 0 201 | elif facet == 'key': 202 | facet_idx = 1 203 | elif facet == 'value': 204 | facet_idx = 2 205 | else: 206 | raise TypeError(f"{facet} is not a supported facet.") 207 | 208 | def _inner_hook(module, input, output): 209 | input = input[0] 210 | B, N, C = input.shape 211 | qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) 212 | self._feats.append(qkv[facet_idx]) #Bxhxtxd 213 | return _inner_hook 214 | 215 | def _register_hooks(self, layers: List[int], facet: str) -> None: 216 | """ 217 | register hook to extract features. 218 | :param layers: layers from which to extract features. 219 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] 220 | """ 221 | for block_idx, block in enumerate(self.model.blocks): 222 | if block_idx in layers: 223 | if facet == 'token': 224 | self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) 225 | elif facet == 'attn': 226 | self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) 227 | elif facet in ['key', 'query', 'value']: 228 | self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) 229 | else: 230 | raise TypeError(f"{facet} is not a supported facet.") 231 | 232 | def _unregister_hooks(self) -> None: 233 | """ 234 | unregisters the hooks. should be called after feature extraction. 235 | """ 236 | for handle in self.hook_handlers: 237 | handle.remove() 238 | self.hook_handlers = [] 239 | 240 | def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: 241 | """ 242 | extract features from the model 243 | :param batch: batch to extract features for. Has shape BxCxHxW. 244 | :param layers: layer to extract. A number between 0 to 11. 245 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] 246 | :return : tensor of features. 247 | if facet is 'key' | 'query' | 'value' has shape Bxhxtxd 248 | if facet is 'attn' has shape Bxhxtxt 249 | if facet is 'token' has shape Bxtxd 250 | """ 251 | B, C, H, W = batch.shape 252 | self._feats = [] 253 | self._register_hooks(layers, facet) 254 | _ = self.model(batch) 255 | self._unregister_hooks() 256 | self.load_size = (H, W) 257 | self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) 258 | return self._feats 259 | 260 | def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: 261 | """ 262 | create a log-binned descriptor. 263 | :param x: tensor of features. Has shape Bxhxtxd. 264 | :param hierarchy: how many bin hierarchies to use. 265 | """ 266 | B = x.shape[0] 267 | num_bins = 1 + 8 * hierarchy 268 | 269 | bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) 270 | bin_x = bin_x.permute(0, 2, 1) 271 | bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) 272 | # Bx(dxh)xnum_patches[0]xnum_patches[1] 273 | sub_desc_dim = bin_x.shape[1] 274 | 275 | avg_pools = [] 276 | # compute bins of all sizes for all spatial locations. 277 | for k in range(0, hierarchy): 278 | # avg pooling with kernel 3**kx3**k 279 | win_size = 3 ** k 280 | avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) 281 | avg_pools.append(avg_pool(bin_x)) 282 | 283 | bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) 284 | for y in range(self.num_patches[0]): 285 | for x in range(self.num_patches[1]): 286 | part_idx = 0 287 | # fill all bins for a spatial location (y, x) 288 | for k in range(0, hierarchy): 289 | kernel_size = 3 ** k 290 | for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): 291 | for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): 292 | if i == y and j == x and k != 0: 293 | continue 294 | if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: 295 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 296 | :, :, i, j] 297 | else: # handle padding in a more delicate way than zero padding 298 | temp_i = max(0, min(i, self.num_patches[0] - 1)) 299 | temp_j = max(0, min(j, self.num_patches[1] - 1)) 300 | bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ 301 | :, :, temp_i, 302 | temp_j] 303 | part_idx += 1 304 | bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) 305 | # Bx1x(t-1)x(dxh) 306 | return bin_x 307 | 308 | def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', 309 | bin: bool = False, include_cls: bool = False) -> torch.Tensor: 310 | """ 311 | extract descriptors from the model 312 | :param batch: batch to extract descriptors for. Has shape BxCxHxW. 313 | :param layers: layer to extract. A number between 0 to 11. 314 | :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] 315 | :param bin: apply log binning to the descriptor. default is False. 316 | :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. 317 | """ 318 | assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. 319 | choose from ['key' | 'query' | 'value' | 'token'] """ 320 | self._extract_features(batch, [layer], facet) 321 | x = self._feats[0] 322 | if facet == 'token': 323 | x.unsqueeze_(dim=1) #Bx1xtxd 324 | if not include_cls: 325 | x = x[:, :, 1:, :] # remove cls token 326 | else: 327 | assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." 328 | if not bin: 329 | desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) 330 | else: 331 | desc = self._log_bin(x) 332 | return desc 333 | 334 | def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: 335 | """ 336 | extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer 337 | in of the CLS token. All values are then normalized to range between 0 and 1. 338 | :param batch: batch to extract saliency maps for. Has shape BxCxHxW. 339 | :return: a tensor of saliency maps. has shape Bxt-1 340 | """ 341 | assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." 342 | self._extract_features(batch, [11], 'attn') 343 | head_idxs = [0, 2, 4, 5] 344 | curr_feats = self._feats[0] #Bxhxtxt 345 | cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) 346 | temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] 347 | cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] 348 | return cls_attn_maps 349 | 350 | """ taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" 351 | def str2bool(v): 352 | if isinstance(v, bool): 353 | return v 354 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 355 | return True 356 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 357 | return False 358 | else: 359 | raise argparse.ArgumentTypeError('Boolean value expected.') 360 | 361 | if __name__ == "__main__": 362 | parser = argparse.ArgumentParser(description='Facilitate ViT Descriptor extraction.') 363 | parser.add_argument('--image_path', type=str, required=True, help='path of the extracted image.') 364 | parser.add_argument('--output_path', type=str, required=True, help='path to file containing extracted descriptors.') 365 | parser.add_argument('--load_size', default=224, type=int, help='load size of the input image.') 366 | parser.add_argument('--stride', default=4, type=int, help="""stride of first convolution layer. 367 | small stride -> higher resolution.""") 368 | parser.add_argument('--model_type', default='dino_vits8', type=str, 369 | help="""type of model to extract. 370 | Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | 371 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""") 372 | parser.add_argument('--facet', default='key', type=str, help="""facet to create descriptors from. 373 | options: ['key' | 'query' | 'value' | 'token']""") 374 | parser.add_argument('--layer', default=11, type=int, help="layer to create descriptors from.") 375 | parser.add_argument('--bin', default='False', type=str2bool, help="create a binned descriptor if True.") 376 | parser.add_argument('--patch_size', default=14, type=int, help="patch size of the model.") 377 | args = parser.parse_args() 378 | 379 | with torch.no_grad(): 380 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 381 | extractor = ViTExtractor(args.model_type, args.stride, device=device) 382 | image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size, args.patch_size) 383 | print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.") 384 | descriptors = extractor.extract_descriptors(image_batch.to(device), args.layer, args.facet, args.bin) 385 | print(f"Descriptors are of size: {descriptors.shape}") 386 | torch.save(descriptors, args.output_path) 387 | print(f"Descriptors saved to: {args.output_path}") -------------------------------------------------------------------------------- /Correspondence/sc_models/sd_dino/extractor_sd.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from contextlib import ExitStack 3 | import torch 4 | from mask2former.data.datasets.register_ade20k_panoptic import ADE20K_150_CATEGORIES 5 | from PIL import Image 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from detectron2.config import instantiate 9 | from detectron2.data import MetadataCatalog 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.config import LazyCall as L 12 | from detectron2.data import transforms as T 13 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 14 | from detectron2.evaluation import inference_context 15 | from detectron2.utils.env import seed_all_rng 16 | from detectron2.utils.visualizer import ColorMode, Visualizer, random_color 17 | from detectron2.utils.logger import setup_logger 18 | 19 | from odise import model_zoo 20 | from odise.checkpoint import ODISECheckpointer 21 | from odise.config import instantiate_odise 22 | from odise.data import get_openseg_labels 23 | from odise.modeling.wrapper import OpenPanopticInference 24 | 25 | from utils.utils_correspondence import resize 26 | import faiss 27 | 28 | COCO_THING_CLASSES = [ 29 | label 30 | for idx, label in enumerate(get_openseg_labels("coco_panoptic", True)) 31 | if COCO_CATEGORIES[idx]["isthing"] == 1 32 | ] 33 | COCO_THING_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 1] 34 | COCO_STUFF_CLASSES = [ 35 | label 36 | for idx, label in enumerate(get_openseg_labels("coco_panoptic", True)) 37 | if COCO_CATEGORIES[idx]["isthing"] == 0 38 | ] 39 | COCO_STUFF_COLORS = [c["color"] for c in COCO_CATEGORIES if c["isthing"] == 0] 40 | 41 | ADE_THING_CLASSES = [ 42 | label 43 | for idx, label in enumerate(get_openseg_labels("ade20k_150", True)) 44 | if ADE20K_150_CATEGORIES[idx]["isthing"] == 1 45 | ] 46 | ADE_THING_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 1] 47 | ADE_STUFF_CLASSES = [ 48 | label 49 | for idx, label in enumerate(get_openseg_labels("ade20k_150", True)) 50 | if ADE20K_150_CATEGORIES[idx]["isthing"] == 0 51 | ] 52 | ADE_STUFF_COLORS = [c["color"] for c in ADE20K_150_CATEGORIES if c["isthing"] == 0] 53 | 54 | LVIS_CLASSES = get_openseg_labels("lvis_1203", True) 55 | # use beautiful coco colors 56 | LVIS_COLORS = list( 57 | itertools.islice(itertools.cycle([c["color"] for c in COCO_CATEGORIES]), len(LVIS_CLASSES)) 58 | ) 59 | 60 | 61 | class StableDiffusionSeg(object): 62 | def __init__(self, model, metadata, aug, instance_mode=ColorMode.IMAGE): 63 | """ 64 | Args: 65 | model (nn.Module): 66 | metadata (MetadataCatalog): image metadata. 67 | instance_mode (ColorMode): 68 | parallel (bool): whether to run the model in different processes from visualization. 69 | Useful since the visualization logic can be slow. 70 | """ 71 | self.model = model 72 | self.metadata = metadata 73 | self.aug = aug 74 | self.cpu_device = torch.device("cpu") 75 | self.instance_mode = instance_mode 76 | 77 | def get_features(self, original_image, caption=None, pca=None): 78 | """ 79 | Args: 80 | original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). 81 | 82 | Returns: 83 | features (dict): 84 | the output of the model for one image only. 85 | """ 86 | height, width = original_image.shape[:2] 87 | aug_input = T.AugInput(original_image, sem_seg=None) 88 | self.aug(aug_input) 89 | image = aug_input.image 90 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 91 | 92 | inputs = {"image": image, "height": height, "width": width} 93 | if caption is not None: 94 | features = self.model.get_features([inputs],caption,pca=pca) 95 | else: 96 | features = self.model.get_features([inputs],pca=pca) 97 | return features 98 | 99 | def predict(self, original_image): 100 | """ 101 | Args: 102 | original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). 103 | 104 | Returns: 105 | predictions (dict): 106 | the output of the model for one image only. 107 | See :doc:`/tutorials/models` for details about the format. 108 | """ 109 | height, width = original_image.shape[:2] 110 | aug_input = T.AugInput(original_image, sem_seg=None) 111 | self.aug(aug_input) 112 | image = aug_input.image 113 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 114 | 115 | inputs = {"image": image, "height": height, "width": width} 116 | predictions = self.model([inputs])[0] 117 | return predictions 118 | 119 | def build_demo_classes_and_metadata(vocab, label_list): 120 | extra_classes = [] 121 | 122 | if vocab: 123 | for words in vocab.split(";"): 124 | extra_classes.append([word.strip() for word in words.split(",")]) 125 | extra_colors = [random_color(rgb=True, maximum=1) for _ in range(len(extra_classes))] 126 | 127 | demo_thing_classes = extra_classes 128 | demo_stuff_classes = [] 129 | demo_thing_colors = extra_colors 130 | demo_stuff_colors = [] 131 | 132 | if "COCO" in label_list: 133 | demo_thing_classes += COCO_THING_CLASSES 134 | demo_stuff_classes += COCO_STUFF_CLASSES 135 | demo_thing_colors += COCO_THING_COLORS 136 | demo_stuff_colors += COCO_STUFF_COLORS 137 | if "ADE" in label_list: 138 | demo_thing_classes += ADE_THING_CLASSES 139 | demo_stuff_classes += ADE_STUFF_CLASSES 140 | demo_thing_colors += ADE_THING_COLORS 141 | demo_stuff_colors += ADE_STUFF_COLORS 142 | if "LVIS" in label_list: 143 | demo_thing_classes += LVIS_CLASSES 144 | demo_thing_colors += LVIS_COLORS 145 | 146 | MetadataCatalog.pop("odise_demo_metadata", None) 147 | demo_metadata = MetadataCatalog.get("odise_demo_metadata") 148 | demo_metadata.thing_classes = [c[0] for c in demo_thing_classes] 149 | demo_metadata.stuff_classes = [ 150 | *demo_metadata.thing_classes, 151 | *[c[0] for c in demo_stuff_classes], 152 | ] 153 | demo_metadata.thing_colors = demo_thing_colors 154 | demo_metadata.stuff_colors = demo_thing_colors + demo_stuff_colors 155 | demo_metadata.stuff_dataset_id_to_contiguous_id = { 156 | idx: idx for idx in range(len(demo_metadata.stuff_classes)) 157 | } 158 | demo_metadata.thing_dataset_id_to_contiguous_id = { 159 | idx: idx for idx in range(len(demo_metadata.thing_classes)) 160 | } 161 | 162 | demo_classes = demo_thing_classes + demo_stuff_classes 163 | 164 | return demo_classes, demo_metadata 165 | 166 | import sys 167 | 168 | 169 | def load_model(device="cuda", config_path="Panoptic/odise_label_coco_50e.py", seed=42, diffusion_ver="v1-5", image_size=960, num_timesteps=100, block_indices=(2,5,8,11), decoder_only=True, encoder_only=False, resblock_only=False): 170 | cfg = model_zoo.get_config(config_path, trained=True) 171 | 172 | cfg.model.backbone.feature_extractor.init_checkpoint = "sd://"+diffusion_ver 173 | cfg.model.backbone.feature_extractor.steps = (num_timesteps,) 174 | cfg.model.backbone.feature_extractor.unet_block_indices = block_indices 175 | cfg.model.backbone.feature_extractor.encoder_only = encoder_only 176 | cfg.model.backbone.feature_extractor.decoder_only = decoder_only 177 | cfg.model.backbone.feature_extractor.resblock_only = resblock_only 178 | cfg.model.overlap_threshold = 0 179 | seed_all_rng(seed) 180 | 181 | cfg.dataloader.test.mapper.augmentations=[ 182 | L(T.ResizeShortestEdge)(short_edge_length=image_size, sample_style="choice", max_size=2560), 183 | ] 184 | dataset_cfg = cfg.dataloader.test 185 | 186 | aug = instantiate(dataset_cfg.mapper).augmentations 187 | 188 | model = instantiate_odise(cfg.model) 189 | model.to(device) 190 | ODISECheckpointer(model).load(cfg.train.init_checkpoint) 191 | 192 | return model, aug 193 | 194 | def inference(model, aug, image, vocab, label_list): 195 | 196 | demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list) 197 | with ExitStack() as stack: 198 | inference_model = OpenPanopticInference( 199 | model=model, 200 | labels=demo_classes, 201 | metadata=demo_metadata, 202 | semantic_on=False, 203 | instance_on=False, 204 | panoptic_on=True, 205 | ) 206 | stack.enter_context(inference_context(inference_model)) 207 | stack.enter_context(torch.no_grad()) 208 | 209 | demo = StableDiffusionSeg(inference_model, demo_metadata, aug) 210 | pred = demo.predict(np.array(image)) 211 | return (pred, demo_classes) 212 | 213 | def get_features(model, aug, image, vocab, label_list, caption=None, pca=False): 214 | 215 | demo_classes, demo_metadata = build_demo_classes_and_metadata(vocab, label_list) 216 | with ExitStack() as stack: 217 | inference_model = OpenPanopticInference( 218 | model=model, 219 | labels=demo_classes, 220 | metadata=demo_metadata, 221 | semantic_on=False, 222 | instance_on=False, 223 | panoptic_on=True, 224 | ) 225 | stack.enter_context(inference_context(inference_model)) 226 | stack.enter_context(torch.no_grad()) 227 | 228 | demo = StableDiffusionSeg(inference_model, demo_metadata, aug) 229 | if caption is not None: 230 | features = demo.get_features(np.array(image), caption, pca=pca) 231 | else: 232 | features = demo.get_features(np.array(image), pca=pca) 233 | return features 234 | 235 | 236 | def pca_process(features): 237 | # Get the feature tensors 238 | size_s5=features['s5'].shape[-1] 239 | size_s4=features['s4'].shape[-1] 240 | size_s3=features['s3'].shape[-1] 241 | 242 | s5 = features['s5'].reshape(features['s5'].shape[0], features['s5'].shape[1], -1) 243 | s4 = features['s4'].reshape(features['s4'].shape[0], features['s4'].shape[1], -1) 244 | s3 = features['s3'].reshape(features['s3'].shape[0], features['s3'].shape[1], -1) 245 | 246 | # Define the target dimensions 247 | target_dims = {'s5': 128, 's4': 128, 's3': 128} 248 | 249 | # Apply PCA to each tensor using Faiss CPU 250 | for name, tensor in zip(['s5', 's4', 's3'], [s5, s4, s3]): 251 | target_dim = target_dims[name] 252 | 253 | # Transpose the tensor so that the last dimension is the number of features 254 | tensor = tensor.permute(0, 2, 1) 255 | 256 | # # Norm the tensor 257 | # tensor = tensor / tensor.norm(dim=-1, keepdim=True) 258 | 259 | # Initialize a Faiss PCA object 260 | pca = faiss.PCAMatrix(tensor.shape[-1], target_dim) 261 | 262 | # Train the PCA object 263 | pca.train(tensor[0].cpu().numpy()) 264 | 265 | # Apply PCA to the data 266 | transformed_tensor_np = pca.apply(tensor[0].cpu().numpy()) 267 | 268 | # Convert the transformed data back to a tensor 269 | transformed_tensor = torch.tensor(transformed_tensor_np, device=tensor.device).unsqueeze(0) 270 | 271 | # Store the transformed tensor in the features dictionary 272 | features[name] = transformed_tensor 273 | 274 | # Reshape the tensors back to their original shapes 275 | features['s5'] = features['s5'].permute(0, 2, 1).reshape(features['s5'].shape[0], -1, size_s5, size_s5) 276 | features['s4'] = features['s4'].permute(0, 2, 1).reshape(features['s4'].shape[0], -1, size_s4, size_s4) 277 | features['s3'] = features['s3'].permute(0, 2, 1).reshape(features['s3'].shape[0], -1, size_s3, size_s3) 278 | # Upsample s5 spatially by a factor of 2 279 | upsampled_s5 = torch.nn.functional.interpolate(features['s5'], scale_factor=2, mode='bilinear', align_corners=False) 280 | 281 | # Concatenate upsampled_s5 and s4 to create a new s5 282 | features['s5'] = torch.cat((upsampled_s5, features['s4']), dim=1) 283 | 284 | # Set s3 as the new s4 285 | features['s4'] = features['s3'] 286 | 287 | # Remove s3 from the features dictionary 288 | del features['s3'] 289 | 290 | return features 291 | 292 | 293 | def process_features_and_mask(model, aug, image, category=None, input_text=None, mask=True, pca=False, raw=False): 294 | 295 | input_image = image 296 | caption = input_text 297 | vocab = "" 298 | label_list = ["COCO"] 299 | category_convert_dict={ 300 | 'aeroplane':'airplane', 301 | 'motorbike':'motorcycle', 302 | 'pottedplant':'potted plant', 303 | 'tvmonitor':'tv', 304 | } 305 | if type(category) is not list and category in category_convert_dict: 306 | category=category_convert_dict[category] 307 | elif type(category) is list: 308 | category=[category_convert_dict[cat] if cat in category_convert_dict else cat for cat in category] 309 | features = get_features(model, aug, input_image, vocab, label_list, caption, pca=(pca or raw)) 310 | if pca: 311 | features = pca_process(features) 312 | if raw: 313 | return features 314 | features_gether_s4_s5 = torch.cat([features['s4'], F.interpolate(features['s5'], size=(features['s4'].shape[-2:]), mode='bilinear')], dim=1) 315 | 316 | if mask: 317 | (pred,classes) =inference(model, aug, input_image, vocab, label_list) 318 | seg_map=pred['panoptic_seg'][0] 319 | target_mask_id = [] 320 | for item in pred['panoptic_seg'][1]: 321 | item['category_name']=classes[item['category_id']] 322 | if category in item['category_name']: 323 | target_mask_id.append(item['id']) 324 | resized_seg_map_s4 = F.interpolate(seg_map.unsqueeze(0).unsqueeze(0).float(), 325 | size=(features['s4'].shape[-2:]), mode='nearest') 326 | # to do adjust size 327 | binary_seg_map = torch.zeros_like(resized_seg_map_s4) 328 | for i in target_mask_id: 329 | binary_seg_map += (resized_seg_map_s4 == i).float() 330 | if len(target_mask_id) == 0 or binary_seg_map.sum() < 6: 331 | binary_seg_map = torch.ones_like(resized_seg_map_s4) 332 | features_gether_s4_s5 = features_gether_s4_s5 * binary_seg_map 333 | # set where mask is 0 to inf 334 | features_gether_s4_s5[(binary_seg_map == 0).repeat(1,features_gether_s4_s5.shape[1],1,1)] = -1 335 | 336 | return features_gether_s4_s5 337 | 338 | def get_mask(model, aug, image, category=None, input_text=None): 339 | model.backbone.feature_extractor.decoder_only = False 340 | model.backbone.feature_extractor.encoder_only = False 341 | model.backbone.feature_extractor.resblock_only = False 342 | input_image = image 343 | caption = input_text 344 | vocab = "" 345 | label_list = ["COCO"] 346 | category_convert_dict={ 347 | 'aeroplane':'airplane', 348 | 'motorbike':'motorcycle', 349 | 'pottedplant':'potted plant', 350 | 'tvmonitor':'tv', 351 | } 352 | if type(category) is not list and category in category_convert_dict: 353 | category=category_convert_dict[category] 354 | elif type(category) is list: 355 | category=[category_convert_dict[cat] if cat in category_convert_dict else cat for cat in category] 356 | 357 | (pred,classes) =inference(model, aug, input_image, vocab, label_list) 358 | seg_map=pred['panoptic_seg'][0] 359 | target_mask_id = [] 360 | for item in pred['panoptic_seg'][1]: 361 | item['category_name']=classes[item['category_id']] 362 | if type(category) is list: 363 | for cat in category: 364 | if cat in item['category_name']: 365 | target_mask_id.append(item['id']) 366 | else: 367 | if category in item['category_name']: 368 | target_mask_id.append(item['id']) 369 | resized_seg_map_s4 = seg_map.float() 370 | binary_seg_map = torch.zeros_like(resized_seg_map_s4) 371 | for i in target_mask_id: 372 | binary_seg_map += (resized_seg_map_s4 == i).float() 373 | if len(target_mask_id) == 0 or binary_seg_map.sum() < 6: 374 | binary_seg_map = torch.ones_like(resized_seg_map_s4) 375 | 376 | return binary_seg_map 377 | 378 | if __name__ == "__main__": 379 | image_path = sys.argv[1] 380 | try: 381 | input_text = sys.argv[2] 382 | except: 383 | input_text = None 384 | 385 | model, aug = load_model() 386 | img_size = 960 387 | image = Image.open(image_path).convert('RGB') 388 | image = resize(image, img_size, resize=True, to_pil=True) 389 | 390 | features = process_features_and_mask(model, aug, image, category=input_text, pca=False, raw=True) 391 | features = features['s4'] # save the features of layer 5 392 | 393 | # save the features 394 | np.save(image_path[:-4]+'.npy', features.cpu().numpy()) -------------------------------------------------------------------------------- /Correspondence/sc_models/sd_dino/get_cor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from extractor_sd import process_features_and_mask 6 | from sc_models.sd_dino.cor_utils import resize, pairwise_sim, co_pca, chunk_cosine_sim 7 | import numpy as np 8 | import torch.nn.functional as F 9 | torch.backends.cudnn.benchmark = True 10 | 11 | MASK = False 12 | SAMPLE = 20 13 | TOTAL_SAVE_RESULT = 5 14 | BBOX_THRE = True 15 | VER = 'v1-5' 16 | CO_PCA = True 17 | CO_PCA_DINO = False 18 | PCA_DIMS = [256, 256, 256] 19 | SIZE = 960 20 | EDGE_PAD = False 21 | 22 | FUSE_DINO = True 23 | ONLY_DINO = False 24 | DINOV2 = True 25 | MODEL_SIZE = 'base' 26 | TEXT_INPUT = False 27 | 28 | SEED = 42 29 | WEIGHT = [1, 1, 1, 1, 1] # corresponde to three groups for the sd features, and one group for the dino features 30 | PASCAL = False 31 | RAW = False 32 | 33 | @torch.no_grad() 34 | def get_cor_pairs(model, aug, extractor, src_image, trg_image, src_points, src_prompt, trg_prompt, dist='l2', transpose_img_func=lambda x:x, transpose_pts_func = lambda x, y: (x, y), device='cuda'): 35 | sd_size = 960 36 | dino_size = 840 if DINOV2 else 224 if ONLY_DINO else 480 37 | model_dict={'small':'dinov2_vits14', 38 | 'base':'dinov2_vitb14', 39 | 'large':'dinov2_vitl14', 40 | 'giant':'dinov2_vitg14'} 41 | 42 | model_type = model_dict[MODEL_SIZE] if DINOV2 else 'dino_vits8' 43 | layer = 11 if DINOV2 else 9 44 | if 'l' in model_type: 45 | layer = 23 46 | elif 'g' in model_type: 47 | layer = 39 48 | facet = 'token' if DINOV2 else 'key' 49 | stride = 14 if DINOV2 else 4 if ONLY_DINO else 8 50 | # indiactor = 'v2' if DINOV2 else 'v1' 51 | # model_size = model_type.split('vit')[-1] 52 | 53 | patch_size = extractor.model.patch_embed.patch_size[0] if DINOV2 else extractor.model.patch_embed.patch_size 54 | num_patches = int(patch_size / stride * (dino_size // patch_size - 1) + 1) 55 | 56 | # Load src image 57 | src_image = transpose_img_func(Image.open(src_image).convert('RGB')) 58 | src_w, src_h = src_image.size 59 | src_sd_input = resize(src_image, sd_size, resize=True, to_pil=True, edge=EDGE_PAD) 60 | src_dino_input = resize(src_image, dino_size, resize=True, to_pil=True, edge=EDGE_PAD) 61 | src_points = [transpose_pts_func(x, y) for x, y in src_points] 62 | src_x_scale, src_y_scale = dino_size / src_w, dino_size / src_h 63 | src_points = torch.Tensor([[int(np.round(x * src_x_scale)), int(np.round(y * src_y_scale))] for (x, y) in src_points]) 64 | # Get patch index for the keypoints 65 | src_y, src_x = src_points[:, 1].numpy(), src_points[:, 0].numpy() 66 | src_y_patch = (num_patches / dino_size * src_y).astype(np.int32) 67 | src_x_patch = (num_patches / dino_size * src_x).astype(np.int32) 68 | src_patch_idx = num_patches * src_y_patch + src_x_patch 69 | 70 | # Load trg image 71 | trg_image = Image.open(trg_image).convert('RGB') 72 | trg_w, trg_h = trg_image.size 73 | trg_sd_input = resize(trg_image, sd_size, resize=True, to_pil=True, edge=EDGE_PAD) 74 | trg_dino_input = resize(trg_image, dino_size, resize=True, to_pil=True, edge=EDGE_PAD) 75 | trg_x_scale, trg_y_scale = dino_size / trg_w, dino_size / trg_h 76 | 77 | if not CO_PCA: 78 | if not ONLY_DINO: 79 | src_desc = process_features_and_mask(model, aug, src_sd_input, input_text=src_prompt, mask=False).reshape(1,1,-1, num_patches**2).permute(0,1,3,2) 80 | trg_desc = process_features_and_mask(model, aug, trg_sd_input, input_text=trg_prompt, mask=False).reshape(1,1,-1, num_patches**2).permute(0,1,3,2) 81 | if FUSE_DINO: 82 | src_batch = extractor.preprocess_pil(src_dino_input) 83 | src_desc_dino = extractor.extract_descriptors(src_batch.to(device), layer, facet) 84 | trg_batch = extractor.preprocess_pil(trg_dino_input) 85 | trg_desc_dino = extractor.extract_descriptors(trg_batch.to(device), layer, facet) 86 | 87 | else: 88 | if not ONLY_DINO: 89 | features1 = process_features_and_mask(model, aug, src_sd_input, input_text=src_prompt, mask=False, raw=True) 90 | features2 = process_features_and_mask(model, aug, trg_sd_input, input_text=trg_prompt, mask=False, raw=True) 91 | if not RAW: 92 | processed_features1, processed_features2 = co_pca(features1, features2, PCA_DIMS) 93 | else: 94 | if WEIGHT[0]: 95 | processed_features1 = features1['s5'] 96 | processed_features2 = features2['s5'] 97 | elif WEIGHT[1]: 98 | processed_features1 = features1['s4'] 99 | processed_features2 = features2['s4'] 100 | elif WEIGHT[2]: 101 | processed_features1 = features1['s3'] 102 | processed_features2 = features2['s3'] 103 | elif WEIGHT[3]: 104 | processed_features1 = features1['s2'] 105 | processed_features2 = features2['s2'] 106 | else: 107 | raise NotImplementedError 108 | # rescale the features 109 | processed_features1 = F.interpolate(processed_features1, size=(num_patches, num_patches), mode='bilinear', align_corners=False) 110 | processed_features2 = F.interpolate(processed_features2, size=(num_patches, num_patches), mode='bilinear', align_corners=False) 111 | 112 | src_desc = processed_features1.reshape(1, 1, -1, num_patches**2).permute(0,1,3,2) 113 | trg_desc = processed_features2.reshape(1, 1, -1, num_patches**2).permute(0,1,3,2) 114 | if FUSE_DINO: 115 | src_batch = extractor.preprocess_pil(src_dino_input) 116 | src_desc_dino = extractor.extract_descriptors(src_batch.to(device), layer, facet) 117 | trg_batch = extractor.preprocess_pil(trg_dino_input) 118 | trg_desc_dino = extractor.extract_descriptors(trg_batch.to(device), layer, facet) 119 | 120 | if CO_PCA_DINO: 121 | cat_desc_dino = torch.cat((src_desc_dino, trg_desc_dino), dim=2).squeeze() # (1, 1, num_patches**2, dim) 122 | mean = torch.mean(cat_desc_dino, dim=0, keepdim=True) 123 | centered_features = cat_desc_dino - mean 124 | U, S, V = torch.pca_lowrank(centered_features, q=CO_PCA_DINO) 125 | reduced_features = torch.matmul(centered_features, V[:, :CO_PCA_DINO]) # (t_x+t_y)x(d) 126 | processed_co_features = reduced_features.unsqueeze(0).unsqueeze(0) 127 | src_desc_dino = processed_co_features[:, :, :src_desc_dino.shape[2], :] 128 | trg_desc_dino = processed_co_features[:, :, src_desc_dino.shape[2]:, :] 129 | 130 | if not ONLY_DINO and not RAW: # reweight different layers of sd 131 | 132 | src_desc[...,:PCA_DIMS[0]]*=WEIGHT[0] 133 | src_desc[...,PCA_DIMS[0]:PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[1] 134 | src_desc[...,PCA_DIMS[1]+PCA_DIMS[0]:PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[2] 135 | 136 | trg_desc[...,:PCA_DIMS[0]]*=WEIGHT[0] 137 | trg_desc[...,PCA_DIMS[0]:PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[1] 138 | trg_desc[...,PCA_DIMS[1]+PCA_DIMS[0]:PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[2] 139 | 140 | if 'l1' in dist or 'l2' in dist or dist == 'plus_norm': 141 | # normalize the features 142 | src_desc = src_desc / src_desc.norm(dim=-1, keepdim=True) 143 | trg_desc = trg_desc / trg_desc.norm(dim=-1, keepdim=True) 144 | src_desc_dino = src_desc_dino / src_desc_dino.norm(dim=-1, keepdim=True) 145 | trg_desc_dino = trg_desc_dino / trg_desc_dino.norm(dim=-1, keepdim=True) 146 | 147 | if FUSE_DINO and not ONLY_DINO and dist!='plus' and dist!='plus_norm': 148 | # cat two features together 149 | src_desc = torch.cat((src_desc, src_desc_dino), dim=-1) 150 | trg_desc = torch.cat((trg_desc, trg_desc_dino), dim=-1) 151 | if not RAW: 152 | # reweight sd and dino 153 | src_desc[...,:PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[3] 154 | src_desc[...,PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]:]*=WEIGHT[4] 155 | trg_desc[...,:PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]]*=WEIGHT[3] 156 | trg_desc[...,PCA_DIMS[2]+PCA_DIMS[1]+PCA_DIMS[0]:]*=WEIGHT[4] 157 | 158 | elif dist=='plus' or dist=='plus_norm': 159 | src_desc = src_desc + src_desc_dino 160 | trg_desc = trg_desc + trg_desc_dino 161 | dist='cos' 162 | 163 | if ONLY_DINO: 164 | src_desc = src_desc_dino 165 | trg_desc = trg_desc_dino 166 | 167 | # Get similarity matrix 168 | if dist == 'cos': 169 | sim_1_to_2 = chunk_cosine_sim(src_desc, trg_desc).squeeze() 170 | elif dist == 'l2': 171 | sim_1_to_2 = pairwise_sim(src_desc, trg_desc, p=2).squeeze() 172 | elif dist == 'l1': 173 | sim_1_to_2 = pairwise_sim(src_desc, trg_desc, p=1).squeeze() 174 | elif dist == 'l2_norm': 175 | sim_1_to_2 = pairwise_sim(src_desc, trg_desc, p=2, normalize=True).squeeze() 176 | elif dist == 'l1_norm': 177 | sim_1_to_2 = pairwise_sim(src_desc, trg_desc, p=1, normalize=True).squeeze() 178 | else: 179 | raise ValueError('Unknown distance metric') 180 | 181 | # Get nearest neighors 182 | nn_1_to_2 = torch.argmax(sim_1_to_2[src_patch_idx], dim=1) 183 | max_sim = torch.max(sim_1_to_2[src_patch_idx], dim=1)[0].detach().cpu().numpy() 184 | 185 | nn_y_patch, nn_x_patch = nn_1_to_2 // num_patches, nn_1_to_2 % num_patches 186 | nn_x = (nn_x_patch - 1) * stride + stride + patch_size // 2 - .5 187 | nn_y = (nn_y_patch - 1) * stride + stride + patch_size // 2 - .5 188 | trg_points = torch.stack([nn_x, nn_y]).permute(1, 0).cpu().numpy() 189 | trg_points = [[int(np.round(x / trg_x_scale)), int(np.round(y / trg_y_scale))] for (x, y) in trg_points] 190 | 191 | return trg_points, src_points, None, src_dino_input, max_sim 192 | 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Robo-ABC: Affordance Generalization Beyond Categories via Semantic Correspondence for Robot Manipulation 3 | 4 | ECCV 2024 5 |

6 | 7 | This is the official repository of [Robo-ABC: Affordance Generalization Beyond Categories via Semantic Correspondence for Robot Manipulation](https://arxiv.org/pdf/2401.07487). 8 | 9 | 10 |
11 | 12 |
13 | 14 | # 🍒Pipeline 15 | 16 |
17 | 18 |
19 | 20 | # TODO 21 | - [x] Release the semantic correspondence method code. 22 | - [x] Release the retriever code. 23 | - [ ] Release the affordance memory extraction code. -------------------------------------------------------------------------------- /Retrieve/Feature_extraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | from tqdm import tqdm 5 | from PIL import Image 6 | 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | 11 | from transformers import CLIPModel, CLIPProcessor 12 | 13 | MODEL_NAME = "/home/ /.cache/huggingface/hub/model--openai--clip-vit-base-patch32" 14 | model = CLIPModel.from_pretrained(MODEL_NAME) 15 | processor = CLIPProcessor.from_pretrained(MODEL_NAME) 16 | 17 | def get_labels(files): 18 | labels = [] 19 | for file_path in files: 20 | directory = os.path.dirname(file_path) 21 | label = os.path.basename(directory) 22 | if label not in labels: 23 | labels.append(label) 24 | return labels 25 | 26 | def list_files(dataset_path): 27 | images = [] 28 | valid_images = [".jpg",".gif",".png",".jpeg"] 29 | for root, _, files in os.walk(dataset_path): 30 | for name in files: 31 | if os.path.splitext(name)[1].lower() in valid_images: 32 | images.append(os.path.join(root, name)) 33 | return images 34 | 35 | class CustomImageDataset(Dataset): 36 | def __init__(self, img_dir): 37 | self.img_dir = img_dir 38 | self.images = list_files(self.img_dir) 39 | self.transform = transforms.Compose([ 40 | transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 44 | ]) 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | 49 | def __getitem__(self, idx): 50 | img_path = self.images[idx] 51 | image = Image.open(img_path).convert("RGB") 52 | if self.transform: 53 | image = self.transform(image) 54 | label = os.path.basename(os.path.dirname(img_path)) # get the parent directory name 55 | return image, img_path 56 | 57 | 58 | dir_path = "" 59 | 60 | # Iterate over each subdirectory in the main directory 61 | for sub_dir in os.listdir(dir_path): 62 | full_sub_dir_path = os.path.join(dir_path, sub_dir) 63 | if os.path.isdir(full_sub_dir_path): 64 | dataset = CustomImageDataset(full_sub_dir_path) 65 | train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True,) 66 | 67 | final_img_features = [] 68 | final_img_filepaths = [] 69 | for image_tensors, file_paths in tqdm(train_dataloader): 70 | try: 71 | image_features = model.get_image_features(image_tensors) #512 72 | image_features /= image_features.norm(dim=-1, keepdim=True) 73 | image_features = image_features.tolist() 74 | final_img_features.extend(image_features) 75 | final_img_filepaths.extend((list(file_paths))) 76 | except Exception as e: 77 | print("Exception occurred: ",e) 78 | break 79 | 80 | # Create a unique h5 filename for each sub-directory 81 | h5_filename = f"/home/ /shared/biggest/{sub_dir}_features.h5" 82 | with h5py.File(h5_filename, 'w') as h5f: 83 | h5f.create_dataset(f"{sub_dir}_features", data= np.array(final_img_features)) 84 | # to save file names strings in byte format. 85 | h5f.create_dataset(f"{sub_dir}_filenames", data= np.array(final_img_filepaths, dtype=object)) -------------------------------------------------------------------------------- /Retrieve/Retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import h5py 4 | import faiss 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from torchvision import transforms 9 | from transformers import CLIPModel, CLIPProcessor 10 | 11 | MODEL_NAME = "" 12 | model = CLIPModel.from_pretrained(MODEL_NAME) 13 | processor = CLIPProcessor.from_pretrained(MODEL_NAME) 14 | 15 | transform = transforms.Compose([ 16 | transforms.Resize(256), 17 | transforms.CenterCrop(224), 18 | transforms.Lambda(lambda image: image.convert("RGB")), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | 23 | with h5py.File('.h5', 'r') as h5f: 24 | all_features = np.array(h5f['all_features']) 25 | all_filenames = np.array(h5f['all_filenames']) 26 | 27 | faiss_index = faiss.IndexFlatIP(all_features.shape[1]) 28 | faiss_index.add(all_features) 29 | 30 | folder_path = "" 31 | txt_folder_path = "" 32 | 33 | for subdir, dirs, files in os.walk(folder_path): 34 | for file in files: 35 | if file.endswith((".png", ".jpg", ".jpeg")): 36 | img_path = os.path.join(subdir, file) 37 | image = Image.open(img_path) 38 | t_image = transform(image) 39 | inputs = processor(images=t_image, return_tensors="pt") 40 | 41 | query_features = model(**inputs).pooler_output 42 | query_features /= query_features.norm(dim=-1, keepdim=True) 43 | query_features = query_features.detach().numpy() 44 | 45 | K_neighbours = 5 46 | distances, indices = faiss_index.search(query_features, K_neighbours) 47 | 48 | for index in range(K_neighbours): 49 | similar_image_filename = all_filenames[indices[0][index]] 50 | similar_image_filename_str = similar_image_filename.decode('utf-8') 51 | 52 | # Add suffix to filename 53 | filename, ext = os.path.splitext(similar_image_filename_str) 54 | new_filename = f'{filename}_top{index+1}{ext}' 55 | new_path = os.path.join(subdir, os.path.basename(new_filename)) 56 | shutil.copy(similar_image_filename_str, new_path) 57 | 58 | txt_filename_without_path = os.path.splitext(os.path.basename(similar_image_filename_str))[0] + '.txt' 59 | for txt_subdir, txt_dirs, txt_files in os.walk(txt_folder_path): 60 | if txt_filename_without_path in txt_files: 61 | txt_filename = os.path.join(txt_subdir, txt_filename_without_path) 62 | 63 | # Add suffix to txt file 64 | txt_filename_without_ext, txt_ext = os.path.splitext(txt_filename) 65 | new_txt_filename = f'{txt_filename_without_ext}_top{index+1}{txt_ext}' 66 | new_txt_path = os.path.join(subdir, os.path.basename(new_txt_filename)) 67 | shutil.copy(txt_filename, new_txt_path) -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/Robo-ABC/0ce7ac90d0ce61099988690f77f19785a388bb20/assets/pipeline.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/Robo-ABC/0ce7ac90d0ce61099988690f77f19785a388bb20/assets/teaser.png --------------------------------------------------------------------------------