├── .gitignore ├── README.md ├── assets └── teaser_nohuman.png ├── datasets.py ├── main_most.py ├── networks.py ├── object_discovery.py ├── requirements.txt ├── saliency_utils.py └── visualizations.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sh 2 | *.out 3 | *.err 4 | *__pycache__/* 5 | dino 6 | datasets 7 | saliency 8 | most_* 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official Implementation of MOST (ICCV 2023 Oral) 2 | [Sai Saketh Rambhatla](https://rssaketh.github.io), [Ishan Misra](https://imisra.github.io), [Rama Chellappa](https://engineering.jhu.edu/faculty/rama-chellappa/), [Abhinav Shrivastava](https://www.cs.umd.edu/~abhinav/) 3 | 4 | ![alt text](https://github.com/rssaketh/MOST/blob/main/assets/teaser_nohuman.png?raw=true) 5 | 6 | This is the official repository for the work **MOST**:**M**ultiple **O**bject localization using **S**elf-supervised **T**ransformers for object discovery, accepted as an Oral at ICCV 2023, see [Project Page](rssaketh.github.io/most). 7 | 8 | ## Introduction 9 | We tackle the challenging task of unsupervised object localization in this work. Recently, transformers trained with self-supervised learning have been shown to exhibit object localization properties without being trained for this task. In this work, we present Multiple Object localization with Self-supervised Transformers (MOST) that uses features of transformers trained using self-supervised learning to localize multiple objects in real world images. MOST analyzes the similarity maps of the features using box counting; a fractal analysis tool to identify tokens lying on foreground patches. The identified tokens are then clustered together, and tokens of each cluster are used to generate bounding boxes on foreground regions. Unlike recent state-of-the-art object localization methods, MOST can localize multiple objects per image and outperforms SOTA algorithms on several object localization and discovery benchmarks on PASCAL-VOC 07, 12 and COCO20k datasets. Additionally, we show that MOST can be used for self-supervised pre-training of object detectors, and yields consistent improvements on fully, semi-supervised object detection and unsupervised region proposal generation. 10 | 11 | 12 | ## Installation instructions 13 | - This code was tested using Python3.7 14 | 15 | We recommend using conda to create a new environment. 16 | 17 | ``` 18 | conda create -n most python=3.7 19 | ``` 20 | 21 | Then activate the virtual environment 22 | 23 | ``` 24 | conda activate most 25 | ``` 26 | 27 | Install Pytorch 1.7.1 (CUDA 10.2) 28 | 29 | ``` 30 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch 31 | ``` 32 | 33 | - Install other requirements 34 | 35 | ``` 36 | pip install -r requirements.txt 37 | ``` 38 | - Install DINO 39 | ``` 40 | git clone https://github.com/facebookresearch/dino.git 41 | cd dino; 42 | touch __init__.py 43 | echo -e "import sys\nfrom os.path import dirname, join\nsys.path.insert(0, join(dirname(__file__), '.'))" >> __init__.py; cd ../; 44 | ``` 45 | ## MOST on a single image 46 | To apply MOST to an example image, run the following example 47 | ``` 48 | python main_most.py --image_path --visualize pred 49 | ``` 50 | Results are stored in the output directory given by parameter `output_dir`. 51 | 52 | ## Run MOST on datasets 53 | To run MOST on PASCAL VOC, COCO or other custom datasets, follow the dataset instructions of [LOST](https://github.com/valeoai/LOST#launching-lost-on-datasets). 54 | To launch MOST on PASCAL VOC 2007 and 2012 datasets, run 55 | ``` 56 | python main_most.py --dataset VOC07 --set trainval 57 | python main_most.py --dataset VOC12 --set trainval 58 | ``` 59 | For COCO dataset, run 60 | ``` 61 | python main_most.py --dataset COCO20k --set train 62 | ``` 63 | 64 | To run with different patch sizes or architectures, run 65 | ``` 66 | python main_most.py --dataset VOC07 --set trainval #VIT-S/16 67 | python main_most.py --dataset VOC07 --set trainval --patch_size 8 #VIT-S/8 68 | python main_most.py --dataset VOC07 --set trainval --arch vit_base #VIT-B/16 69 | ``` 70 | -------------------------------------------------------------------------------- /assets/teaser_nohuman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rssaketh/MOST/5ed0a81bba53e9f1e0408bf17620946a080a7767/assets/teaser_nohuman.png -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import torchvision 5 | import numpy as np 6 | import skimage.io 7 | import pdb 8 | import pandas as pd 9 | import pickle 10 | 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torchvision import transforms as pth_transforms 14 | from torch.utils.data import Dataset 15 | 16 | # Image transformation applied to all images 17 | transform = pth_transforms.Compose( 18 | [ 19 | pth_transforms.ToTensor(), 20 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 21 | ] 22 | ) 23 | 24 | 25 | class ImagenetDataset(Dataset): 26 | def __init__(self, root_path, image_path, transform=None): 27 | self.image_path = image_path 28 | self.root_path = root_path 29 | self.transform=transform 30 | self.images = pd.read_csv(os.path.join(self.root_path,'imagenet_images.csv'), 31 | header=None)[0].unique().tolist() 32 | if not self.images: 33 | raise Exception("Path to images {} invalid!!".format(image_path)) 34 | 35 | def __len__(self): 36 | return len(self.images) 37 | 38 | def __getitem__(self, idx): 39 | im_n = self.images[idx] 40 | img = Image.open(os.path.join(self.image_path, im_n + '.JPEG')).convert('RGB') 41 | if self.transform is not None: 42 | img = self.transform(img) 43 | return img, im_n 44 | 45 | 46 | 47 | class ECSSDDataset(Dataset): 48 | def __init__(self, image_path, transform=None): 49 | self.image_path = image_path 50 | self.transform=transform 51 | self.images = os.listdir(image_path) 52 | if not self.images: 53 | raise Exception("Path to images {} invalid!!".format(image_path)) 54 | 55 | def __len__(self): 56 | return len(self.images) 57 | 58 | def __getitem__(self, idx): 59 | im_n = self.images[idx] 60 | img = Image.open(os.path.join(self.image_path, im_n)).convert('RGB') 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | return img, im_n 64 | 65 | 66 | class CUBDataset(Dataset): 67 | def __init__(self, image_path, dataset_set='train', transform=None): 68 | self.image_path = os.path.join(image_path, 'images') 69 | self.image_list_file = os.path.join(image_path, 'images.txt') 70 | self.images = pd.read_csv(self.image_list_file, header=None, sep=" ", names=['id', 'path']) 71 | self.split = pd.read_csv(os.path.join(image_path, 72 | 'train_test_split.txt'), 73 | header=None, sep=" ", names=['id', 'split']) 74 | 75 | self.images = self.images.merge(self.split, on="id") 76 | split = 1 if dataset_set=='train' else 0 77 | self.images = self.images[self.images.split == split] 78 | self.transform=transform 79 | 80 | def __len__(self): 81 | return self.images.shape[0] 82 | 83 | def __getitem__(self, idx): 84 | im_n = self.images.iloc[idx] 85 | img = Image.open(os.path.join(self.image_path, im_n['path'])).convert('RGB') 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | return img, im_n['id'] 89 | 90 | 91 | class ImageDataset: 92 | def __init__(self, image_path): 93 | self.image_path = image_path 94 | self.name = image_path.split("/")[-1] 95 | 96 | # Read the image 97 | with open(image_path, "rb") as f: 98 | img = Image.open(f) 99 | img = img.convert("RGB") 100 | 101 | # Build a dataloader 102 | img = transform(img) 103 | self.dataloader = [[img, image_path]] 104 | 105 | def get_image_name(self, *args, **kwargs): 106 | return self.image_path.split("/")[-1].split(".")[0] 107 | 108 | def load_image(self, *args, **kwargs): 109 | return skimage.io.imread(self.image_path) 110 | 111 | class Dataset: 112 | def __init__(self, dataset_name, dataset_set, remove_hards): 113 | """ 114 | Build the dataloader 115 | """ 116 | 117 | self.dataset_name = dataset_name 118 | self.set = dataset_set 119 | 120 | if dataset_name == "VOC07": 121 | self.root_path = "datasets/VOC2007" 122 | self.year = "2007" 123 | elif dataset_name == "VOC12": 124 | self.root_path = "datasets/VOC2012" 125 | self.year = "2012" 126 | elif dataset_name == "COCO20k": 127 | self.year = "2014" 128 | self.root_path = f"datasets/COCO/images/{dataset_set}{self.year}" 129 | self.sel20k = 'datasets/coco_20k_filenames.txt' 130 | # JSON file constructed based on COCO train2014 gt 131 | self.all_annfile = "datasets/COCO/annotations/instances_train2014.json" 132 | self.annfile = "datasets/instances_train2014_sel20k.json" 133 | if not os.path.exists(self.annfile): 134 | select_coco_20k(self.sel20k, self.all_annfile) 135 | elif dataset_name == "COCO": 136 | self.year = "2014" 137 | self.root_path = f"datasets/COCO/images/{dataset_set}{self.year}" 138 | # JSON file constructed based on COCO train2014 gt 139 | self.all_annfile = "datasets/instances_train2017.json" 140 | self.annfile = "datasets/instances_train2017.json" 141 | elif dataset_name == "COCOminival": 142 | self.year = 2017 143 | self.root_path = f"/fs/vulcan-datasets/coco/images/val2017" 144 | self.all_annfile = "/fs/vulcan-datasets/coco/annotations/instances_val2017.json" 145 | self.annfile = "/fs/vulcan-datasets/coco/annotations/instances_val2017.json" 146 | elif dataset_name == "ECSSD": 147 | self.root_path = "saliency/data/ECSSD/images/" 148 | elif dataset_name == "DUTS": 149 | self.root_path = "saliency/data/DUTS-TE/DUTS-TE-Image/" 150 | elif dataset_name == "DUT-OMRON": 151 | self.root_path = "saliency/data/DUT-OMRON/DUT-OMRON-image/" 152 | else: 153 | raise ValueError("Unknown dataset.") 154 | if not os.path.exists(self.root_path): 155 | raise ValueError("Please follow the README to setup the datasets.") 156 | 157 | self.name = f"{self.dataset_name}_{self.set}" 158 | 159 | # Build the dataloader 160 | if "VOC" in dataset_name: 161 | self.dataloader = torchvision.datasets.VOCDetection( 162 | self.root_path, 163 | year=self.year, 164 | image_set=self.set, 165 | transform=transform, 166 | download=False, 167 | ) 168 | elif "COCO20k" == dataset_name: 169 | self.dataloader = torchvision.datasets.CocoDetection( 170 | self.root_path, annFile=self.annfile, transform=transform 171 | ) 172 | elif "COCO" == dataset_name: 173 | self.dataloader = torchvision.datasets.CocoDetection( 174 | self.root_path, annFile=self.annfile, transform=transform 175 | ) 176 | elif "LVIS" == dataset_name: 177 | self.dataloader = LVISDetection(self.root_path, 178 | annFile=self.annFile, 179 | transform=transform) 180 | elif "COCOminival" == dataset_name: 181 | self.dataloader = torchvision.datasets.CocoDetection( 182 | self.root_path, annFile=self.annfile, transform=transform) 183 | elif "ECSSD" == dataset_name: 184 | self.dataloader = ECSSDDataset(self.root_path, transform=transform) 185 | elif "DUTS" == dataset_name: 186 | self.dataloader = ECSSDDataset(self.root_path, transform=transform) 187 | elif "DUT-OMRON" == dataset_name: 188 | self.dataloader = ECSSDDataset(self.root_path, transform=transform) 189 | elif "CUB" == dataset_name: 190 | self.dataloader = CUBDataset(self.root_path, dataset_set=self.set, transform=transform) 191 | elif "Imagenet" == dataset_name: 192 | self.dataloader = ImagenetDataset(self.annfile, self.root_path, transform=transform) 193 | 194 | else: 195 | raise ValueError("Unknown dataset.") 196 | 197 | # Set hards images that are not included 198 | self.remove_hards = remove_hards 199 | self.hards = [] 200 | if remove_hards: 201 | self.name += f"-nohards" 202 | self.hards = self.get_hards() 203 | print(f"Nb images discarded {len(self.hards)}") 204 | 205 | def load_image(self, im_name): 206 | """ 207 | Load the image corresponding to the im_name 208 | """ 209 | if "VOC" in self.dataset_name: 210 | image = skimage.io.imread(f"datasets/VOC{self.year}/VOCdevkit/VOC{self.year}/JPEGImages/{im_name}") 211 | elif "COCO" in self.dataset_name: 212 | # im_path = self.path_20k[self.sel_20k.index(im_name)] 213 | filename = im_name.zfill(12)+'.jpg' 214 | image = skimage.io.imread(f"datasets/COCO/images/train2014/{filename}") 215 | elif self.dataset_name in ['ECSSD']: 216 | image = skimage.io.imread(f"saliency/data/ECSSD/images/{im_name}") 217 | elif self.dataset_name in ['DUTS']: 218 | image = skimage.io.imread(f"saliency/data/DUTS-TE/DUTS-TE-Image/{im_name}") 219 | elif self.dataset_name == 'DUT-OMRON': 220 | image = skimage.io.imread(f"saliency/data/DUT-OMRON/DUT-OMRON-image/{im_name}") 221 | elif self.dataset_name == 'CUB': 222 | image = skimage.io.imread(f"/fs/vulcan-datasets/CUB/CUB_200_2011/images/{im_name}") 223 | elif self.dataset_name == 'Imagenet': 224 | image = skimage.io.imread(f"/fs/vulcan-datasets/imagenet/val/{im_name}") 225 | 226 | else: 227 | raise ValueError("Unkown dataset.") 228 | return image 229 | 230 | def get_image_name(self, inp): 231 | """ 232 | Return the image name 233 | """ 234 | if "VOC" in self.dataset_name: 235 | im_name = inp["annotation"]["filename"] 236 | elif "COCO" in self.dataset_name: 237 | try: 238 | im_name = str(inp[0]["image_id"]) 239 | except: 240 | pdb.set_trace() 241 | elif "LVIS" in self.dataset_name: 242 | # for lvis return split/img_name 243 | try: 244 | im_name = str(inp[0]["image_id"]) 245 | except: 246 | pdb.set_trace() 247 | elif "ECSSD" in self.dataset_name: 248 | im_name = inp 249 | elif "DUTS" in self.dataset_name: 250 | im_name = inp 251 | elif "DUT-OMRON" in self.dataset_name: 252 | im_name = inp 253 | elif "CUB" in self.dataset_name: 254 | im_name = inp 255 | elif "Imagenet" in self.dataset_name: 256 | im_name = inp 257 | 258 | return im_name 259 | 260 | def extract_gt(self, targets, im_name): 261 | if "VOC" in self.dataset_name: 262 | return extract_gt_VOC(targets, remove_hards=self.remove_hards) 263 | elif "COCO" in self.dataset_name: 264 | return extract_gt_COCO(targets, remove_iscrowd=True) 265 | elif "LVIS" in self.dataset_name: 266 | return extract_gt_LVIS(targets, im_name) 267 | elif "ECSSD" in self.dataset_name: 268 | return extract_gt_ECSSD(targets, im_name) 269 | elif "DUTS" in self.dataset_name: 270 | return extract_gt_DUTS(targets, im_name) 271 | elif "DUT-OMRON" in self.dataset_name: 272 | return extract_gt_DUT_OMRON(targets, im_name) 273 | elif "CUB" in self.dataset_name: 274 | return extract_gt_CUB(targets, im_name) 275 | elif "Imagenet" in self.dataset_name: 276 | return extract_gt_Imagenet(targets, im_name) 277 | 278 | else: 279 | raise ValueError("Unknown dataset") 280 | 281 | def extract_classes(self): 282 | if "VOC" in self.dataset_name: 283 | cls_path = f"classes_{self.set}_{self.year}.txt" 284 | elif "COCO" in self.dataset_name: 285 | cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt" 286 | 287 | # Load if exists 288 | if os.path.exists(cls_path): 289 | all_classes = [] 290 | with open(cls_path, "r") as f: 291 | for line in f: 292 | all_classes.append(line.strip()) 293 | else: 294 | print("Extract all classes from the dataset") 295 | if "VOC" in self.dataset_name: 296 | all_classes = self.extract_classes_VOC() 297 | elif "COCO" in self.dataset_name: 298 | all_classes = self.extract_classes_COCO() 299 | 300 | with open(cls_path, "w") as f: 301 | for s in all_classes: 302 | f.write(str(s) + "\n") 303 | 304 | return all_classes 305 | 306 | def extract_classes_VOC(self): 307 | all_classes = [] 308 | for im_id, inp in enumerate(tqdm(self.dataloader)): 309 | objects = inp[1]["annotation"]["object"] 310 | 311 | for o in range(len(objects)): 312 | if objects[o]["name"] not in all_classes: 313 | all_classes.append(objects[o]["name"]) 314 | 315 | return all_classes 316 | 317 | def extract_classes_COCO(self): 318 | all_classes = [] 319 | for im_id, inp in enumerate(tqdm(self.dataloader)): 320 | objects = inp[1] 321 | 322 | for o in range(len(objects)): 323 | if objects[o]["category_id"] not in all_classes: 324 | all_classes.append(objects[o]["category_id"]) 325 | 326 | return all_classes 327 | 328 | def get_hards(self): 329 | hard_path = "datasets/hard_%s_%s_%s.txt" % (self.dataset_name, self.set, self.year) 330 | if os.path.exists(hard_path): 331 | hards = [] 332 | with open(hard_path, "r") as f: 333 | for line in f: 334 | hards.append(int(line.strip())) 335 | else: 336 | print("Discover hard images that should be discarded") 337 | 338 | if "VOC" in self.dataset_name: 339 | # set the hards 340 | hards = discard_hard_voc(self.dataloader) 341 | 342 | with open(hard_path, "w") as f: 343 | for s in hards: 344 | f.write(str(s) + "\n") 345 | 346 | return hards 347 | 348 | 349 | def discard_hard_voc(dataloader): 350 | hards = [] 351 | for im_id, inp in enumerate(tqdm(dataloader)): 352 | objects = inp[1]["annotation"]["object"] 353 | nb_obj = len(objects) 354 | 355 | hard = np.zeros(nb_obj) 356 | for i, o in enumerate(range(nb_obj)): 357 | hard[i] = ( 358 | 1 359 | if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1") 360 | else 0 361 | ) 362 | 363 | # all images with only truncated or difficult objects 364 | if np.sum(hard) == nb_obj: 365 | hards.append(im_id) 366 | return hards 367 | 368 | def extract_gt_ECSSD(targets, im_name): 369 | gt_path = "saliency/data/ECSSD/ground_truth_mask" 370 | im_name = im_name.split('.')[0] + '.png' 371 | if not os.path.isfile(os.path.join(gt_path, im_name)): 372 | raise Exception("Gt file {} not found".format(im_name)) 373 | img = np.asarray(Image.open(os.path.join(gt_path, im_name)).convert('L')) 374 | if np.unique(img).shape[0] > 2: 375 | img[img > 0] = 255 376 | return img, None 377 | 378 | def extract_gt_DUTS(targets, im_name): 379 | gt_path = "saliency/data/DUTS-TE/DUTS-TE-Mask/" 380 | im_name = im_name.split('.')[0] + '.png' 381 | if not os.path.isfile(os.path.join(gt_path, im_name)): 382 | raise Exception("Gt file {} not found".format(im_name)) 383 | img = np.asarray(Image.open(os.path.join(gt_path, im_name)).convert('L')) 384 | if np.unique(img).shape[0] > 2: 385 | img[img > 0] = 255 386 | return img, None 387 | 388 | def extract_gt_CUB(targets, im_name): 389 | gt_path = "/fs/vulcan-datasets/CUB/CUB_200_2011/" 390 | img_cls_label_file = os.path.join(gt_path, 'image_class_labels.txt') 391 | gt_file = os.path.join(gt_path, 'bounding_boxes.txt') 392 | gts = pd.read_csv(gt_file, header=None, sep=" ", names=["id", 393 | "x","y","w","h"]) 394 | labels = pd.read_csv(img_cls_label_file, header=None, sep=" ", names=["id", 395 | "label"]) 396 | 397 | cur_gt = gts[gts['id'] == im_name] 398 | cur_label = labels[labels['id'] == im_name]['label'] 399 | x1 = cur_gt['x'] 400 | y1 = cur_gt['y'] 401 | x2 = cur_gt['w'] + x1 402 | y2 = cur_gt['h'] + y1 403 | # pdb.set_trace() 404 | gt_bbxs = np.asarray([x1,y1,x2,y2]).reshape(1,-1) 405 | return gt_bbxs, labels 406 | 407 | def extract_gt_Imagenet(targets, im_name): 408 | gt_path = "/vulcanscratch/rssaketh/LOST/" 409 | box_file = os.path.join(gt_path, 'imagenet_box_annotations.pkl') 410 | boxes = pickle.load(open(box_file,'rb')) 411 | cbox = boxes[im_name] 412 | labels = os.path.dirname(im_name) 413 | # pdb.set_trace() 414 | gt_bbxs = np.asarray(cbox) 415 | return gt_bbxs, labels 416 | 417 | def extract_gt_DUT_OMRON(targets, im_name): 418 | gt_path = "saliency/data/DUT-OMRON/pixelwiseGT-new-PNG/" 419 | im_name = im_name.split('.')[0] + '.png' 420 | if not os.path.isfile(os.path.join(gt_path, im_name)): 421 | raise Exception("Gt file {} not found".format(im_name)) 422 | img = np.asarray(Image.open(os.path.join(gt_path, im_name)).convert('L')) 423 | if np.unique(img).shape[0] > 2: 424 | img[img > 0] = 255 425 | return img, None 426 | 427 | def extract_gt_LVIS(targets): 428 | objects = targets 429 | nb_obj = len(objects) 430 | 431 | gt_bbxs = [] 432 | gt_clss = [] 433 | for o in range(nb_obj): 434 | # Remove iscrowd boxes 435 | gt_cls = objects[o]["category_id"] 436 | gt_clss.append(gt_cls) 437 | bbx = objects[o]["bbox"] 438 | x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]] 439 | x1y1x2y2 = [int(round(x)) for x in x1y1x2y2] 440 | gt_bbxs.append(x1y1x2y2) 441 | 442 | return np.asarray(gt_bbxs), gt_clss 443 | 444 | 445 | def extract_gt_COCO(targets, remove_iscrowd=True): 446 | objects = targets 447 | nb_obj = len(objects) 448 | 449 | gt_bbxs = [] 450 | gt_clss = [] 451 | for o in range(nb_obj): 452 | # Remove iscrowd boxes 453 | if remove_iscrowd and objects[o]["iscrowd"] == 1: 454 | continue 455 | gt_cls = objects[o]["category_id"] 456 | gt_clss.append(gt_cls) 457 | bbx = objects[o]["bbox"] 458 | x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]] 459 | x1y1x2y2 = [int(round(x)) for x in x1y1x2y2] 460 | gt_bbxs.append(x1y1x2y2) 461 | 462 | return np.asarray(gt_bbxs), gt_clss 463 | 464 | 465 | def extract_gt_VOC(targets, remove_hards=False): 466 | objects = targets["annotation"]["object"] 467 | nb_obj = len(objects) 468 | 469 | gt_bbxs = [] 470 | gt_clss = [] 471 | for o in range(nb_obj): 472 | if remove_hards and ( 473 | objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1" 474 | ): 475 | continue 476 | gt_cls = objects[o]["name"] 477 | gt_clss.append(gt_cls) 478 | obj = objects[o]["bndbox"] 479 | x1y1x2y2 = [ 480 | int(obj["xmin"]), 481 | int(obj["ymin"]), 482 | int(obj["xmax"]), 483 | int(obj["ymax"]), 484 | ] 485 | # Original annotations are integers in the range [1, W or H] 486 | # Assuming they mean 1-based pixel indices (inclusive), 487 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 488 | # In coordinate space this is represented by (xmin=0, xmax=W) 489 | x1y1x2y2[0] -= 1 490 | x1y1x2y2[1] -= 1 491 | gt_bbxs.append(x1y1x2y2) 492 | 493 | return np.asarray(gt_bbxs), gt_clss 494 | 495 | 496 | def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): 497 | # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py 498 | # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 499 | box2 = box2.T 500 | 501 | # Get the coordinates of bounding boxes 502 | if x1y1x2y2: # x1, y1, x2, y2 = box1 503 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] 504 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] 505 | else: # transform from xywh to xyxy 506 | b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 507 | b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 508 | b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 509 | b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 510 | 511 | # Intersection area 512 | inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * ( 513 | torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1) 514 | ).clamp(0) 515 | 516 | # Union Area 517 | w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps 518 | w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps 519 | union = w1 * h1 + w2 * h2 - inter + eps 520 | 521 | iou = inter / union 522 | if GIoU or DIoU or CIoU: 523 | cw = torch.max(b1_x2, b2_x2) - torch.min( 524 | b1_x1, b2_x1 525 | ) # convex (smallest enclosing box) width 526 | ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height 527 | if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 528 | c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared 529 | rho2 = ( 530 | (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 531 | + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 532 | ) / 4 # center distance squared 533 | if DIoU: 534 | return iou - rho2 / c2 # DIoU 535 | elif ( 536 | CIoU 537 | ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 538 | v = (4 / math.pi ** 2) * torch.pow( 539 | torch.atan(w2 / h2) - torch.atan(w1 / h1), 2 540 | ) 541 | with torch.no_grad(): 542 | alpha = v / (v - iou + (1 + eps)) 543 | return iou - (rho2 / c2 + v * alpha) # CIoU 544 | else: # GIoU https://arxiv.org/pdf/1902.09630.pdf 545 | c_area = cw * ch + eps # convex area 546 | return iou - (c_area - union) / c_area # GIoU 547 | else: 548 | return iou # IoU 549 | 550 | def select_coco_20k(sel_file, all_annotations_file): 551 | print('Building COCO 20k dataset.') 552 | 553 | # load all annotations 554 | with open(all_annotations_file, "r") as f: 555 | train2014 = json.load(f) 556 | 557 | # load selected images 558 | with open(sel_file, "r") as f: 559 | sel_20k = f.readlines() 560 | sel_20k = [s.replace("\n", "") for s in sel_20k] 561 | im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k] 562 | 563 | new_anno = [] 564 | new_images = [] 565 | 566 | for i in tqdm(im20k): 567 | new_anno.extend( 568 | [a for a in train2014["annotations"] if a["image_id"] == int(i)] 569 | ) 570 | new_images.extend([a for a in train2014["images"] if a["id"] == int(i)]) 571 | 572 | train2014_20k = {} 573 | train2014_20k["images"] = new_images 574 | train2014_20k["annotations"] = new_anno 575 | train2014_20k["categories"] = train2014["categories"] 576 | 577 | with open("datasets/instances_train2014_sel20k.json", "w") as outfile: 578 | json.dump(train2014_20k, outfile) 579 | 580 | print('Done.') 581 | 582 | -------------------------------------------------------------------------------- /main_most.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import pickle 5 | import pdb 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | from tqdm import tqdm 12 | from PIL import Image 13 | 14 | from networks import get_model 15 | from datasets import ImageDataset, Dataset, bbox_iou 16 | from visualizations import visualize_predictions, visualize_map 17 | from object_discovery import most 18 | from saliency_utils import saliency_metrics 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser("Arguments for MOST") 22 | parser.add_argument( 23 | "--arch", 24 | default="vit_small", 25 | type=str, 26 | choices=[ 27 | "vit_tiny", 28 | "vit_small", 29 | "vit_base", 30 | "ibot_base", 31 | ], 32 | help="Model architecture.", 33 | ) 34 | parser.add_argument( 35 | "--patch_size", default=16, type=int, help="Patch resolution of the model." 36 | ) 37 | 38 | # Use a dataset 39 | parser.add_argument( 40 | "--dataset", 41 | default="VOC07", 42 | type=str, 43 | choices=[None, "VOC07", "VOC12", "COCO20k", "COCO", "COCOminival", 44 | "ECSSD", "DUTS", "DUT-OMRON"], 45 | help="Dataset name.", 46 | ) 47 | parser.add_argument( 48 | "--set", 49 | default="train", 50 | type=str, 51 | choices=["val", "train", "trainval", "test"], 52 | help="Path of the image to load.", 53 | ) 54 | # Or use a single image 55 | parser.add_argument( 56 | "--image_path", 57 | type=str, 58 | default=None, 59 | help="If want to apply only on one image, give file path.", 60 | ) 61 | 62 | # Folder used to output visualizations and 63 | parser.add_argument( 64 | "--output_dir", type=str, default="outputs", help="Output directory to store predictions and visualizations." 65 | ) 66 | 67 | # Evaluation setup 68 | parser.add_argument("--no_hard", action="store_true", help="Only used in the case of the VOC_all setup (see the paper).") 69 | parser.add_argument("--no_evaluation", action="store_true", help="Compute the evaluation.") 70 | parser.add_argument("--save_predictions", default=True, type=bool, help="Save predicted bouding boxes.") 71 | 72 | # Visualization 73 | parser.add_argument( 74 | "--visualize", 75 | type=str, 76 | choices=["pred", None], 77 | default=None, 78 | help="Select the different type of visualizations.", 79 | ) 80 | 81 | # MOST parameters 82 | parser.add_argument( 83 | "--which_features", 84 | type=str, 85 | default="k", 86 | choices=["k", "q", "v"], 87 | help="Which features to use", 88 | ) 89 | parser.add_argument( 90 | "--k_patches", 91 | type=int, 92 | default=100, 93 | help="Number of patches with the lowest degree considered." 94 | ) 95 | parser.add_argument( 96 | "--dbscan_eps", 97 | type=int, 98 | default=2, 99 | help="DBSCAN min distance per sample" 100 | ) 101 | parser.add_argument( 102 | "--ks", 103 | nargs='+', 104 | default=[1,2,3,4,5] 105 | ) 106 | 107 | 108 | args = parser.parse_args() 109 | 110 | if args.image_path is not None: 111 | args.save_predictions = False 112 | args.no_evaluation = True 113 | args.dataset = None 114 | 115 | # ------------------------------------------------------------------------------------------------------- 116 | # Dataset 117 | 118 | # If an image_path is given, apply the method only to the image 119 | if args.image_path is not None: 120 | dataset = ImageDataset(args.image_path) 121 | else: 122 | dataset = Dataset(args.dataset, args.set, args.no_hard) 123 | 124 | # ------------------------------------------------------------------------------------------------------- 125 | # Model 126 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 127 | model = get_model(args.arch, args.patch_size, device) 128 | # ------------------------------------------------------------------------------------------------------- 129 | # Directories 130 | if args.image_path is None: 131 | args.output_dir = os.path.join(args.output_dir, dataset.name) 132 | os.makedirs(args.output_dir, exist_ok=True) 133 | 134 | # Naming 135 | # Experiment with MOST 136 | exp_name = f"MOST-{args.arch}" 137 | if "vit" in args.arch: 138 | exp_name += f"{args.patch_size}_{args.which_features}" 139 | 140 | print(f"Running MOST on the dataset {dataset.name} (exp: {exp_name})") 141 | 142 | # Visualization 143 | if args.visualize: 144 | vis_folder = f"{args.output_dir}/visualizations/{exp_name}" 145 | os.makedirs(vis_folder, exist_ok=True) 146 | 147 | # ------------------------------------------------------------------------------------------------------- 148 | # Loop over images 149 | preds_dict = {} 150 | cnt = 0 151 | corloc = np.zeros(len(dataset.dataloader)) 152 | fbmax = np.zeros(len(dataset.dataloader)) 153 | iou_arr = np.zeros(len(dataset.dataloader)) 154 | acc = np.zeros(len(dataset.dataloader)) 155 | pbar = tqdm(dataset.dataloader) 156 | for im_id, inp in enumerate(pbar): 157 | # ------------ IMAGE PROCESSING ------------------------------------------- 158 | img = inp[0] 159 | init_image_size = img.shape 160 | if (torch.tensor(init_image_size[1:]) > 1000).any(): 161 | continue 162 | # Get the name of the image 163 | if not inp[1]: 164 | continue 165 | im_name = dataset.get_image_name(inp[1]) 166 | 167 | # Pass in case of no gt boxes in the image 168 | if im_name is None: 169 | continue 170 | 171 | # Padding the image with zeros to fit multiple of patch-size 172 | size_im = ( 173 | img.shape[0], 174 | int(np.ceil(img.shape[1] / args.patch_size) * args.patch_size), 175 | int(np.ceil(img.shape[2] / args.patch_size) * args.patch_size), 176 | ) 177 | paded = torch.zeros(size_im) 178 | paded[:, : img.shape[1], : img.shape[2]] = img 179 | img = paded 180 | 181 | # Move to gpu 182 | img = img.cuda(non_blocking=True) 183 | # Size for transformers 184 | w_featmap = img.shape[-2] // args.patch_size 185 | h_featmap = img.shape[-1] // args.patch_size 186 | 187 | # ------------ GROUND-TRUTH ------------------------------------------- 188 | if not args.no_evaluation: 189 | gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name) 190 | if gt_bbxs is not None: 191 | # Discard images with no gt annotations 192 | # Happens only in the case of VOC07 and VOC12 193 | if gt_bbxs.shape[0] == 0 and args.no_hard: 194 | continue 195 | 196 | # ------------ EXTRACT FEATURES ------------------------------------------- 197 | with torch.no_grad(): 198 | 199 | # ------------ FORWARD PASS ------------------------------------------- 200 | if "vit" in args.arch or "ibot" in args.arch: 201 | # Store the outputs of qkv layer from the last attention layer 202 | feat_out = {} 203 | def hook_fn_forward_qkv(module, input, output): 204 | feat_out["qkv"] = output 205 | model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 206 | 207 | # Forward pass in the model 208 | attentions = model.get_last_selfattention(img[None, :, :, :]) 209 | 210 | # Scaling factor 211 | scales = [args.patch_size, args.patch_size] 212 | 213 | # Dimensions 214 | nb_im = attentions.shape[0] # Batch size 215 | nh = attentions.shape[1] # Number of heads 216 | nb_tokens = attentions.shape[2] # Number of tokens 217 | # Extract the qkv features of the last attention layer 218 | qkv = ( 219 | feat_out["qkv"] 220 | .reshape(nb_im, nb_tokens, 3, nh, -1 // nh) 221 | .permute(2, 0, 3, 1, 4) 222 | ) 223 | q, k, v = qkv[0], qkv[1], qkv[2] 224 | k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 225 | q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 226 | v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 227 | # Modality selection 228 | if args.which_features == "k": 229 | feats = k[:, 1:, :] 230 | elif args.which_features == "q": 231 | feats = q[:, 1:, :] 232 | elif args.which_features == "v": 233 | feats = v[:, 1:, :] 234 | else: 235 | raise ValueError("Unknown model.") 236 | 237 | # ------------ Apply MOST ------------------------------------------- 238 | pred, A, seed, others = most( 239 | feats, 240 | [w_featmap, h_featmap], 241 | scales, 242 | init_image_size, 243 | k_patches=args.k_patches, 244 | dbscan_eps=args.dbscan_eps, 245 | return_mask = (args.dataset in ['ECSSD', "DUTS", 246 | "DUT-OMRON"]), 247 | ks=args.ks 248 | ) 249 | # ------------ Visualizations ------------------------------------------- 250 | if args.visualize == "pred": 251 | image = dataset.load_image(im_name) 252 | if args.dataset in ['ECSSD', 'DUTS', 'DUT-OMRON']: 253 | visualize_map( 254 | image, 255 | pred, 256 | gt_bbxs, 257 | vis_folder, 258 | im_name, 259 | others 260 | ) 261 | else: 262 | visualize_predictions( 263 | image, 264 | pred, 265 | seed, 266 | scales, 267 | [w_featmap, h_featmap], 268 | vis_folder, 269 | im_name 270 | ) 271 | # Save the prediction 272 | if args.dataset in ['ECSSD', 'DUTS', 'DUT-OMRON']: 273 | if pred: 274 | preds_dict[im_name] = np.max(np.stack(pred, axis=0),axis=0) 275 | else: 276 | preds_dict[im_name] = np.asarray([]) 277 | else: 278 | preds_dict[im_name] = pred 279 | 280 | # Evaluation 281 | if args.no_evaluation: 282 | continue 283 | 284 | # Compare prediction to GT boxes 285 | if args.dataset in ['ECSSD', 'DUTS', 'DUT-OMRON']: 286 | # For saliency detection, compute different metrics 287 | fbmax_, iou_, acc_ = saliency_metrics(pred, gt_bbxs) 288 | fbmax[im_id] = fbmax_ 289 | iou_arr[im_id] = iou_ 290 | acc[im_id] = acc_ 291 | cnt += 1 292 | # if cnt % 50 == 0: 293 | cur_fbmax = np.sum(fbmax)/cnt 294 | cur_iou = np.sum(iou_arr)/cnt 295 | cur_acc = np.sum(acc)/cnt 296 | pbar.set_description(f"F1: {cur_fbmax}, IoU: {cur_iou}, Acc: {cur_acc}") 297 | 298 | else: 299 | if isinstance(pred, list): 300 | ious = [] 301 | for p in pred: 302 | try: 303 | iou = bbox_iou(torch.from_numpy(p), torch.from_numpy(gt_bbxs)) 304 | except: 305 | pdb.set_trace() 306 | ious.append(iou) 307 | if ious: 308 | ious = torch.cat(ious) 309 | else: 310 | ious = torch.tensor([]) 311 | else: 312 | ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs)) 313 | if ious.shape[0] > 0: 314 | if torch.any(ious >= 0.5): 315 | corloc[im_id] = 1 316 | 317 | cnt += 1 318 | pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}") 319 | 320 | 321 | # Save predicted bounding boxes 322 | if args.save_predictions: 323 | folder = f"{args.output_dir}/{exp_name}" 324 | os.makedirs(folder, exist_ok=True) 325 | filename = os.path.join(folder, "preds.pkl") 326 | with open(filename, "wb") as f: 327 | pickle.dump(preds_dict, f) 328 | print("Predictions saved at %s" % filename) 329 | 330 | # Evaluate 331 | if not args.no_evaluation: 332 | result_file = os.path.join(folder, 'results.txt') 333 | if args.dataset in ['ECSSD', 'DUTS', 'DUT-OMRON']: 334 | print(f"F1: {int(np.sum(fbmax))}/{cnt}, IoU: {int(np.sum(iou_arr))}/{cnt}, Acc: {int(np.sum(acc))}/{cnt}") 335 | res_str = 'fbmax,%.1f,iou,%.1f,acc,%.1f,,\n'%(100*np.sum(fbmax)/cnt, 336 | 100*np.sum(iou_arr)/cnt, 337 | 100*np.sum(acc)/cnt) 338 | else: 339 | print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})") 340 | res_str = 'corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt) 341 | with open(result_file, 'w') as f: 342 | f.write(res_str) 343 | print('File saved at %s'%result_file) 344 | 345 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from torchvision.models.resnet import resnet50 6 | from torchvision.models.vgg import vgg16 7 | 8 | import dino.vision_transformer as vits 9 | 10 | def get_model(arch, patch_size, device): 11 | if "ibot" in arch: 12 | # Currently only supporting base 13 | assert patch_size == 16 14 | model = vits.__dict__['vit_base'](patch_size=patch_size, return_all_tokens=True) 15 | 16 | else: 17 | model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) 18 | 19 | for p in model.parameters(): 20 | p.requires_grad = False 21 | 22 | # Initialize model with pretraining 23 | if "imagenet" not in arch: 24 | url = None 25 | if arch == "vit_small" and patch_size == 16: 26 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 27 | # url = None 28 | elif arch == "vit_small" and patch_size == 8: 29 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 30 | elif arch == "vit_base" and patch_size == 16: 31 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 32 | elif arch == "vit_base" and patch_size == 8: 33 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 34 | elif arch == "ibot_base" and patch_size == 16: 35 | url = None 36 | if url is not None: 37 | print( 38 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." 39 | ) 40 | state_dict = torch.hub.load_state_dict_from_url( 41 | url="https://dl.fbaipublicfiles.com/dino/" + url 42 | ) 43 | strict_loading = True 44 | msg = model.load_state_dict(state_dict, strict=strict_loading) 45 | print( 46 | "Pretrained weights found at {} and loaded with msg: {}".format( 47 | url, msg 48 | ) 49 | ) 50 | else: 51 | if arch == "ibot_base": 52 | state_dict = torch.load('weights/checkpoint_teacher.pth')['state_dict'] 53 | state_dict = {k.replace("module.", ""): v for k, v in 54 | state_dict.items()} 55 | else: 56 | state_dict = torch.load('weights/checkpoint.pth')['teacher'] 57 | to_remove = ['head.mlp.0.weight', 'head.mlp.0.bias', 58 | 'head.mlp.2.weight', 'head.mlp.2.bias', 59 | 'head.mlp.4.weight', 'head.mlp.4.bias', 60 | 'head.last_layer.weight_g', 61 | 'head.last_layer.weight_v'] 62 | state_dict = {k.replace('backbone.',''): v for k, v in 63 | state_dict.items() if k not in to_remove} 64 | 65 | strict_loading = False if 'ibot' in arch else True 66 | msg = model.load_state_dict(state_dict, strict=strict_loading) 67 | 68 | print( 69 | "Pretrained weights found at {} and loaded with msg: {}".format( 70 | 'weights/checkpoint.pth', msg 71 | ) 72 | ) 73 | 74 | model.eval() 75 | model.to(device) 76 | return model 77 | 78 | -------------------------------------------------------------------------------- /object_discovery.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy 3 | import scipy.ndimage 4 | from collections import Counter 5 | import time 6 | import datetime 7 | import os 8 | import random 9 | 10 | import numpy as np 11 | from torchvision.ops import batched_nms 12 | from itertools import product 13 | from sklearn.cluster import DBSCAN 14 | 15 | 16 | def entropy_tensor(vals): 17 | _, counts = torch.unique(vals, return_counts=True) 18 | freq = counts/torch.sum(counts) 19 | ent = -1 * torch.sum(freq * 20 | torch.log(freq)/torch.log(torch.tensor([2.])).to(freq.device)) 21 | return ent.item() 22 | 23 | 24 | def compute_block_entropy(map_, poolers): 25 | with torch.no_grad(): 26 | f = [l(map_.unsqueeze(0).unsqueeze(0).cuda()).reshape(-1) for l in poolers] 27 | ents = [entropy_tensor(l) for l in f] 28 | return ents 29 | 30 | def most(feats, dims, scales, init_image_size, k_patches=100, 31 | dbscan_eps=2, return_mask=False, ks=[1,2,3,4,5]): 32 | """ 33 | Implementation of MOST method. 34 | Inputs 35 | feats: the pixel/patche features of an image 36 | dims: dimension of the map from which the features are used 37 | scales: from image to map scale 38 | init_image_size: size of the image 39 | k_patches: number of k patches retrieved that are compared to the seed at seed expansion 40 | dbscan_eps: threshold for clustering 41 | return_mask: Flag to return mask for saliency detection 42 | ks: Pooling filter sizes 43 | """ 44 | if scales[0] != scales[1]: 45 | raise Exception('Scales values should be the same', scales) 46 | # Get number of features 47 | A = (feats @ feats.transpose(1, 2)).squeeze() 48 | og_ks = 1 49 | 50 | seeds = [] 51 | 52 | for i in range(A.shape[0]): 53 | # Get each map 54 | map_ = A[i].clone() 55 | map_[map_ <= 0] = 0 56 | map_ = map_.reshape(dims[0], dims[1]) 57 | # Get thresholds for entropy 58 | ks = list(map(int, ks)) 59 | ks = list(set([min(k, min(dims)) for k in ks])) 60 | ks.sort() 61 | pool_dims = [[dims[0] - k+1, dims[1]-k+1] for k in ks] 62 | feat_pool = [torch.nn.AdaptiveAvgPool2d(k) for k in pool_dims] 63 | thresh = [1+(np.log(d[0])/np.log(2)) for d in pool_dims] 64 | # compute entropy at each resolution 65 | ents = compute_block_entropy(map_, feat_pool) 66 | # Check if map contain any object 67 | pass_ = [l < t for l, t in zip(ents, thresh)] 68 | # If atleast 50% of the maps agree there is an object, we pick it 69 | if sum(pass_) >= 0.5 * len(pass_): 70 | seeds.append(i) 71 | seeds.sort() 72 | # If there are no seeds, then there are no objects 73 | if not seeds: 74 | return [], A, seeds, [] 75 | 76 | # Since we are using manhattan distance, any of the eight neighbors are 77 | # considered neighbors 78 | dbscan = DBSCAN(eps=dbscan_eps, metric='cityblock', min_samples=1) 79 | # We use min_samples 1 to allow a single seed to be its own cluster 80 | # Unravel linear index to coordinates to cluster them 81 | seed_coords = np.stack([np.unravel_index(l, dims) for l in seeds]) 82 | assign = dbscan.fit(seed_coords) 83 | seed_labels = assign.labels_ 84 | 85 | A_clone = A.clone() 86 | A_clone.fill_diagonal_(0) 87 | A_clone[A_clone < 0] = 0 88 | 89 | preds = [] 90 | masks = [] 91 | req_feats = [] 92 | cent = -torch.sum(A_clone > 0, dim=1).type(torch.float32) 93 | 94 | max_cluster_id = Counter(seed_labels).most_common(1)[0][0] 95 | pick = 0 96 | # For each cluster, get all seeds belonging to cluster and construct a box 97 | for ii, cl in enumerate(np.unique(seed_labels)): 98 | # First collect all seeds that belong to this cluster 99 | similars = [seeds[i] for i in np.where(seed_labels == cl)[0]] 100 | # Find the seed with the maximum outgoing degree 101 | seed_value = [cent[l] for l in similars] 102 | pot_seed = max(seed_value) 103 | # Find all pixels that have similarity with the seed with highest 104 | # degree 105 | seed = torch.tensor(similars[seed_value.index(pot_seed)]) 106 | similars = [l for l in similars if A[seed, torch.tensor([l])] > 0] 107 | # Find the mask 108 | M = torch.sum(A[similars, :], dim=0) 109 | # Detect the box using the mask 110 | pred, small_pred = detect_box( 111 | M, seed, dims, scales=scales, initial_im_size=init_image_size[1:]) 112 | 113 | if pred: 114 | boxes = np.stack(pred) 115 | areas = (boxes[:,3] - boxes[:,1]) * (boxes[:,2] - boxes[:,0]) 116 | widths = boxes[:,2] - boxes[:,0] 117 | heights = (boxes[:,3] - boxes[:,1]) 118 | # Remove trivial boxes 119 | # Boxes are kept if they have a height/width greater than 16 pixels 120 | keep_hw = np.bitwise_and(widths > 16, heights > 16) 121 | # Remove boxes if they cover the full image 122 | keep_ar = np.bitwise_and(areas > 256, areas < 123 | 0.9*np.prod(init_image_size[1:])) 124 | keep = np.where(np.bitwise_and(keep_hw, keep_ar))[0] 125 | 126 | pred = [pred[l] for l in keep] 127 | 128 | if not return_mask: 129 | # If not returning a mask, then just stack them 130 | preds+= [torch.tensor(l) for l in pred] 131 | elif pred: 132 | # If output is a mask, then postprocess and upsample the mask 133 | # to the image resolution 134 | if cl == max_cluster_id: 135 | pick = ii 136 | full_map = refine_mask(M, seed, dims) 137 | full_map = torch.nn.functional.interpolate(full_map.unsqueeze(0).unsqueeze(0), 138 | scale_factor=scales[0], 139 | mode='nearest') 140 | full_map = full_map.squeeze(0).squeeze(0) 141 | preds.append(full_map.cpu().numpy()) 142 | 143 | else: 144 | # If there is no box, then move on 145 | continue 146 | 147 | others = [] 148 | if preds: 149 | if not return_mask: 150 | preds_tensor = torch.stack(preds) 151 | preds = [l.numpy() for l in preds] 152 | others = [] 153 | else: 154 | others = [l for i, l in enumerate(preds)] 155 | if pick < len(preds): 156 | preds = [preds[pick]] 157 | else: 158 | preds = [preds[0]] 159 | return preds, A, seeds, others 160 | 161 | 162 | 163 | def refine_mask(A, seed, dims): 164 | w_featmap, h_featmap = dims 165 | 166 | correl = A.reshape(w_featmap, h_featmap).float() 167 | 168 | # Compute connected components 169 | labeled_array, num_features = scipy.ndimage.label(correl.cpu().numpy() > 0.0) 170 | 171 | # Find connected component corresponding to the initial seed 172 | cc = labeled_array[np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))] 173 | 174 | if cc == 0: 175 | return [], [] 176 | mask = A.reshape(w_featmap, h_featmap).cpu().numpy() 177 | mask[labeled_array != cc] = -1 178 | 179 | return torch.tensor(mask) 180 | 181 | 182 | def detect_box(A, seed, dims, initial_im_size=None, scales=None): 183 | """ 184 | Extract a box corresponding to the seed patch. Among connected components extract from the affinity matrix, select the one corresponding to the seed patch. 185 | """ 186 | w_featmap, h_featmap = dims 187 | 188 | correl = A.reshape(w_featmap, h_featmap).float() 189 | 190 | # Compute connected components 191 | labeled_array, num_features = scipy.ndimage.label(correl.cpu().numpy() > 0.0) 192 | 193 | # Find connected component corresponding to the initial seed 194 | cc = labeled_array[np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))] 195 | 196 | # Should not happen with LOST 197 | if cc == 0: 198 | # raise ValueError("The seed is in the background component.") 199 | return [], [] 200 | 201 | # Find box 202 | mask = np.where(labeled_array == cc) 203 | # Add +1 because excluded max 204 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 205 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 206 | 207 | # Rescale to image size 208 | r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax 209 | r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax 210 | 211 | pred = [r_xmin, r_ymin, r_xmax, r_ymax] 212 | 213 | # Check not out of image size (used when padding) 214 | if initial_im_size: 215 | pred[2] = min(pred[2], initial_im_size[1]) 216 | pred[3] = min(pred[3], initial_im_size[0]) 217 | 218 | # Coordinate predictions for the feature space 219 | # Axis different then in image space 220 | pred_feats = [ymin, xmin, ymax, xmax] 221 | 222 | return [pred], [pred_feats] 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy>=1.4.1 2 | matplotlib>=3.2.2 3 | opencv-python>=4.1.2 4 | tqdm>=4.41.0 5 | scikit-image 6 | catalyst 7 | pandas 8 | scikit-learn 9 | pycocotools 10 | -------------------------------------------------------------------------------- /saliency_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import jaccard_score 3 | import pdb 4 | 5 | 6 | def saliency_metrics(pred_list, gt_img): 7 | # pred_list is list of masks from different seeds 8 | # add them all up 9 | if not pred_list: 10 | return 0, 0, 0 11 | # pdb.set_trace() 12 | pred = np.stack(pred_list, axis=0).max(0) 13 | if pred.shape != gt_img.shape: 14 | pred = pred[:gt_img.shape[0], :gt_img.shape[1]] 15 | # check if gt has values between 0-1 16 | if gt_img.max() > 1: 17 | gt_img = gt_img.astype(float) 18 | gt_img /= 255. 19 | if gt_img.ndim > 2: 20 | pdb.set_trace() 21 | pred_sig = 1./(1+np.exp(-pred)) 22 | # get fbmax 23 | fbmax = compute_fbmax(gt_img, pred_sig) 24 | # get IoU with binarization threshold 0.5 25 | pred_bin = pred_sig > 0.5 26 | try: 27 | iou = jaccard_score(gt_img.reshape(-1).astype(bool), pred_bin.reshape(-1)) 28 | except: 29 | pdb.set_trace() 30 | 31 | # get Acc with binarization threshold 0.5 32 | acc = (pred_bin == gt_img).sum()/np.prod(gt_img.shape) 33 | 34 | return fbmax, iou, acc 35 | 36 | 37 | def compute_fbmax(gt, pred, levels=255): 38 | beta = 0.3 39 | thresholds = np.linspace(0, 1-1e-10, levels) 40 | prec = np.zeros(levels) 41 | recall = np.zeros(levels) 42 | 43 | for i, th in enumerate(thresholds): 44 | bin_pred = (pred >= th).astype(float) 45 | tp = (bin_pred * gt).sum() 46 | prec[i] = tp/(bin_pred.sum() + 1e-15) 47 | recall[i] = tp/(gt.sum() + 1e-15) 48 | f_score = (1+beta**2)*prec * recall/((beta**2*prec) + recall) 49 | f_score[np.isnan(f_score)] = 0 50 | return f_score.max() 51 | -------------------------------------------------------------------------------- /visualizations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import skimage.io 4 | import numpy as np 5 | import torch.nn as nn 6 | from PIL import Image 7 | import pdb 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def visualize_map(image, pred, gt_img, vis_folder, im_name, all_preds=None): 13 | """ 14 | Visualization of the predicted box and the corresponding seed patch. 15 | """ 16 | if not isinstance(pred, list): 17 | pred = [pred] 18 | pred = np.stack(pred, axis=0).max(0) 19 | if pred.shape != gt_img.shape: 20 | pred = pred[:gt_img.shape[0], :gt_img.shape[1]] 21 | if image.shape != gt_img.shape: 22 | image = image[:gt_img.shape[0], :gt_img.shape[1]] 23 | 24 | pred_sig = pred 25 | widths = [image.shape[1], pred_sig.shape[1]] 26 | heights = [image.shape[0], pred_sig.shape[0]] 27 | gap_img = Image.fromarray(255*np.ones((image.shape[0], 5))) 28 | 29 | ims = [Image.fromarray(image), gap_img, 30 | Image.fromarray(255*(pred_sig>0).astype(float)), gap_img] 31 | if all_preds: 32 | # others are given 33 | stack_preds = np.stack(all_preds, axis=0).max(0) 34 | widths.append(stack_preds.shape[1]) 35 | heights.append(stack_preds.shape[0]) 36 | ims.append(Image.fromarray(255*(stack_preds>0).astype(float))) 37 | ims.append(gap_img) 38 | widths.append(gt_img.shape[1]) 39 | heights.append(gt_img.shape[0]) 40 | ims.append(Image.fromarray(gt_img)) 41 | total_width = sum(widths) + 15 42 | max_height = max(heights) 43 | new_im = Image.new('RGB', (total_width, max_height)) 44 | x_offset = 0 45 | for im in ims: 46 | new_im.paste(im, (x_offset, 0)) 47 | x_offset += im.size[0] 48 | 49 | pltname = f"{vis_folder}/MOST_{im_name}.png" 50 | new_im.save(pltname) 51 | print(f"Predictions saved at {pltname}.") 52 | 53 | def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, 54 | k_suffix, plot_seed=False): 55 | """ 56 | Visualization of the predicted box and the corresponding seed patch. 57 | """ 58 | w_featmap, h_featmap = dims 59 | if not isinstance(pred, list): 60 | pred = [pred] 61 | for ii, pre in enumerate(pred): 62 | area = (pre[3] - pre[1]) * (pre[2] - pre[0]) 63 | if area < 1000: 64 | continue 65 | # Plot the box 66 | cv2.rectangle( 67 | image, 68 | (int(pre[0]), int(pre[1])), 69 | (int(pre[2]), int(pre[3])), 70 | (255, 0, 0), 3, 71 | ) 72 | 73 | # Plot the seed 74 | if plot_seed: 75 | s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap)) 76 | size_ = np.asarray(scales) / 2 77 | cv2.rectangle( 78 | image, 79 | (int(s_[1] * scales[1] - (size_[1] / 2)), int(s_[0] * scales[0] - (size_[0] / 2))), 80 | (int(s_[1] * scales[1] + (size_[1] / 2)), int(s_[0] * scales[0] + (size_[0] / 2))), 81 | (0, 255, 0), -1, 82 | ) 83 | 84 | pltname = f"{vis_folder}/LOST_{im_name}_{k_suffix}.png" 85 | Image.fromarray(image).save(pltname) 86 | print(f"Predictions saved at {pltname}.") 87 | 88 | --------------------------------------------------------------------------------