├── .gitattributes ├── GPCIS_supp.pdf ├── LICENSE ├── README.md ├── checkpoints └── GPCIS_Resnet50.pth ├── config.yml ├── isegm ├── data │ ├── aligned_augmentation.py │ ├── base.py │ ├── compose.py │ ├── datasets │ │ ├── __init__.py │ │ ├── berkeley.py │ │ ├── davis.py │ │ ├── grabcut.py │ │ └── sbd.py │ ├── points_sampler.py │ ├── sample.py │ └── transforms.py ├── engine │ ├── gp_trainer.py │ └── optimizer.py ├── inference │ ├── clicker.py │ ├── evaluation.py │ ├── predictors │ │ ├── __init__.py │ │ └── baseline.py │ ├── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── crops.py │ │ ├── flip.py │ │ ├── limit_longest_side.py │ │ ├── resize.py │ │ └── zoom_in.py │ └── utils.py ├── model │ ├── initializer.py │ ├── is_gp_model.py │ ├── is_gp_resnet50.py │ ├── losses.py │ ├── metrics.py │ ├── modeling │ │ ├── basic_blocks.py │ │ ├── deeplab_v3.py │ │ ├── deeplab_v3_gp.py │ │ ├── resnet.py │ │ └── resnetv1b.py │ ├── modifiers.py │ └── ops.py └── utils │ ├── crop_local.py │ ├── cython │ ├── __init__.py │ ├── _get_dist_maps.pyx │ ├── _get_dist_maps.pyxbld │ └── dist_maps.py │ ├── distributed.py │ ├── exp.py │ ├── exp_imports │ └── default.py │ ├── log.py │ ├── misc.py │ ├── serialization.py │ └── vis.py ├── models └── gp_sbd_resnet50.py ├── net.png ├── requirements.txt ├── run.sh ├── scripts └── evaluate_model.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | *.pth filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /GPCIS_supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmhhmz/GPCIS_CVPR2023/6460415a2e784f5623a0c859971f884a89eb0fd0/GPCIS_supp.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MinghaoZhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Interactive Segmentation as Gaussian Process Classification (CVPR2023 Highlight) 2 | Minghao Zhou, [Hong Wang](https://hongwang01.github.io/), Qian Zhao, Yuexiang Li, Yawen Huang, [Deyu Meng](http://gr.xjtu.edu.cn/web/dymeng), [Yefeng Zheng](https://sites.google.com/site/yefengzheng/) 3 | 4 | 5 | [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_Interactive_Segmentation_As_Gaussion_Process_Classification_CVPR_2023_paper.pdf) [[Poster]](https://cvpr2023.thecvf.com/media/PosterPDFs/CVPR%202023/23088.png?t=1684895990.406102) [[Video]](https://youtu.be/mapyH-WujhY) [[Slides]](https://cvpr2023.thecvf.com/media/cvpr-2023/Slides/23088.pdf) [[Supp]](GPCIS_supp.pdf) 6 | 7 | ## Update 2023/8/1 8 | We have updated [is_gp_model.py](isegm/model/is_gp_model.py) for a more memory-efficient implementation. 9 | 10 | ## Usage 11 | Please first set up the environment and prepare the training (SBD)/testing (GrabCut, Berkeley, SBD, DAVIS) datasets following [RITM](https://github.com/saic-vul/ritm_interactive_segmentation), and change the directories in [config.yml](config.yml). 12 | 13 | Please run [run.sh](run.sh) for training/evaluation. For training, the resnet50 weights pretrained on ImageNet is used. Please download the [weights](https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth) and change the corresponding directory in [config.yml](config.yml). For evaluation, you can directly test with our provided checkpoint in [checkpoints/GPCIS_Resnet50.pth](checkpoints/GPCIS_Resnet50.pth). 14 | 15 | The core codes of the GPCIS model can be found in [isegm/model/is_gp_model.py](isegm/model/is_gp_model.py) and [isegm/model/is_gp_resnet50.py](isegm/model/is_gp_resnet50.py). 16 | 17 | ## Overview of GPCIS 18 |
19 | -------------------------------------------------------------------------------- /checkpoints/GPCIS_Resnet50.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1b59a4c8b7f1ea300abe8add185f8ab98beeeac01bd80bb2a2d029200195c3c9 3 | size 157934784 4 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | INTERACTIVE_MODELS_PATH: " " 2 | EXPS_PATH: "./experiments" 3 | 4 | # Datasets 5 | GRABCUT_PATH: "path_to/GrabCut" 6 | BERKELEY_PATH: "path_to/Berkeley" 7 | DAVIS_PATH: "path_to/DAVIS" 8 | COCO_MVAL_PATH: "path_to/COCO_MVal" 9 | PASCALVOC_PATH: "path_to/VOC2012" 10 | DAVIS585_PATH: "path_to/Selected_480P" 11 | SBD_PATH: "path_to/SBD/dataset" 12 | 13 | # Pretrained weights 14 | IMAGENET_PRETRAINED_MODELS: 15 | RESNET50_v1s: "path_to/gluon_resnet50_v1s-1762acc0.pth" 16 | 17 | 18 | -------------------------------------------------------------------------------- /isegm/data/aligned_augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy.stats import truncnorm 4 | 5 | def get_truncated_normal(mean=0, sd=1, low=0, upp=10): 6 | return truncnorm( 7 | (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd) 8 | 9 | #X1 = get_truncated_normal(mean=0.7, sd=0.3, low=0.2, upp=1) 10 | #x1 = X1.rvs(1)[0] 11 | 12 | 13 | class AlignedAugmentator: 14 | def __init__(self, ratio = [0.3,1], target_size = (256,256), flip = True, 15 | distribution = 'Uniform', gs_center = 0.8, gs_sd = 0.4, 16 | color_augmentator = None): 17 | ''' 18 | distribution belongs to [ 'Uniform, Gaussian' ] 19 | ''' 20 | self.ratio = ratio 21 | self.target_size = target_size 22 | self.flip = flip 23 | self.distribution = distribution 24 | self.gaussian = get_truncated_normal(mean=gs_center, sd=gs_sd, low=ratio[0], upp=ratio[1]) 25 | self.color_augmentator = color_augmentator 26 | 27 | def __call__(self, image, mask): 28 | ''' 29 | image: np.array (267, 400, 3) np.uint8 30 | mask: np.array (267, 400, 1) np.int32 31 | ''' 32 | 33 | if self.distribution == 'Uniform': 34 | hr,wr = np.random.uniform(*self.ratio),np.random.uniform(*self.ratio) 35 | elif self.distribution == 'Gaussian': 36 | hr,wr = self.gaussian.rvs(2) 37 | 38 | H,W = image.shape[0], image.shape[1] 39 | h,w = int(H*hr), int(W*wr) 40 | if hr > 1 or wr > 1: 41 | image, mask = self.pad_image_mask(image, mask, hr, wr) 42 | H,W = image.shape[0], image.shape[1] 43 | 44 | y1 = np.random.randint(0,H-h) 45 | x1 = np.random.randint(0,W-w) 46 | y2 = y1 + h 47 | x2 = y1 + W 48 | 49 | image_crop = image[y1:y2,x1:x2,:] 50 | image_crop = cv2.resize(image_crop, tuple(self.target_size)) 51 | mask_crop = mask[y1:y2,x1:x2,:].astype(np.uint8) 52 | mask_crop = (cv2.resize(mask_crop, tuple(self.target_size))).astype(np.int32) 53 | if len(mask_crop.shape) == 2: 54 | mask_crop = np.expand_dims(mask_crop,-1) 55 | 56 | if self.flip: 57 | if np.random.rand() < 0.3: 58 | image_crop = np.flip(image_crop,0) 59 | mask_crop = np.flip(mask_crop,0) 60 | if np.random.rand() < 0.3: 61 | image_crop = np.flip(image_crop,1) 62 | mask_crop = np.flip(mask_crop,1) 63 | 64 | image_crop = np.ascontiguousarray(image_crop) 65 | mask_crop = np.ascontiguousarray(mask_crop) 66 | 67 | if self.color_augmentator is not None: 68 | image_crop = self.color_augmentator(image=image_crop)['image'] 69 | 70 | aug_output = {} 71 | aug_output['image'] = image_crop 72 | aug_output['mask'] = mask_crop 73 | return aug_output 74 | 75 | def pad_image_mask(self, image, mask, hr, wr): 76 | H,W = image.shape[0], image.shape[1] 77 | if hr > 1: 78 | new_h = int(H * hr) + 1 79 | pad_h = new_h - H 80 | pad_h1 = np.random.randint(0,pad_h) 81 | pad_h2 = pad_h - pad_h1 82 | image = np.pad(image, ((pad_h1, pad_h2),(0,0),(0,0)), 'constant') 83 | mask = np.pad(mask, ((pad_h1, pad_h2),(0,0),(0,0)), 'constant') 84 | 85 | if wr > 1: 86 | new_w = int(W * wr) + 1 87 | pad_w = new_w - W 88 | pad_w1 = np.random.randint(0,pad_w) 89 | pad_w2 = pad_w - pad_w1 90 | image = np.pad(image, ((0,0), (pad_w1, pad_w2),(0,0)), 'constant') 91 | mask = np.pad(mask, ( (0,0), (pad_w1, pad_w2),(0,0)), 'constant') 92 | return image, mask 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /isegm/data/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | from .points_sampler import MultiPointSampler 7 | from .sample import DSample 8 | import cv2 9 | from isegm.utils.crop_local import random_choose_target,get_bbox_from_mask,getLargestCC,expand_bbox, expand_bbox_with_bias 10 | from skimage import morphology 11 | # import pdb 12 | 13 | class ISDataset(torch.utils.data.dataset.Dataset): 14 | def __init__(self, 15 | augmentator=None, 16 | points_sampler=MultiPointSampler(max_num_points=12), 17 | min_object_area=0, 18 | keep_background_prob=0.0, 19 | with_image_info=False, 20 | samples_scores_path=None, 21 | samples_scores_gamma=1.0, 22 | epoch_len=-1, 23 | with_refiner = False): 24 | super(ISDataset, self).__init__() 25 | self.epoch_len = epoch_len 26 | self.augmentator = augmentator 27 | self.min_object_area = min_object_area 28 | self.keep_background_prob = keep_background_prob 29 | self.points_sampler = points_sampler 30 | self.with_image_info = with_image_info 31 | self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma) 32 | self.to_tensor = transforms.ToTensor() 33 | self.with_refiner = with_refiner 34 | self.dataset_samples = None 35 | 36 | def __getitem__(self, index): 37 | while(1): 38 | try: 39 | if self.samples_precomputed_scores is not None: 40 | index = np.random.choice(self.samples_precomputed_scores['indices'], 41 | p=self.samples_precomputed_scores['probs']) 42 | else: 43 | if self.epoch_len > 0: 44 | index = random.randrange(0, len(self.dataset_samples)) 45 | 46 | sample = self.get_sample(index) 47 | sample = self.augment_sample(sample) 48 | sample.remove_small_objects(self.min_object_area) 49 | 50 | self.points_sampler.sample_object(sample) 51 | points = np.array(self.points_sampler.sample_points()) 52 | mask = self.points_sampler.selected_mask 53 | mask = self.remove_small_regions(mask) 54 | image = sample.image 55 | mask_area = mask[0].shape[0] * mask[0].shape[1] 56 | 57 | if self.with_refiner: 58 | trimap = self.get_trimap(mask[0]) 59 | if mask[0].sum() < 3600: # 80 * 80 60 | y1,x1,y2,x2 = self.sampling_roi_full_object(mask[0]) 61 | else: 62 | if np.random.rand() < 0.4: 63 | y1,x1,y2,x2 = self.sampling_roi_on_boundary(mask[0]) 64 | else: 65 | y1,x1,y2,x2 = self.sampling_roi_full_object(mask[0]) 66 | 67 | roi = torch.tensor([x1, y1, x2, y2]) 68 | h,w = mask[0].shape[0], mask[0].shape[1] 69 | image_focus = image[y1:y2,x1:x2,:] 70 | image_focus = cv2.resize(image_focus, (h,w)) 71 | 72 | mask_255 = (mask[0] * 255).astype(np.uint8) 73 | mask_focus = mask_255[y1:y2,x1:x2] 74 | mask_focus = cv2.resize(mask_focus, (h,w)) > 128 75 | mask_focus = np.expand_dims(mask_focus,0).astype(np.float32) 76 | 77 | trimap_255 = (trimap[0] * 255).astype(np.uint8) 78 | trimap_focus = trimap_255[y1:y2,x1:x2] 79 | trimap_focus = cv2.resize(trimap_focus, (h,w)) > 128 80 | trimap_focus = np.expand_dims(trimap_focus,0).astype(np.float32) 81 | 82 | hc,wc = y2-y1, x2-x1 83 | ry,rx = h/hc, w/wc 84 | bias = np.array([y1,x1,0]) 85 | ratio = np.array([ry,rx,1]) 86 | points_focus = (points - bias) * ratio 87 | 88 | if mask.sum() > self.min_object_area and mask.sum() < mask_area * 0.85: 89 | 90 | output = { 91 | 'images': self.to_tensor(image), 92 | 'points': points.astype(np.float32), 93 | 'instances': mask, 94 | 'trimap':trimap, 95 | 'images_focus':self.to_tensor(image_focus), 96 | 'instances_focus':mask_focus, 97 | 'trimap_focus': trimap_focus, 98 | 'points_focus': points_focus.astype(np.float32), 99 | 'rois':roi.float() 100 | } 101 | 102 | if self.with_image_info: 103 | output['image_info'] = sample.sample_id 104 | return output 105 | else: 106 | index = np.random.randint(len(self.dataset_samples)-1) 107 | else: 108 | if mask.sum() > self.min_object_area and mask.sum() < mask_area * 0.85: 109 | output = { 110 | 'images': self.to_tensor(image), 111 | 'points': points.astype(np.float32), 112 | 'instances': mask, 113 | } 114 | 115 | if self.with_image_info: 116 | output['image_info'] = sample.sample_id 117 | return output 118 | else: 119 | index = np.random.randint(len(self.dataset_samples)-1) 120 | except: 121 | 122 | index = np.random.randint(len(self.dataset_samples)-1) 123 | 124 | # def __getitem__(self, index): 125 | # if self.samples_precomputed_scores is not None: 126 | # index = np.random.choice(self.samples_precomputed_scores['indices'], 127 | # p=self.samples_precomputed_scores['probs']) 128 | # else: 129 | # if self.epoch_len > 0: 130 | # index = random.randrange(0, len(self.dataset_samples)) 131 | 132 | # sample = self.get_sample(index) 133 | # sample = self.augment_sample(sample) 134 | # sample.remove_small_objects(self.min_object_area) 135 | 136 | # self.points_sampler.sample_object(sample) 137 | # points = np.array(self.points_sampler.sample_points()) 138 | # mask = self.points_sampler.selected_mask 139 | # output = { 140 | # 'images': self.to_tensor(sample.image), 141 | # 'points': points.astype(np.float32), 142 | # 'instances': mask 143 | # } 144 | 145 | # if self.with_image_info: 146 | # output['image_info'] = sample.sample_id 147 | 148 | # return output 149 | 150 | def remove_small_regions(self,mask): 151 | mask = mask[0] > 0.5 152 | mask = morphology.remove_small_objects(mask,min_size= 900) 153 | mask = np.expand_dims(mask,0).astype(np.float32) 154 | return mask 155 | 156 | 157 | def sampling_roi_full_object(self, gt_mask, min_size=32): 158 | max_mask = getLargestCC(gt_mask) 159 | y1,y2,x1,x2 = get_bbox_from_mask(max_mask) 160 | ratio = np.random.randint(11,17)/10 161 | y1,y2,x1,x2 = expand_bbox_with_bias(gt_mask,y1,y2,x1,x2,ratio,min_size,0.3) 162 | return y1,x1,y2,x2 163 | 164 | def sampling_roi_on_boundary(self,gt_mask): 165 | h,w = gt_mask.shape[0], gt_mask.shape[1] 166 | rh = np.random.randint(15,40)/10 167 | rw = np.random.randint(15,40)/10 168 | new_h,new_w = h/rh, w/rw 169 | crop_size = (int(new_h), int(new_w)) 170 | 171 | alpha = gt_mask > 0.5 172 | alpha = alpha.astype(np.uint8) 173 | kernel = np.ones((5,5),np.uint8) 174 | dilate = cv2.dilate(alpha,kernel,iterations = 1) 175 | boundary = np.logical_and( dilate, np.logical_not(alpha)) 176 | y1,x1,y2,x2 = random_choose_target(boundary,crop_size) 177 | return y1,x1,y2,x2 178 | 179 | 180 | def get_trimap(self, mask): 181 | h,w = mask.shape[0],mask.shape[1] 182 | hs,ws = h//8,w//8 183 | mask_255_big = (mask * 255).astype(np.uint8) 184 | mask_255_small = (cv2.resize(mask_255_big, (ws,hs)) > 128) * 255 185 | mask_resized = cv2.resize(mask_255_small.astype(np.uint8),(w,h)) > 128 186 | diff_mask = np.logical_xor(mask, mask_resized).astype(np.uint8) 187 | 188 | kernel = np.ones((3, 3), dtype=np.uint8) 189 | diff_mask = cv2.dilate(diff_mask, kernel, iterations=2) # 1:迭代次数,也就是执行几次膨胀操作 190 | 191 | diff_mask = diff_mask.astype(np.float32) 192 | diff_mask = np.expand_dims(diff_mask,0) 193 | return diff_mask 194 | 195 | 196 | def augment_sample(self, sample) -> DSample: 197 | valid_augmentation = False 198 | while not valid_augmentation: 199 | sample.augment(self.augmentator) 200 | keep_sample = (self.keep_background_prob < 0.0 or 201 | random.random() < self.keep_background_prob) 202 | valid_augmentation = len(sample) > 0 or keep_sample 203 | 204 | return sample 205 | 206 | def get_sample(self, index) -> DSample: 207 | raise NotImplementedError 208 | 209 | def __len__(self): 210 | if self.epoch_len > 0: 211 | return self.epoch_len 212 | else: 213 | return self.get_samples_number() 214 | 215 | def get_samples_number(self): 216 | return len(self.dataset_samples) 217 | 218 | @staticmethod 219 | def _load_samples_scores(samples_scores_path, samples_scores_gamma): 220 | if samples_scores_path is None: 221 | return None 222 | 223 | with open(samples_scores_path, 'rb') as f: 224 | images_scores = pickle.load(f) 225 | 226 | probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) 227 | probs /= probs.sum() 228 | samples_scores = { 229 | 'indices': [x[0] for x in images_scores], 230 | 'probs': probs 231 | } 232 | print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') 233 | return samples_scores 234 | -------------------------------------------------------------------------------- /isegm/data/compose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import isclose 3 | from .base import ISDataset 4 | 5 | 6 | class ComposeDataset(ISDataset): 7 | def __init__(self, datasets, **kwargs): 8 | super(ComposeDataset, self).__init__(**kwargs) 9 | 10 | self._datasets = datasets 11 | self.dataset_samples = [] 12 | for dataset_indx, dataset in enumerate(self._datasets): 13 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) 14 | 15 | def get_sample(self, index): 16 | dataset_indx, sample_indx = self.dataset_samples[index] 17 | return self._datasets[dataset_indx].get_sample(sample_indx) 18 | 19 | 20 | class ProportionalComposeDataset(ISDataset): 21 | def __init__(self, datasets, ratios, **kwargs): 22 | super().__init__(**kwargs) 23 | 24 | assert len(ratios) == len(datasets),\ 25 | "The number of datasets must match the number of ratios" 26 | assert isclose(sum(ratios), 1.0),\ 27 | "The sum of ratios must be equal to 1" 28 | 29 | self._ratios = ratios 30 | self._datasets = datasets 31 | self.dataset_samples = [] 32 | for dataset_indx, dataset in enumerate(self._datasets): 33 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) 34 | 35 | def get_sample(self, index): 36 | dataset_indx = np.random.choice(len(self._datasets), p=self._ratios) 37 | sample_indx = np.random.choice(len(self._datasets[dataset_indx])) 38 | 39 | return self._datasets[dataset_indx].get_sample(sample_indx) 40 | -------------------------------------------------------------------------------- /isegm/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from isegm.data.compose import ComposeDataset, ProportionalComposeDataset 2 | from .berkeley import BerkeleyDataset 3 | from .davis import DavisDataset 4 | from .grabcut import GrabCutDataset 5 | from .sbd import SBDDataset, SBDEvaluationDataset 6 | -------------------------------------------------------------------------------- /isegm/data/datasets/berkeley.py: -------------------------------------------------------------------------------- 1 | from .grabcut import GrabCutDataset 2 | 3 | 4 | class BerkeleyDataset(GrabCutDataset): 5 | def __init__(self, dataset_path, **kwargs): 6 | super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs) 7 | self.name = 'Berkeley' 8 | -------------------------------------------------------------------------------- /isegm/data/datasets/davis.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class DavisDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='img', masks_dir_name='gt', 13 | init_mask_mode = None, **kwargs): 14 | super(DavisDataset, self).__init__(**kwargs) 15 | self.name = 'Davis' 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} 22 | self.init_mask_mode = init_mask_mode 23 | 24 | def get_sample(self, index) -> DSample: 25 | image_name = self.dataset_samples[index] 26 | image_path = str(self._images_path / image_name) 27 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 28 | 29 | image = cv2.imread(image_path) 30 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 31 | instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) 32 | instances_mask[instances_mask > 0] = 1 33 | 34 | init_mask = None 35 | 36 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index, init_mask=init_mask) 37 | -------------------------------------------------------------------------------- /isegm/data/datasets/grabcut.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from isegm.data.base import ISDataset 7 | from isegm.data.sample import DSample 8 | 9 | 10 | class GrabCutDataset(ISDataset): 11 | def __init__(self, dataset_path, 12 | images_dir_name='data_GT', masks_dir_name='boundary_GT', 13 | **kwargs): 14 | super(GrabCutDataset, self).__init__(**kwargs) 15 | self.name = 'GrabCut' 16 | self.dataset_path = Path(dataset_path) 17 | self._images_path = self.dataset_path / images_dir_name 18 | self._insts_path = self.dataset_path / masks_dir_name 19 | 20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] 21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} 22 | 23 | def get_sample(self, index) -> DSample: 24 | image_name = self.dataset_samples[index] 25 | image_path = str(self._images_path / image_name) 26 | mask_path = str(self._masks_paths[image_name.split('.')[0]]) 27 | 28 | image = cv2.imread(image_path) 29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32) 31 | instances_mask[instances_mask == 128] = -1 32 | instances_mask[instances_mask > 128] = 1 33 | 34 | return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) 35 | -------------------------------------------------------------------------------- /isegm/data/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | from scipy.io import loadmat 7 | 8 | from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes 9 | from isegm.data.base import ISDataset 10 | from isegm.data.sample import DSample 11 | 12 | 13 | class SBDDataset(ISDataset): 14 | def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs): 15 | super(SBDDataset, self).__init__(**kwargs) 16 | assert split in {'train', 'val'} 17 | self.name = 'SBD' 18 | self.dataset_path = Path(dataset_path) 19 | self.dataset_split = split 20 | self._images_path = self.dataset_path / 'img' 21 | self._insts_path = self.dataset_path / 'inst' 22 | self._buggy_objects = dict() 23 | self._buggy_mask_thresh = buggy_mask_thresh 24 | 25 | with open(self.dataset_path / f'{split}.txt', 'r') as f: 26 | self.dataset_samples = [x.strip() for x in f.readlines()] 27 | 28 | def get_sample(self, index): 29 | image_name = self.dataset_samples[index] 30 | image_path = str(self._images_path / f'{image_name}.jpg') 31 | inst_info_path = str(self._insts_path / f'{image_name}.mat') 32 | 33 | image = cv2.imread(image_path) 34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 35 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) 36 | instances_mask = self.remove_buggy_masks(index, instances_mask) 37 | instances_ids, _ = get_labels_with_sizes(instances_mask) 38 | 39 | return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index) 40 | 41 | def remove_buggy_masks(self, index, instances_mask): 42 | if self._buggy_mask_thresh > 0.0: 43 | buggy_image_objects = self._buggy_objects.get(index, None) 44 | if buggy_image_objects is None: 45 | buggy_image_objects = [] 46 | instances_ids, _ = get_labels_with_sizes(instances_mask) 47 | for obj_id in instances_ids: 48 | obj_mask = instances_mask == obj_id 49 | mask_area = obj_mask.sum() 50 | bbox = get_bbox_from_mask(obj_mask) 51 | bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) 52 | obj_area_ratio = mask_area / bbox_area 53 | if obj_area_ratio < self._buggy_mask_thresh: 54 | buggy_image_objects.append(obj_id) 55 | 56 | self._buggy_objects[index] = buggy_image_objects 57 | for obj_id in buggy_image_objects: 58 | instances_mask[instances_mask == obj_id] = 0 59 | 60 | return instances_mask 61 | 62 | 63 | class SBDEvaluationDataset(ISDataset): 64 | def __init__(self, dataset_path, split='val', **kwargs): 65 | super(SBDEvaluationDataset, self).__init__(**kwargs) 66 | assert split in {'train', 'val'} 67 | 68 | self.dataset_path = Path(dataset_path) 69 | self.dataset_split = split 70 | self._images_path = self.dataset_path / 'img' 71 | self._insts_path = self.dataset_path / 'inst' 72 | 73 | with open(self.dataset_path / f'{split}.txt', 'r') as f: 74 | self.dataset_samples = [x.strip() for x in f.readlines()] 75 | 76 | self.dataset_samples = self.get_sbd_images_and_ids_list() 77 | 78 | def get_sample(self, index) -> DSample: 79 | image_name, instance_id = self.dataset_samples[index] 80 | image_path = str(self._images_path / f'{image_name}.jpg') 81 | inst_info_path = str(self._insts_path / f'{image_name}.mat') 82 | 83 | image = cv2.imread(image_path) 84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 85 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) 86 | instances_mask[instances_mask != instance_id] = 0 87 | instances_mask[instances_mask > 0] = 1 88 | 89 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index) 90 | 91 | def get_sbd_images_and_ids_list(self): 92 | pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl' 93 | 94 | if pkl_path.exists(): 95 | with open(str(pkl_path), 'rb') as fp: 96 | images_and_ids_list = pkl.load(fp) 97 | else: 98 | images_and_ids_list = [] 99 | 100 | for sample in self.dataset_samples: 101 | inst_info_path = str(self._insts_path / f'{sample}.mat') 102 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) 103 | instances_ids, _ = get_labels_with_sizes(instances_mask) 104 | 105 | for instances_id in instances_ids: 106 | images_and_ids_list.append((sample, instances_id)) 107 | 108 | with open(str(pkl_path), 'wb') as fp: 109 | pkl.dump(images_and_ids_list, fp) 110 | 111 | return images_and_ids_list 112 | -------------------------------------------------------------------------------- /isegm/data/points_sampler.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import random 4 | import numpy as np 5 | from functools import lru_cache 6 | from .sample import DSample 7 | 8 | 9 | class BasePointSampler: 10 | def __init__(self): 11 | self._selected_mask = None 12 | self._selected_masks = None 13 | 14 | def sample_object(self, sample: DSample): 15 | raise NotImplementedError 16 | 17 | def sample_points(self): 18 | raise NotImplementedError 19 | 20 | @property 21 | def selected_mask(self): 22 | assert self._selected_mask is not None 23 | return self._selected_mask 24 | 25 | @selected_mask.setter 26 | def selected_mask(self, mask): 27 | self._selected_mask = mask[np.newaxis, :].astype(np.float32) 28 | 29 | 30 | class MultiPointSampler(BasePointSampler): 31 | def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1, 32 | positive_erode_prob=0.9, positive_erode_iters=3, 33 | negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5, 34 | merge_objects_prob=0.0, max_num_merged_objects=2, 35 | use_hierarchy=False, soft_targets=False, 36 | first_click_center=False, only_one_first_click=False, 37 | sfc_inner_k=1.7, sfc_full_inner_prob=0.0): 38 | super().__init__() 39 | self.max_num_points = max_num_points 40 | self.expand_ratio = expand_ratio 41 | self.positive_erode_prob = positive_erode_prob 42 | self.positive_erode_iters = positive_erode_iters 43 | self.merge_objects_prob = merge_objects_prob 44 | self.use_hierarchy = use_hierarchy 45 | self.soft_targets = soft_targets 46 | self.first_click_center = first_click_center 47 | self.only_one_first_click = only_one_first_click 48 | self.sfc_inner_k = sfc_inner_k 49 | self.sfc_full_inner_prob = sfc_full_inner_prob 50 | 51 | if max_num_merged_objects == -1: 52 | max_num_merged_objects = max_num_points 53 | self.max_num_merged_objects = max_num_merged_objects 54 | 55 | self.neg_strategies = ['bg', 'other', 'border'] 56 | self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob] 57 | assert math.isclose(sum(self.neg_strategies_prob), 1.0) 58 | 59 | self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma) 60 | self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma) 61 | self._neg_masks = None 62 | 63 | def sample_object(self, sample: DSample): 64 | if len(sample) == 0: 65 | bg_mask = sample.get_background_mask() 66 | self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32) 67 | self._selected_masks = [[]] 68 | self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies} 69 | self._neg_masks['required'] = [] 70 | return 71 | 72 | gt_mask, pos_masks, neg_masks = self._sample_mask(sample) 73 | binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0 74 | 75 | self.selected_mask = gt_mask 76 | self._selected_masks = pos_masks 77 | 78 | neg_mask_bg = np.logical_not(binary_gt_mask) 79 | neg_mask_border = self._get_border_mask(binary_gt_mask) 80 | if len(sample) <= len(self._selected_masks): 81 | neg_mask_other = neg_mask_bg 82 | else: 83 | neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()), 84 | np.logical_not(binary_gt_mask)) 85 | 86 | self._neg_masks = { 87 | 'bg': neg_mask_bg, 88 | 'other': neg_mask_other, 89 | 'border': neg_mask_border, 90 | 'required': neg_masks 91 | } 92 | 93 | def _sample_mask(self, sample: DSample): 94 | root_obj_ids = sample.root_objects 95 | 96 | if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob: 97 | max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects) 98 | num_selected_objects = np.random.randint(2, max_selected_objects + 1) 99 | random_ids = random.sample(root_obj_ids, num_selected_objects) 100 | else: 101 | random_ids = [random.choice(root_obj_ids)] 102 | 103 | gt_mask = None 104 | pos_segments = [] 105 | neg_segments = [] 106 | for obj_id in random_ids: 107 | obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample) 108 | if gt_mask is None: 109 | gt_mask = obj_gt_mask 110 | else: 111 | gt_mask = np.maximum(gt_mask, obj_gt_mask) 112 | 113 | pos_segments.extend(obj_pos_segments) 114 | neg_segments.extend(obj_neg_segments) 115 | 116 | pos_masks = [self._positive_erode(x) for x in pos_segments] 117 | neg_masks = [self._positive_erode(x) for x in neg_segments] 118 | 119 | return gt_mask, pos_masks, neg_masks 120 | 121 | def _sample_from_masks_layer(self, obj_id, sample: DSample): 122 | objs_tree = sample._objects 123 | 124 | if not self.use_hierarchy: 125 | node_mask = sample.get_object_mask(obj_id) 126 | gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask 127 | return gt_mask, [node_mask], [] 128 | 129 | def _select_node(node_id): 130 | node_info = objs_tree[node_id] 131 | if not node_info['children'] or random.random() < 0.5: 132 | return node_id 133 | return _select_node(random.choice(node_info['children'])) 134 | 135 | selected_node = _select_node(obj_id) 136 | node_info = objs_tree[selected_node] 137 | node_mask = sample.get_object_mask(selected_node) 138 | gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask 139 | pos_mask = node_mask.copy() 140 | 141 | negative_segments = [] 142 | if node_info['parent'] is not None and node_info['parent'] in objs_tree: 143 | parent_mask = sample.get_object_mask(node_info['parent']) 144 | #negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask))) 145 | 146 | for child_id in node_info['children']: 147 | if objs_tree[child_id]['area'] / node_info['area'] < 0.10: 148 | child_mask = sample.get_object_mask(child_id) 149 | pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) 150 | 151 | if node_info['children']: 152 | max_disabled_children = min(len(node_info['children']), 3) 153 | num_disabled_children = np.random.randint(0, max_disabled_children + 1) 154 | disabled_children = random.sample(node_info['children'], num_disabled_children) 155 | 156 | for child_id in disabled_children: 157 | child_mask = sample.get_object_mask(child_id) 158 | pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) 159 | if self.soft_targets: 160 | soft_child_mask = sample.get_soft_object_mask(child_id) 161 | gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask) 162 | else: 163 | gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask)) 164 | #negative_segments.append(child_mask) 165 | 166 | return gt_mask, [pos_mask], negative_segments 167 | 168 | def sample_points(self): 169 | assert self._selected_mask is not None 170 | 171 | pos_points = self._multi_mask_sample_points(self._selected_masks, 172 | is_negative=[False] * len(self._selected_masks), 173 | with_first_click=self.first_click_center) 174 | 175 | neg_strategy = [(self._neg_masks[k], prob) 176 | for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)] 177 | neg_masks = self._neg_masks['required'] + [neg_strategy] 178 | neg_points = self._multi_mask_sample_points(neg_masks, 179 | is_negative=[True] * len(self._neg_masks['required']) + [True]) 180 | #rint('selected :', len(self._selected_masks)) 181 | #print('neg_masks : ', len(neg_masks)) 182 | 183 | return pos_points + neg_points 184 | 185 | def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): 186 | selected_masks = selected_masks[:self.max_num_points] 187 | 188 | each_obj_points = [ 189 | self._sample_points(mask, is_negative=is_negative[i], 190 | with_first_click=with_first_click) 191 | for i, mask in enumerate(selected_masks) 192 | ] 193 | each_obj_points = [x for x in each_obj_points if len(x) > 0] 194 | 195 | points = [] 196 | if len(each_obj_points) == 1: 197 | points = each_obj_points[0] 198 | elif len(each_obj_points) > 1: 199 | if self.only_one_first_click: 200 | each_obj_points = each_obj_points[:1] 201 | 202 | points = [obj_points[0] for obj_points in each_obj_points] 203 | 204 | aggregated_masks_with_prob = [] 205 | for indx, x in enumerate(selected_masks): 206 | if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): 207 | for t, prob in x: 208 | aggregated_masks_with_prob.append((t, prob / len(selected_masks))) 209 | else: 210 | aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) 211 | 212 | other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) 213 | if len(other_points_union) + len(points) <= self.max_num_points: 214 | points.extend(other_points_union) 215 | else: 216 | points.extend(random.sample(other_points_union, self.max_num_points - len(points))) 217 | 218 | if len(points) < self.max_num_points: 219 | points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) 220 | 221 | return points 222 | 223 | def _sample_points(self, mask, is_negative=False, with_first_click=False): 224 | if is_negative: 225 | num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs) 226 | else: 227 | num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) 228 | 229 | indices_probs = None 230 | if isinstance(mask, (list, tuple)): 231 | indices_probs = [x[1] for x in mask] 232 | indices = [(np.argwhere(x), prob) for x, prob in mask] 233 | if indices_probs: 234 | assert math.isclose(sum(indices_probs), 1.0) 235 | else: 236 | indices = np.argwhere(mask) 237 | 238 | points = [] 239 | for j in range(num_points): 240 | first_click = with_first_click and j == 0 and indices_probs is None 241 | 242 | if first_click: 243 | point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob) 244 | elif indices_probs: 245 | point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs) 246 | point_indices = indices[point_indices_indx][0] 247 | else: 248 | point_indices = indices 249 | 250 | num_indices = len(point_indices) 251 | if num_indices > 0: 252 | point_indx = 0 if first_click else 100 253 | click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx] 254 | points.append(click) 255 | 256 | return points 257 | 258 | def _positive_erode(self, mask): 259 | if random.random() > self.positive_erode_prob: 260 | return mask 261 | 262 | kernel = np.ones((3, 3), np.uint8) 263 | eroded_mask = cv2.erode(mask.astype(np.uint8), 264 | kernel, iterations=self.positive_erode_iters).astype(np.bool) 265 | 266 | if eroded_mask.sum() > 10: 267 | return eroded_mask 268 | else: 269 | return mask 270 | 271 | def _get_border_mask(self, mask): 272 | expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum()))) 273 | kernel = np.ones((3, 3), np.uint8) 274 | expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r) 275 | expanded_mask[mask.astype(np.bool)] = 0 276 | return expanded_mask 277 | 278 | 279 | @lru_cache(maxsize=None) 280 | def generate_probs(max_num_points, gamma): 281 | probs = [] 282 | last_value = 1 283 | for i in range(max_num_points): 284 | probs.append(last_value) 285 | last_value *= gamma 286 | 287 | probs = np.array(probs) 288 | probs /= probs.sum() 289 | 290 | return probs 291 | 292 | 293 | def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): 294 | if full_prob > 0 and random.random() < full_prob: 295 | return obj_mask 296 | 297 | padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant') 298 | 299 | dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1] 300 | if k > 0: 301 | inner_mask = dt > dt.max() / k 302 | return np.argwhere(inner_mask) 303 | else: 304 | prob_map = dt.flatten() 305 | prob_map /= max(prob_map.sum(), 1e-6) 306 | click_indx = np.random.choice(len(prob_map), p=prob_map) 307 | click_coords = np.unravel_index(click_indx, dt.shape) 308 | return np.array([click_coords]) 309 | -------------------------------------------------------------------------------- /isegm/data/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from isegm.utils.misc import get_labels_with_sizes 4 | from isegm.data.transforms import remove_image_only_transforms 5 | from albumentations import ReplayCompose 6 | 7 | class DSample: 8 | def __init__(self, image, encoded_masks, objects=None, 9 | objects_ids=None, ignore_ids=None, sample_id=None, 10 | init_mask = None): 11 | self.image = image 12 | self.sample_id = sample_id 13 | self.init_mask = init_mask 14 | 15 | if len(encoded_masks.shape) == 2: 16 | encoded_masks = encoded_masks[:, :, np.newaxis] 17 | self._encoded_masks = encoded_masks 18 | self._ignored_regions = [] 19 | 20 | if objects_ids is not None: 21 | if not objects_ids or not isinstance(objects_ids[0], tuple): 22 | assert encoded_masks.shape[2] == 1 23 | objects_ids = [(0, obj_id) for obj_id in objects_ids] 24 | 25 | self._objects = dict() 26 | for indx, obj_mapping in enumerate(objects_ids): 27 | self._objects[indx] = { 28 | 'parent': None, 29 | 'mapping': obj_mapping, 30 | 'children': [] 31 | } 32 | 33 | if ignore_ids: 34 | if isinstance(ignore_ids[0], tuple): 35 | self._ignored_regions = ignore_ids 36 | else: 37 | self._ignored_regions = [(0, region_id) for region_id in ignore_ids] 38 | else: 39 | self._objects = deepcopy(objects) 40 | 41 | self._augmented = False 42 | self._soft_mask_aug = None 43 | self._original_data = self.image, self._encoded_masks, deepcopy(self._objects) 44 | 45 | 46 | def augment(self, augmentator): 47 | self.reset_augmentation() 48 | aug_output = augmentator(image=self.image, mask=self._encoded_masks) 49 | image, mask = aug_output['image'],aug_output['mask'] 50 | self.image = image 51 | self._encoded_masks = mask 52 | self._compute_objects_areas() 53 | self.remove_small_objects(min_area=1) 54 | self._augmented = True 55 | 56 | 57 | def reset_augmentation(self): 58 | if not self._augmented: 59 | return 60 | orig_image, orig_masks, orig_objects = self._original_data 61 | self.image = orig_image 62 | self._encoded_masks = orig_masks 63 | self._objects = deepcopy(orig_objects) 64 | self._augmented = False 65 | self._soft_mask_aug = None 66 | 67 | def remove_small_objects(self, min_area): 68 | if self._objects and not 'area' in list(self._objects.values())[0]: 69 | self._compute_objects_areas() 70 | 71 | for obj_id, obj_info in list(self._objects.items()): 72 | if obj_info['area'] < min_area: 73 | self._remove_object(obj_id) 74 | 75 | def get_object_mask(self, obj_id): 76 | layer_indx, mask_id = self._objects[obj_id]['mapping'] 77 | obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) 78 | if self._ignored_regions: 79 | for layer_indx, mask_id in self._ignored_regions: 80 | ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id 81 | obj_mask[ignore_mask] = -1 82 | 83 | return obj_mask 84 | 85 | def get_soft_object_mask(self, obj_id): 86 | assert self._soft_mask_aug is not None 87 | original_encoded_masks = self._original_data[1] 88 | layer_indx, mask_id = self._objects[obj_id]['mapping'] 89 | obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32) 90 | obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image'] 91 | return np.clip(obj_mask, 0, 1) 92 | 93 | def get_background_mask(self): 94 | return np.max(self._encoded_masks, axis=2) == 0 95 | 96 | @property 97 | def objects_ids(self): 98 | return list(self._objects.keys()) 99 | 100 | @property 101 | def gt_mask(self): 102 | assert len(self._objects) == 1 103 | return self.get_object_mask(self.objects_ids[0]) 104 | 105 | @property 106 | def root_objects(self): 107 | return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None] 108 | 109 | def _compute_objects_areas(self): 110 | inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()} 111 | ignored_regions_keys = set(self._ignored_regions) 112 | 113 | for layer_indx in range(self._encoded_masks.shape[2]): 114 | objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx]) 115 | for obj_id, obj_area in zip(objects_ids, objects_areas): 116 | inv_key = (layer_indx, obj_id) 117 | if inv_key in ignored_regions_keys: 118 | continue 119 | try: 120 | self._objects[inverse_index[inv_key]]['area'] = obj_area 121 | del inverse_index[inv_key] 122 | except KeyError: 123 | layer = self._encoded_masks[:, :, layer_indx] 124 | layer[layer == obj_id] = 0 125 | self._encoded_masks[:, :, layer_indx] = layer 126 | 127 | for obj_id in inverse_index.values(): 128 | self._objects[obj_id]['area'] = 0 129 | 130 | def _remove_object(self, obj_id): 131 | obj_info = self._objects[obj_id] 132 | obj_parent = obj_info['parent'] 133 | for child_id in obj_info['children']: 134 | self._objects[child_id]['parent'] = obj_parent 135 | 136 | if obj_parent is not None: 137 | parent_children = self._objects[obj_parent]['children'] 138 | parent_children = [x for x in parent_children if x != obj_id] 139 | self._objects[obj_parent]['children'] = parent_children + obj_info['children'] 140 | 141 | del self._objects[obj_id] 142 | 143 | def __len__(self): 144 | return len(self._objects) 145 | -------------------------------------------------------------------------------- /isegm/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | 5 | from albumentations.core.serialization import SERIALIZABLE_REGISTRY 6 | from albumentations import ImageOnlyTransform, DualTransform 7 | from albumentations.core.transforms_interface import to_tuple 8 | from albumentations.augmentations import functional as F 9 | from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes 10 | 11 | 12 | class UniformRandomResize(DualTransform): 13 | def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): 14 | super().__init__(always_apply, p) 15 | self.scale_range = scale_range 16 | self.interpolation = interpolation 17 | 18 | def get_params_dependent_on_targets(self, params): 19 | scale = random.uniform(*self.scale_range) 20 | height = int(round(params['image'].shape[0] * scale)) 21 | width = int(round(params['image'].shape[1] * scale)) 22 | return {'new_height': height, 'new_width': width} 23 | 24 | def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params): 25 | return F.resize(img, height=new_height, width=new_width, interpolation=interpolation) 26 | 27 | def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): 28 | scale_x = new_width / params["cols"] 29 | scale_y = new_height / params["rows"] 30 | return F.keypoint_scale(keypoint, scale_x, scale_y) 31 | 32 | def get_transform_init_args_names(self): 33 | return "scale_range", "interpolation" 34 | 35 | @property 36 | def targets_as_params(self): 37 | return ["image"] 38 | 39 | 40 | class ZoomIn(DualTransform): 41 | def __init__( 42 | self, 43 | height, 44 | width, 45 | bbox_jitter=0.1, 46 | expansion_ratio=1.4, 47 | min_crop_size=200, 48 | min_area=100, 49 | always_resize=False, 50 | always_apply=False, 51 | p=0.5, 52 | ): 53 | super(ZoomIn, self).__init__(always_apply, p) 54 | self.height = height 55 | self.width = width 56 | self.bbox_jitter = to_tuple(bbox_jitter) 57 | self.expansion_ratio = expansion_ratio 58 | self.min_crop_size = min_crop_size 59 | self.min_area = min_area 60 | self.always_resize = always_resize 61 | 62 | def apply(self, img, selected_object, bbox, **params): 63 | if selected_object is None: 64 | if self.always_resize: 65 | img = F.resize(img, height=self.height, width=self.width) 66 | return img 67 | 68 | rmin, rmax, cmin, cmax = bbox 69 | img = img[rmin:rmax + 1, cmin:cmax + 1] 70 | img = F.resize(img, height=self.height, width=self.width) 71 | 72 | return img 73 | 74 | def apply_to_mask(self, mask, selected_object, bbox, **params): 75 | if selected_object is None: 76 | if self.always_resize: 77 | mask = F.resize(mask, height=self.height, width=self.width, 78 | interpolation=cv2.INTER_NEAREST) 79 | return mask 80 | 81 | rmin, rmax, cmin, cmax = bbox 82 | mask = mask[rmin:rmax + 1, cmin:cmax + 1] 83 | if isinstance(selected_object, tuple): 84 | layer_indx, mask_id = selected_object 85 | obj_mask = mask[:, :, layer_indx] == mask_id 86 | new_mask = np.zeros_like(mask) 87 | new_mask[:, :, layer_indx][obj_mask] = mask_id 88 | else: 89 | obj_mask = mask == selected_object 90 | new_mask = mask.copy() 91 | new_mask[np.logical_not(obj_mask)] = 0 92 | 93 | new_mask = F.resize(new_mask, height=self.height, width=self.width, 94 | interpolation=cv2.INTER_NEAREST) 95 | return new_mask 96 | 97 | def get_params_dependent_on_targets(self, params): 98 | instances = params['mask'] 99 | 100 | is_mask_layer = len(instances.shape) > 2 101 | candidates = [] 102 | if is_mask_layer: 103 | for layer_indx in range(instances.shape[2]): 104 | labels, areas = get_labels_with_sizes(instances[:, :, layer_indx]) 105 | candidates.extend([(layer_indx, obj_id) 106 | for obj_id, area in zip(labels, areas) 107 | if area > self.min_area]) 108 | else: 109 | labels, areas = get_labels_with_sizes(instances) 110 | candidates = [obj_id for obj_id, area in zip(labels, areas) 111 | if area > self.min_area] 112 | 113 | selected_object = None 114 | bbox = None 115 | if candidates: 116 | selected_object = random.choice(candidates) 117 | if is_mask_layer: 118 | layer_indx, mask_id = selected_object 119 | obj_mask = instances[:, :, layer_indx] == mask_id 120 | else: 121 | obj_mask = instances == selected_object 122 | 123 | bbox = get_bbox_from_mask(obj_mask) 124 | 125 | if isinstance(self.expansion_ratio, tuple): 126 | expansion_ratio = random.uniform(*self.expansion_ratio) 127 | else: 128 | expansion_ratio = self.expansion_ratio 129 | 130 | bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size) 131 | bbox = self._jitter_bbox(bbox) 132 | bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1) 133 | 134 | return { 135 | 'selected_object': selected_object, 136 | 'bbox': bbox 137 | } 138 | 139 | def _jitter_bbox(self, bbox): 140 | rmin, rmax, cmin, cmax = bbox 141 | height = rmax - rmin + 1 142 | width = cmax - cmin + 1 143 | rmin = int(rmin + random.uniform(*self.bbox_jitter) * height) 144 | rmax = int(rmax + random.uniform(*self.bbox_jitter) * height) 145 | cmin = int(cmin + random.uniform(*self.bbox_jitter) * width) 146 | cmax = int(cmax + random.uniform(*self.bbox_jitter) * width) 147 | 148 | return rmin, rmax, cmin, cmax 149 | 150 | def apply_to_bbox(self, bbox, **params): 151 | raise NotImplementedError 152 | 153 | def apply_to_keypoint(self, keypoint, **params): 154 | raise NotImplementedError 155 | 156 | @property 157 | def targets_as_params(self): 158 | return ["mask"] 159 | 160 | def get_transform_init_args_names(self): 161 | return ("height", "width", "bbox_jitter", 162 | "expansion_ratio", "min_crop_size", "min_area", "always_resize") 163 | 164 | 165 | def remove_image_only_transforms(sdict): 166 | if not 'transforms' in sdict: 167 | return sdict 168 | 169 | keep_transforms = [] 170 | for tdict in sdict['transforms']: 171 | cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']] 172 | if 'transforms' in tdict: 173 | keep_transforms.append(remove_image_only_transforms(tdict)) 174 | elif not issubclass(cls, ImageOnlyTransform): 175 | keep_transforms.append(tdict) 176 | sdict['transforms'] = keep_transforms 177 | 178 | return sdict 179 | -------------------------------------------------------------------------------- /isegm/engine/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from isegm.utils.log import logger 4 | 5 | 6 | def get_optimizer(model, opt_name, opt_kwargs): 7 | params = [] 8 | base_lr = opt_kwargs['lr'] 9 | for name, param in model.named_parameters(): 10 | param_group = {'params': [param]} 11 | if not param.requires_grad: 12 | params.append(param_group) 13 | continue 14 | 15 | if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): 16 | logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') 17 | param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult 18 | 19 | params.append(param_group) 20 | 21 | optimizer = { 22 | 'sgd': torch.optim.SGD, 23 | 'adam': torch.optim.Adam, 24 | 'adamw': torch.optim.AdamW 25 | }[opt_name.lower()](params, **opt_kwargs) 26 | 27 | return optimizer 28 | -------------------------------------------------------------------------------- /isegm/inference/clicker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import cv2 4 | 5 | 6 | class Clicker(object): 7 | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): 8 | self.click_indx_offset = click_indx_offset 9 | if gt_mask is not None: 10 | self.gt_mask = gt_mask == 1 11 | self.not_ignore_mask = gt_mask != ignore_label 12 | else: 13 | self.gt_mask = None 14 | 15 | self.reset_clicks() 16 | 17 | if init_clicks is not None: 18 | for click in init_clicks: 19 | self.add_click(click) 20 | 21 | def make_next_click(self, pred_mask): 22 | assert self.gt_mask is not None 23 | click = self._get_next_click(pred_mask) 24 | self.add_click(click) 25 | 26 | def get_clicks(self, clicks_limit=None): 27 | return self.clicks_list[:clicks_limit] 28 | 29 | def _get_next_click(self, pred_mask, padding=True): 30 | fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) 31 | fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) 32 | 33 | if padding: 34 | fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') 35 | fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') 36 | 37 | fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) 38 | fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) 39 | 40 | if padding: 41 | fn_mask_dt = fn_mask_dt[1:-1, 1:-1] 42 | fp_mask_dt = fp_mask_dt[1:-1, 1:-1] 43 | 44 | fn_mask_dt = fn_mask_dt * self.not_clicked_map 45 | fp_mask_dt = fp_mask_dt * self.not_clicked_map 46 | 47 | fn_max_dist = np.max(fn_mask_dt) 48 | fp_max_dist = np.max(fp_mask_dt) 49 | 50 | is_positive = fn_max_dist > fp_max_dist 51 | if is_positive: 52 | coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] 53 | else: 54 | coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] 55 | 56 | return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) 57 | 58 | def add_click(self, click): 59 | coords = click.coords 60 | 61 | click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks 62 | if click.is_positive: 63 | self.num_pos_clicks += 1 64 | else: 65 | self.num_neg_clicks += 1 66 | 67 | self.clicks_list.append(click) 68 | if self.gt_mask is not None: 69 | self.not_clicked_map[coords[0], coords[1]] = False 70 | 71 | def _remove_last_click(self): 72 | click = self.clicks_list.pop() 73 | coords = click.coords 74 | 75 | if click.is_positive: 76 | self.num_pos_clicks -= 1 77 | else: 78 | self.num_neg_clicks -= 1 79 | 80 | if self.gt_mask is not None: 81 | self.not_clicked_map[coords[0], coords[1]] = True 82 | 83 | def reset_clicks(self): 84 | if self.gt_mask is not None: 85 | self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) 86 | 87 | self.num_pos_clicks = 0 88 | self.num_neg_clicks = 0 89 | 90 | self.clicks_list = [] 91 | 92 | def get_state(self): 93 | return deepcopy(self.clicks_list) 94 | 95 | def set_state(self, state): 96 | self.reset_clicks() 97 | for click in state: 98 | self.add_click(click) 99 | 100 | def __len__(self): 101 | return len(self.clicks_list) 102 | 103 | 104 | class Click: 105 | def __init__(self, is_positive, coords, indx=None): 106 | self.is_positive = is_positive 107 | self.coords = coords 108 | self.indx = indx 109 | 110 | @property 111 | def coords_and_indx(self): 112 | return (*self.coords, self.indx) 113 | 114 | def copy(self, **kwargs): 115 | self_copy = deepcopy(self) 116 | for k, v in kwargs.items(): 117 | setattr(self_copy, k, v) 118 | return self_copy 119 | -------------------------------------------------------------------------------- /isegm/inference/evaluation.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | from isegm.inference import utils 7 | from isegm.inference.clicker import Clicker 8 | import shutil 9 | import cv2 10 | from isegm.utils.vis import add_tag 11 | 12 | 13 | 14 | try: 15 | get_ipython() 16 | from tqdm import tqdm_notebook as tqdm 17 | except NameError: 18 | from tqdm import tqdm 19 | 20 | 21 | def evaluate_dataset(dataset, predictor, vis = True, vis_path = './experiments/vis_val/',**kwargs): 22 | all_ious = [] 23 | if vis: 24 | save_dir = vis_path + dataset.name + '/' 25 | if os.path.exists(save_dir): 26 | shutil.rmtree(save_dir) 27 | os.makedirs(save_dir) 28 | else: 29 | save_dir = None 30 | 31 | start_time = time() 32 | for index in tqdm(range(len(dataset)), leave=False): 33 | sample = dataset.get_sample(index) 34 | 35 | _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, sample.init_mask, predictor, 36 | sample_id=index, vis= vis, save_dir = save_dir, 37 | index = index, **kwargs) 38 | all_ious.append(sample_ious) 39 | end_time = time() 40 | elapsed_time = end_time - start_time 41 | 42 | return all_ious, elapsed_time 43 | 44 | def Progressive_Merge(pred_mask, previous_mask, y, x): 45 | diff_regions = np.logical_xor(previous_mask, pred_mask) 46 | num, labels = cv2.connectedComponents(diff_regions.astype(np.uint8)) 47 | label = labels[y,x] 48 | corr_mask = labels == label 49 | if previous_mask[y,x] == 1: 50 | progressive_mask = np.logical_and( previous_mask, np.logical_not(corr_mask)) 51 | else: 52 | progressive_mask = np.logical_or( previous_mask, corr_mask) 53 | return progressive_mask 54 | 55 | 56 | def evaluate_sample(image, gt_mask, init_mask, predictor, max_iou_thr, 57 | pred_thr=0.49, min_clicks=1, max_clicks=20, 58 | sample_id=None, vis = True, save_dir = None, index = 0, callback=None, 59 | progressive_mode = True, 60 | ): 61 | clicker = Clicker(gt_mask=gt_mask) 62 | pred_mask = np.zeros_like(gt_mask) 63 | prev_mask = pred_mask 64 | ious_list = [] 65 | 66 | with torch.no_grad(): 67 | predictor.set_input_image(image) 68 | if init_mask is not None: 69 | predictor.set_prev_mask(init_mask) 70 | pred_mask = init_mask 71 | prev_mask = init_mask 72 | num_pm = 0 73 | else: 74 | num_pm = 999 75 | 76 | for click_indx in range(max_clicks): 77 | vis_pred = prev_mask 78 | clicker.make_next_click(pred_mask) 79 | pred_probs = predictor.get_prediction(clicker) 80 | pred_mask = pred_probs > pred_thr 81 | 82 | if progressive_mode: 83 | clicks = clicker.get_clicks() 84 | if len(clicks) >= num_pm: 85 | last_click = clicks[-1] 86 | last_y, last_x = last_click.coords[0], last_click.coords[1] 87 | pred_mask = Progressive_Merge(pred_mask, prev_mask,last_y, last_x) 88 | predictor.transforms[0]._prev_probs = np.expand_dims(np.expand_dims(pred_mask,0),0) 89 | if callback is not None: 90 | callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) 91 | 92 | iou = utils.get_iou(gt_mask, pred_mask) 93 | ious_list.append(iou) 94 | prev_mask = pred_mask 95 | 96 | if iou >= max_iou_thr and click_indx + 1 >= min_clicks: 97 | break 98 | 99 | if vis: 100 | clicks_list = clicker.get_clicks() 101 | last_y, last_x = predictor.last_y, predictor.last_x 102 | out_image = vis_result_base(image, pred_mask, gt_mask, init_mask, iou,click_indx+1,clicks_list, vis_pred, last_y, last_x) 103 | cv2.imwrite(save_dir+str(index)+'.png', out_image) 104 | return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs 105 | 106 | 107 | def vis_result_base(image, pred_mask, instances_mask, init_mask, iou, num_clicks, clicks_list, prev_prediction, last_y, last_x): 108 | 109 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 110 | 111 | pred_mask = pred_mask.astype(np.float32) 112 | prev_mask = prev_prediction.astype(np.float32) 113 | instances_mask = instances_mask.astype(np.float32) 114 | image = image.astype(np.float32) 115 | 116 | pred_mask_3 = np.repeat(pred_mask[...,np.newaxis],3,2) 117 | prev_mask_3 = np.repeat(prev_mask[...,np.newaxis],3,2) 118 | gt_mask_3 = np.repeat( instances_mask[...,np.newaxis],3,2 ) 119 | 120 | color_mask_gt = np.zeros_like(pred_mask_3) 121 | color_mask_gt[:,:,0] = instances_mask * 255 122 | 123 | color_mask_pred = np.zeros_like(pred_mask_3) #+ 255 124 | color_mask_pred[:,:,0] = pred_mask * 255 125 | 126 | color_mask_prev = np.zeros_like(prev_mask_3) #+ 255 127 | color_mask_prev[:,:,0] = prev_mask * 255 128 | 129 | 130 | fusion_pred = image * 0.4 + color_mask_pred * 0.6 131 | fusion_pred = image * (1-pred_mask_3) + fusion_pred * pred_mask_3 132 | 133 | fusion_prev = image * 0.4 + color_mask_prev * 0.6 134 | fusion_prev = image * (1-prev_mask_3) + fusion_prev * prev_mask_3 135 | 136 | 137 | fusion_gt = image * 0.4 + color_mask_gt * 0.6 138 | 139 | color_mask_init = np.zeros_like(pred_mask_3) 140 | if init_mask is not None: 141 | color_mask_init[:,:,0] = init_mask * 255 142 | 143 | fusion_init = image * 0.4 + color_mask_init * 0.6 144 | fusion_init = image * (1-color_mask_init) + fusion_init * color_mask_init 145 | 146 | 147 | #cv2.putText( image, 'click num: '+str(num_clicks)+ ' iou: '+ str(round(iou,3)), (50,50), 148 | # cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255 ), 1 ) 149 | 150 | for i in range(len(clicks_list)): 151 | click_tuple = clicks_list[i] 152 | 153 | if click_tuple.is_positive: 154 | color = (0,0,255) 155 | else: 156 | color = (0,255,0) 157 | 158 | coord = click_tuple.coords 159 | x,y = coord[1], coord[0] 160 | if x < 0 or y< 0: 161 | continue 162 | cv2.circle(fusion_pred,(x,y),4,color,-1) 163 | #cv2.putText(fusion_pred, str(i+1), (x-10, y-10), cv2.FONT_HERSHEY_COMPLEX, 0.6 , color,1 ) 164 | 165 | cv2.circle(fusion_pred,(last_x,last_y),2,(255,255,255),-1) 166 | image = add_tag(image, 'nclicks:'+str(num_clicks)+ ' iou:'+ str(round(iou,3))) 167 | fusion_init = add_tag(fusion_init,'init mask') 168 | fusion_pred = add_tag(fusion_pred,'pred') 169 | fusion_gt = add_tag(fusion_gt,'gt') 170 | fusion_prev = add_tag(fusion_prev,'prev pred') 171 | 172 | h,w = image.shape[0],image.shape[1] 173 | if h < w: 174 | out_image = cv2.hconcat([image.astype(np.float32),fusion_init.astype(np.float32),fusion_pred.astype(np.float32), fusion_gt.astype(np.float32),fusion_prev.astype(np.float32)]) 175 | else: 176 | out_image = cv2.hconcat([image.astype(np.float32),fusion_init.astype(np.float32), fusion_pred.astype(np.float32), fusion_gt.astype(np.float32),fusion_prev.astype(np.float32)]) 177 | 178 | return out_image 179 | 180 | -------------------------------------------------------------------------------- /isegm/inference/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import BaselinePredictor 2 | from isegm.inference.transforms import ZoomIn 3 | 4 | 5 | 6 | def get_predictor(net, brs_mode, device, 7 | prob_thresh=0.49, 8 | infer_size = 256, 9 | focus_crop_r= 1.4, 10 | with_flip=False, 11 | zoom_in_params=dict(), 12 | predictor_params=None, 13 | brs_opt_func_params=None, 14 | lbfgs_params=None): 15 | 16 | predictor_params_ = { 17 | 'optimize_after_n_clicks': 1 18 | } 19 | 20 | if zoom_in_params is not None: 21 | zoom_in = ZoomIn(**zoom_in_params) 22 | else: 23 | zoom_in = None 24 | 25 | if predictor_params is not None: 26 | predictor_params_.update(predictor_params) 27 | predictor = BaselinePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, infer_size =infer_size, **predictor_params_) 28 | 29 | 30 | 31 | return predictor 32 | -------------------------------------------------------------------------------- /isegm/inference/predictors/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide, ResizeTrans 5 | 6 | 7 | class BaselinePredictor(object): 8 | def __init__(self, model, device, 9 | net_clicks_limit=None, 10 | with_flip=False, 11 | zoom_in=None, 12 | max_size=None, 13 | infer_size = 384, 14 | **kwargs): 15 | self.with_flip = with_flip 16 | self.net_clicks_limit = net_clicks_limit 17 | self.original_image = None 18 | self.device = device 19 | self.zoom_in = zoom_in 20 | self.prev_prediction = None 21 | self.model_indx = 0 22 | self.click_models = None 23 | self.net_state_dict = None 24 | 25 | if isinstance(model, tuple): 26 | self.net, self.click_models = model 27 | else: 28 | self.net = model 29 | 30 | self.to_tensor = transforms.ToTensor() 31 | 32 | self.transforms = [zoom_in] if zoom_in is not None else [] 33 | if max_size is not None: 34 | self.transforms.append(LimitLongestSide(max_size=max_size)) 35 | self.crop_l = infer_size 36 | self.transforms.append(ResizeTrans(self.crop_l)) 37 | self.transforms.append(SigmoidForPred()) 38 | self.focus_roi = None 39 | self.global_roi = None 40 | self.with_flip = True 41 | if hasattr(self.net, 'set_status'): 42 | self.net.set_status(training=False) 43 | 44 | def set_input_image(self, image): 45 | image_nd = self.to_tensor(image) 46 | for transform in self.transforms: 47 | transform.reset() 48 | self.original_image = image_nd.to(self.device) 49 | if len(self.original_image.shape) == 3: 50 | self.original_image = self.original_image.unsqueeze(0) 51 | self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) 52 | 53 | def set_prev_mask(self, mask): 54 | self.prev_prediction = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(self.device).float() 55 | 56 | def get_prediction(self, clicker, prev_mask=None): 57 | clicks_list = clicker.get_clicks() 58 | click = clicks_list[-1] 59 | last_y,last_x = click.coords[0],click.coords[1] 60 | self.last_y = last_y 61 | self.last_x = last_x 62 | 63 | if self.click_models is not None: 64 | model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 65 | if model_indx != self.model_indx: 66 | self.model_indx = model_indx 67 | self.net = self.click_models[model_indx] 68 | 69 | input_image = self.original_image 70 | if prev_mask is None: 71 | prev_mask = self.prev_prediction 72 | if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: 73 | input_image = torch.cat((input_image, prev_mask), dim=1) 74 | 75 | 76 | image_nd, clicks_lists, is_image_changed = self.apply_transforms( 77 | input_image, [clicks_list] 78 | ) 79 | 80 | pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) 81 | prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, 82 | size=image_nd.size()[2:]) 83 | 84 | for t in reversed(self.transforms): 85 | prediction = t.inv_transform(prediction) 86 | 87 | self.prev_prediction = prediction 88 | return prediction.cpu().numpy()[0, 0] 89 | 90 | def _get_prediction(self, image_nd, clicks_lists, is_image_changed): 91 | points_nd = self.get_points_nd(clicks_lists) 92 | output = self.net(image_nd, points_nd) 93 | return output['instances'] 94 | 95 | def mapp_roi(self, focus_roi, global_roi): 96 | yg1,yg2,xg1,xg2 = global_roi 97 | hg,wg = yg2-yg1, xg2-xg1 98 | yf1,yf2,xf1,xf2 = focus_roi 99 | 100 | ''' 101 | yf1_n = (yf1-yg1+1) * (self.crop_l/hg) 102 | yf2_n = (yf2-yg1+1) * (self.crop_l/hg) 103 | xf1_n = (xf1-xg1+1) * (self.crop_l/wg) 104 | xf2_n = (xf2-xg1+1) * (self.crop_l/wg) 105 | 106 | ''' 107 | yf1_n = (yf1-yg1) * (self.crop_l/hg) 108 | yf2_n = (yf2-yg1) * (self.crop_l/hg) 109 | xf1_n = (xf1-xg1) * (self.crop_l/wg) 110 | xf2_n = (xf2-xg1) * (self.crop_l/wg) 111 | 112 | yf1_n = max(yf1_n,0) 113 | yf2_n = min(yf2_n,self.crop_l) 114 | xf1_n = max(xf1_n,0) 115 | xf2_n = min(xf2_n,self.crop_l) 116 | return (yf1_n,yf2_n,xf1_n,xf2_n) 117 | 118 | 119 | 120 | def _get_transform_states(self): 121 | return [x.get_state() for x in self.transforms] 122 | 123 | def _set_transform_states(self, states): 124 | assert len(states) == len(self.transforms) 125 | for state, transform in zip(states, self.transforms): 126 | transform.set_state(state) 127 | print('_set_transform_states') 128 | 129 | def apply_transforms(self, image_nd, clicks_lists): 130 | is_image_changed = False 131 | for t in self.transforms: 132 | image_nd, clicks_lists = t.transform(image_nd, clicks_lists) 133 | is_image_changed |= t.image_changed 134 | 135 | return image_nd, clicks_lists, is_image_changed 136 | 137 | def get_points_nd(self, clicks_lists): 138 | total_clicks = [] 139 | num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] 140 | num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] 141 | num_max_points = max(num_pos_clicks + num_neg_clicks) 142 | if self.net_clicks_limit is not None: 143 | num_max_points = min(self.net_clicks_limit, num_max_points) 144 | num_max_points = max(1, num_max_points) 145 | 146 | for clicks_list in clicks_lists: 147 | clicks_list = clicks_list[:self.net_clicks_limit] 148 | pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] 149 | pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] 150 | 151 | neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] 152 | neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] 153 | total_clicks.append(pos_clicks + neg_clicks) 154 | 155 | return torch.tensor(total_clicks, device=self.device) 156 | 157 | def get_states(self): 158 | return { 159 | 'transform_states': self._get_transform_states(), 160 | 'prev_prediction': self.prev_prediction.clone() 161 | } 162 | 163 | def set_states(self, states): 164 | self._set_transform_states(states['transform_states']) 165 | self.prev_prediction = states['prev_prediction'] 166 | print('set') 167 | -------------------------------------------------------------------------------- /isegm/inference/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SigmoidForPred 2 | from .flip import AddHorizontalFlip 3 | from .zoom_in import ZoomIn 4 | from .limit_longest_side import LimitLongestSide 5 | from .crops import Crops 6 | from .resize import ResizeTrans 7 | -------------------------------------------------------------------------------- /isegm/inference/transforms/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseTransform(object): 5 | def __init__(self): 6 | self.image_changed = False 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | raise NotImplementedError 10 | 11 | def inv_transform(self, prob_map): 12 | raise NotImplementedError 13 | 14 | def reset(self): 15 | raise NotImplementedError 16 | 17 | def get_state(self): 18 | raise NotImplementedError 19 | 20 | def set_state(self, state): 21 | raise NotImplementedError 22 | 23 | 24 | class SigmoidForPred(BaseTransform): 25 | def transform(self, image_nd, clicks_lists): 26 | return image_nd, clicks_lists 27 | 28 | def inv_transform(self, prob_map): 29 | return torch.sigmoid(prob_map) 30 | 31 | def reset(self): 32 | pass 33 | 34 | def get_state(self): 35 | return None 36 | 37 | def set_state(self, state): 38 | pass 39 | -------------------------------------------------------------------------------- /isegm/inference/transforms/crops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | from typing import List 6 | 7 | from isegm.inference.clicker import Click 8 | from .base import BaseTransform 9 | 10 | 11 | class Crops(BaseTransform): 12 | def __init__(self, crop_size=(320, 480), min_overlap=0.2): 13 | super().__init__() 14 | self.crop_height, self.crop_width = crop_size 15 | self.min_overlap = min_overlap 16 | 17 | self.x_offsets = None 18 | self.y_offsets = None 19 | self._counts = None 20 | 21 | def transform(self, image_nd, clicks_lists: List[List[Click]]): 22 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 23 | image_height, image_width = image_nd.shape[2:4] 24 | self._counts = None 25 | 26 | if image_height < self.crop_height or image_width < self.crop_width: 27 | return image_nd, clicks_lists 28 | 29 | self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) 30 | self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) 31 | self._counts = np.zeros((image_height, image_width)) 32 | 33 | image_crops = [] 34 | for dy in self.y_offsets: 35 | for dx in self.x_offsets: 36 | self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 37 | image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] 38 | image_crops.append(image_crop) 39 | image_crops = torch.cat(image_crops, dim=0) 40 | self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) 41 | 42 | clicks_list = clicks_lists[0] 43 | clicks_lists = [] 44 | for dy in self.y_offsets: 45 | for dx in self.x_offsets: 46 | crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] 47 | clicks_lists.append(crop_clicks) 48 | 49 | return image_crops, clicks_lists 50 | 51 | def inv_transform(self, prob_map): 52 | if self._counts is None: 53 | return prob_map 54 | 55 | new_prob_map = torch.zeros((1, 1, *self._counts.shape), 56 | dtype=prob_map.dtype, device=prob_map.device) 57 | 58 | crop_indx = 0 59 | for dy in self.y_offsets: 60 | for dx in self.x_offsets: 61 | new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] 62 | crop_indx += 1 63 | new_prob_map = torch.div(new_prob_map, self._counts) 64 | 65 | return new_prob_map 66 | 67 | def get_state(self): 68 | return self.x_offsets, self.y_offsets, self._counts 69 | 70 | def set_state(self, state): 71 | self.x_offsets, self.y_offsets, self._counts = state 72 | 73 | def reset(self): 74 | self.x_offsets = None 75 | self.y_offsets = None 76 | self._counts = None 77 | 78 | 79 | def get_offsets(length, crop_size, min_overlap_ratio=0.2): 80 | if length == crop_size: 81 | return [0] 82 | 83 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) 84 | N = math.ceil(N) 85 | 86 | overlap_ratio = (N - length / crop_size) / (N - 1) 87 | overlap_width = int(crop_size * overlap_ratio) 88 | 89 | offsets = [0] 90 | for i in range(1, N): 91 | new_offset = offsets[-1] + crop_size - overlap_width 92 | if new_offset + crop_size > length: 93 | new_offset = length - crop_size 94 | 95 | offsets.append(new_offset) 96 | 97 | return offsets 98 | -------------------------------------------------------------------------------- /isegm/inference/transforms/flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List 4 | from isegm.inference.clicker import Click 5 | from .base import BaseTransform 6 | 7 | 8 | class AddHorizontalFlip(BaseTransform): 9 | def transform(self, image_nd, clicks_lists: List[List[Click]]): 10 | assert len(image_nd.shape) == 4 11 | image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) 12 | 13 | image_width = image_nd.shape[3] 14 | clicks_lists_flipped = [] 15 | for clicks_list in clicks_lists: 16 | clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) 17 | for click in clicks_list] 18 | clicks_lists_flipped.append(clicks_list_flipped) 19 | clicks_lists = clicks_lists + clicks_lists_flipped 20 | 21 | return image_nd, clicks_lists 22 | 23 | def inv_transform(self, prob_map): 24 | assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 25 | num_maps = prob_map.shape[0] // 2 26 | prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] 27 | 28 | return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) 29 | 30 | def get_state(self): 31 | return None 32 | 33 | def set_state(self, state): 34 | pass 35 | 36 | def reset(self): 37 | pass 38 | -------------------------------------------------------------------------------- /isegm/inference/transforms/limit_longest_side.py: -------------------------------------------------------------------------------- 1 | from .zoom_in import ZoomIn, get_roi_image_nd 2 | 3 | 4 | class LimitLongestSide(ZoomIn): 5 | def __init__(self, max_size=800): 6 | super().__init__(target_size=max_size, skip_clicks=0) 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 10 | image_max_size = max(image_nd.shape[2:4]) 11 | self.image_changed = False 12 | 13 | if image_max_size <= self.target_size: 14 | return image_nd, clicks_lists 15 | self._input_image = image_nd 16 | 17 | self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) 18 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) 19 | self.image_changed = True 20 | 21 | tclicks_lists = [self._transform_clicks(clicks_lists[0])] 22 | return self._roi_image, tclicks_lists 23 | -------------------------------------------------------------------------------- /isegm/inference/transforms/resize.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from isegm.inference.clicker import Click 7 | from .base import BaseTransform 8 | import torch.nn.functional as F 9 | 10 | 11 | class ResizeTrans(BaseTransform): 12 | def __init__(self, l=480): 13 | super().__init__() 14 | self.crop_height = l 15 | self.crop_width = l 16 | 17 | def transform(self, image_nd, clicks_lists): 18 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 19 | image_height, image_width = image_nd.shape[2:4] 20 | self.image_height = image_height 21 | self.image_width = image_width 22 | #image_np = np.transpose( image_nd[0].numpy(), (1,2,0)).astype(np.uint8) 23 | #image_np_r = cv2.resize( image_np, (self.crop_width, self.crop_height)) 24 | #image_nd_r = torch.from_numpy(image_np_r).unsqueeze(0).permute(0,3,1,2).float() 25 | image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode = 'bilinear', align_corners=True ) 26 | 27 | y_ratio = self.crop_height / image_height 28 | x_ratio = self.crop_width / image_width 29 | 30 | #clicks_list = clicks_lists[0] 31 | #clicks_lists = [] 32 | #resize_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] * y_ratio, x.coords[1] * x_ratio )) 33 | # for x in clicks_list] 34 | #clicks_lists.append(resize_clicks) 35 | 36 | clicks_lists_resized = [] 37 | for clicks_list in clicks_lists: 38 | clicks_list_resized = [click.copy(coords=(click.coords[0] * y_ratio, click.coords[1] * x_ratio )) 39 | for click in clicks_list] 40 | clicks_lists_resized.append(clicks_list_resized) 41 | 42 | return image_nd_r, clicks_lists_resized 43 | 44 | def inv_transform(self, prob_map): 45 | new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear', align_corners=True ) 46 | 47 | return new_prob_map 48 | 49 | def get_state(self): 50 | return self.x_offsets, self.y_offsets, self._counts 51 | 52 | def set_state(self, state): 53 | self.x_offsets, self.y_offsets, self._counts = state 54 | 55 | def reset(self): 56 | self.x_offsets = None 57 | self.y_offsets = None 58 | self._counts = None 59 | 60 | 61 | def get_offsets(length, crop_size, min_overlap_ratio=0.2): 62 | if length == crop_size: 63 | return [0] 64 | 65 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) 66 | N = math.ceil(N) 67 | 68 | overlap_ratio = (N - length / crop_size) / (N - 1) 69 | overlap_width = int(crop_size * overlap_ratio) 70 | 71 | offsets = [0] 72 | for i in range(1, N): 73 | new_offset = offsets[-1] + crop_size - overlap_width 74 | if new_offset + crop_size > length: 75 | new_offset = length - crop_size 76 | 77 | offsets.append(new_offset) 78 | 79 | return offsets 80 | -------------------------------------------------------------------------------- /isegm/inference/transforms/zoom_in.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List 4 | from isegm.inference.clicker import Click 5 | from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox 6 | from .base import BaseTransform 7 | 8 | 9 | 10 | class ZoomIn(BaseTransform): 11 | def __init__(self, 12 | target_size=480, 13 | skip_clicks=1, 14 | expansion_ratio=1.4, 15 | min_crop_size=10,#200 16 | recompute_thresh_iou=0.5, 17 | prob_thresh=0.49): 18 | super().__init__() 19 | self.target_size = target_size 20 | self.min_crop_size = min_crop_size 21 | self.skip_clicks = skip_clicks 22 | self.expansion_ratio = expansion_ratio 23 | self.recompute_thresh_iou = recompute_thresh_iou 24 | self.prob_thresh = prob_thresh 25 | 26 | self._input_image_shape = None 27 | self._prev_probs = None 28 | self._object_roi = None 29 | self._roi_image = None 30 | 31 | def transform(self, image_nd, clicks_lists: List[List[Click]]): 32 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 33 | self.image_changed = False 34 | 35 | clicks_list = clicks_lists[0] 36 | if len(clicks_list) <= self.skip_clicks: 37 | return image_nd, clicks_lists 38 | 39 | self._input_image_shape = image_nd.shape 40 | 41 | current_object_roi = None 42 | if self._prev_probs is not None: 43 | current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] 44 | if current_pred_mask.sum() > 0: 45 | current_object_roi = get_object_roi(current_pred_mask, clicks_list, 46 | self.expansion_ratio, self.min_crop_size) 47 | else: 48 | print('None') 49 | 50 | if current_object_roi is None: 51 | if self.skip_clicks >= 0: 52 | return image_nd, clicks_lists 53 | else: 54 | current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1 55 | 56 | # here 57 | update_object_roi = True 58 | if self._object_roi is None: 59 | update_object_roi = True 60 | elif not check_object_roi(self._object_roi, clicks_list): 61 | update_object_roi = True 62 | elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: 63 | update_object_roi = True 64 | 65 | if update_object_roi: 66 | self._object_roi = current_object_roi 67 | self.image_changed = True 68 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) 69 | 70 | tclicks_lists = [self._transform_clicks(clicks_list)] 71 | return self._roi_image.to(image_nd.device), tclicks_lists 72 | 73 | def inv_transform(self, prob_map): 74 | if self._object_roi is None: 75 | self._prev_probs = prob_map.cpu().numpy() 76 | return prob_map 77 | 78 | assert prob_map.shape[0] == 1 79 | rmin, rmax, cmin, cmax = self._object_roi 80 | prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), 81 | mode='bilinear', align_corners=True) 82 | 83 | 84 | 85 | if self._prev_probs is not None: 86 | new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) 87 | new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map 88 | #new_prob_map[:, :, rmin:rmax, cmin:cmax] = prob_map 89 | else: 90 | new_prob_map = prob_map 91 | 92 | self._prev_probs = new_prob_map.cpu().numpy() 93 | 94 | return new_prob_map 95 | 96 | def check_possible_recalculation(self): 97 | if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: 98 | return False 99 | 100 | pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] 101 | if pred_mask.sum() > 0: 102 | possible_object_roi = get_object_roi(pred_mask, [], 103 | self.expansion_ratio, self.min_crop_size) 104 | image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) 105 | if get_bbox_iou(possible_object_roi, image_roi) < 0.50: 106 | return True 107 | return False 108 | 109 | def get_state(self): 110 | roi_image = self._roi_image.cpu() if self._roi_image is not None else None 111 | return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed 112 | 113 | def set_state(self, state): 114 | self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state 115 | 116 | def reset(self): 117 | self._input_image_shape = None 118 | self._object_roi = None 119 | self._prev_probs = None 120 | self._roi_image = None 121 | self.image_changed = False 122 | 123 | def _transform_clicks(self, clicks_list): 124 | if self._object_roi is None: 125 | return clicks_list 126 | 127 | rmin, rmax, cmin, cmax = self._object_roi 128 | crop_height, crop_width = self._roi_image.shape[2:] 129 | 130 | transformed_clicks = [] 131 | for click in clicks_list: 132 | new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) 133 | new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) 134 | transformed_clicks.append(click.copy(coords=(new_r, new_c))) 135 | return transformed_clicks 136 | 137 | 138 | def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): 139 | pred_mask = pred_mask.copy() 140 | 141 | for click in clicks_list: 142 | if click.is_positive: 143 | pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 144 | 145 | bbox = get_bbox_from_mask(pred_mask) 146 | bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) 147 | h, w = pred_mask.shape[0], pred_mask.shape[1] 148 | bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) 149 | 150 | return bbox 151 | 152 | 153 | def get_roi_image_nd(image_nd, object_roi, target_size): 154 | rmin, rmax, cmin, cmax = object_roi 155 | 156 | height = rmax - rmin + 1 157 | width = cmax - cmin + 1 158 | 159 | if isinstance(target_size, tuple): 160 | new_height, new_width = target_size 161 | else: 162 | scale = target_size / max(height, width) 163 | new_height = int(round(height * scale)) 164 | new_width = int(round(width * scale)) 165 | 166 | with torch.no_grad(): 167 | roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] 168 | #roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), 169 | # mode='bilinear', align_corners=True) 170 | 171 | return roi_image_nd 172 | 173 | 174 | def check_object_roi(object_roi, clicks_list): 175 | for click in clicks_list: 176 | if click.is_positive: 177 | if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: 178 | return False 179 | if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: 180 | return False 181 | 182 | return True 183 | -------------------------------------------------------------------------------- /isegm/inference/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from pathlib import Path 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset 8 | 9 | from isegm.utils.serialization import load_model 10 | 11 | 12 | def get_time_metrics(all_ious, elapsed_time): 13 | n_images = len(all_ious) 14 | n_clicks = sum(map(len, all_ious)) 15 | 16 | mean_spc = elapsed_time / n_clicks 17 | mean_spi = elapsed_time / n_images 18 | 19 | return mean_spc, mean_spi 20 | 21 | 22 | def load_is_model(checkpoint, device, **kwargs): 23 | if isinstance(checkpoint, (str, Path)): 24 | state_dict = torch.load(checkpoint, map_location='cpu') 25 | else: 26 | state_dict = checkpoint 27 | 28 | if isinstance(state_dict, list): 29 | model = load_single_is_model(state_dict[0], device, **kwargs) 30 | models = [load_single_is_model(x, device, **kwargs) for x in state_dict] 31 | 32 | return model, models 33 | else: 34 | return load_single_is_model(state_dict, device, **kwargs) 35 | 36 | 37 | def load_single_is_model(state_dict, device, **kwargs): 38 | model = load_model(state_dict['config'], **kwargs) 39 | model.load_state_dict(state_dict['state_dict'], strict=False) 40 | 41 | for param in model.parameters(): 42 | param.requires_grad = False 43 | model.to(device) 44 | model.eval() 45 | 46 | return model 47 | 48 | 49 | def get_dataset(dataset_name, cfg): 50 | if dataset_name == 'GrabCut': 51 | dataset = GrabCutDataset(cfg.GRABCUT_PATH) 52 | elif dataset_name == 'Berkeley': 53 | dataset = BerkeleyDataset(cfg.BERKELEY_PATH) 54 | elif dataset_name == 'DAVIS': 55 | dataset = DavisDataset(cfg.DAVIS_PATH) 56 | elif dataset_name == 'SBD': 57 | dataset = SBDEvaluationDataset(cfg.SBD_PATH) 58 | elif dataset_name == 'SBD_Train': 59 | dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train') 60 | else: 61 | dataset = None 62 | return dataset 63 | 64 | 65 | def get_iou(gt_mask, pred_mask, ignore_label=-1): 66 | ignore_gt_mask_inv = gt_mask != ignore_label 67 | obj_gt_mask = gt_mask == 1 68 | 69 | intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() 70 | union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() 71 | 72 | return intersection / union 73 | 74 | 75 | def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): 76 | def _get_noc(iou_arr, iou_thr): 77 | vals = iou_arr >= iou_thr 78 | return np.argmax(vals) + 1 if np.any(vals) else max_clicks 79 | 80 | noc_list = [] 81 | over_max_list = [] 82 | for iou_thr in iou_thrs: 83 | scores_arr = np.array([_get_noc(iou_arr, iou_thr) 84 | for iou_arr in all_ious], dtype=np.int) 85 | 86 | score = scores_arr.mean() 87 | over_max = (scores_arr == max_clicks).sum() 88 | 89 | noc_list.append(score) 90 | over_max_list.append(over_max) 91 | 92 | return noc_list, over_max_list 93 | 94 | 95 | def find_checkpoint(weights_folder, checkpoint_name): 96 | weights_folder = Path(weights_folder) 97 | if ':' in checkpoint_name: 98 | model_name, checkpoint_name = checkpoint_name.split(':') 99 | models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] 100 | assert len(models_candidates) == 1 101 | model_folder = models_candidates[0] 102 | else: 103 | model_folder = weights_folder 104 | 105 | if checkpoint_name.endswith('.pth'): 106 | if Path(checkpoint_name).exists(): 107 | checkpoint_path = checkpoint_name 108 | else: 109 | checkpoint_path = weights_folder / checkpoint_name 110 | else: 111 | model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) 112 | assert len(model_checkpoints) == 1 113 | checkpoint_path = model_checkpoints[0] 114 | 115 | return str(checkpoint_path) 116 | 117 | 118 | def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, 119 | n_clicks=20, model_name=None): 120 | table_header = (f'|{"Pipeline":^13}|{"Dataset":^11}|' 121 | f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' 122 | f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' 123 | f'{"SPC,s":^7}|{"Time":^9}|') 124 | row_width = len(table_header) 125 | 126 | header = f'Eval results for model: {model_name}\n' if model_name is not None else '' 127 | header += '-' * row_width + '\n' 128 | header += table_header + '\n' + '-' * row_width 129 | 130 | eval_time = str(timedelta(seconds=int(elapsed_time))) 131 | table_row = f'|{brs_type:^13}|{dataset_name:^11}|' 132 | table_row += f'{noc_list[0]:^9.2f}|' 133 | table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' 134 | table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' 135 | table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' 136 | table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' 137 | table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' 138 | 139 | return header, table_row 140 | -------------------------------------------------------------------------------- /isegm/model/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Initializer(object): 7 | def __init__(self, local_init=True, gamma=None): 8 | self.local_init = local_init 9 | self.gamma = gamma 10 | 11 | def __call__(self, m): 12 | if getattr(m, '__initialized', False): 13 | return 14 | 15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: 18 | if m.weight is not None: 19 | self._init_gamma(m.weight.data) 20 | if m.bias is not None: 21 | self._init_beta(m.bias.data) 22 | else: 23 | if getattr(m, 'weight', None) is not None: 24 | self._init_weight(m.weight.data) 25 | if getattr(m, 'bias', None) is not None: 26 | self._init_bias(m.bias.data) 27 | 28 | if self.local_init: 29 | object.__setattr__(m, '__initialized', True) 30 | 31 | def _init_weight(self, data): 32 | nn.init.uniform_(data, -0.07, 0.07) 33 | 34 | def _init_bias(self, data): 35 | nn.init.constant_(data, 0) 36 | 37 | def _init_gamma(self, data): 38 | if self.gamma is None: 39 | nn.init.constant_(data, 1.0) 40 | else: 41 | nn.init.normal_(data, 1.0, self.gamma) 42 | 43 | def _init_beta(self, data): 44 | nn.init.constant_(data, 0) 45 | 46 | 47 | class Bilinear(Initializer): 48 | def __init__(self, scale, groups, in_channels, **kwargs): 49 | super().__init__(**kwargs) 50 | self.scale = scale 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | 54 | def _init_weight(self, data): 55 | """Reset the weight and bias.""" 56 | bilinear_kernel = self.get_bilinear_kernel(self.scale) 57 | weight = torch.zeros_like(data) 58 | for i in range(self.in_channels): 59 | if self.groups == 1: 60 | j = i 61 | else: 62 | j = 0 63 | weight[i, j] = bilinear_kernel 64 | data[:] = weight 65 | 66 | @staticmethod 67 | def get_bilinear_kernel(scale): 68 | """Generate a bilinear upsampling kernel.""" 69 | kernel_size = 2 * scale - scale % 2 70 | scale = (kernel_size + 1) // 2 71 | center = scale - 0.5 * (1 + kernel_size % 2) 72 | 73 | og = np.ogrid[:kernel_size, :kernel_size] 74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) 75 | 76 | return torch.tensor(kernel, dtype=torch.float32) 77 | 78 | 79 | class XavierGluon(Initializer): 80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | self.rnd_type = rnd_type 84 | self.factor_type = factor_type 85 | self.magnitude = float(magnitude) 86 | 87 | def _init_weight(self, arr): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) 89 | 90 | if self.factor_type == 'avg': 91 | factor = (fan_in + fan_out) / 2.0 92 | elif self.factor_type == 'in': 93 | factor = fan_in 94 | elif self.factor_type == 'out': 95 | factor = fan_out 96 | else: 97 | raise ValueError('Incorrect factor type') 98 | scale = np.sqrt(self.magnitude / factor) 99 | 100 | if self.rnd_type == 'uniform': 101 | nn.init.uniform_(arr, -scale, scale) 102 | elif self.rnd_type == 'gaussian': 103 | nn.init.normal_(arr, 0, scale) 104 | else: 105 | raise ValueError('Unknown random type') 106 | -------------------------------------------------------------------------------- /isegm/model/is_gp_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from isegm.model.ops import DistMaps, BatchImageNormalize 7 | from einops import rearrange, repeat 8 | from opt_einsum import contract 9 | import math 10 | 11 | class ISGPModel(nn.Module): 12 | def __init__(self, use_rgb_conv=False, feature_stride = 4, with_aux_output=False, 13 | norm_radius=260, use_disks=False, cpu_dist_maps=False, 14 | clicks_groups=None, with_prev_mask=False, use_leaky_relu=False, 15 | binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d, 16 | norm_mean_std=([.485, .456, .406], [.229, .224, .225])): 17 | super().__init__() 18 | self.with_aux_output = with_aux_output 19 | self.clicks_groups = clicks_groups 20 | self.with_prev_mask = with_prev_mask 21 | self.binary_prev_mask = binary_prev_mask 22 | self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) 23 | self.dist_maps = DistMaps(norm_radius=5, spatial_scale=1.0, 24 | cpu_mode=False, use_disks=True) 25 | 26 | 27 | def prepare_input(self, image): 28 | prev_mask = None 29 | if self.with_prev_mask: 30 | prev_mask = image[:, 3:, :, :] 31 | image = image[:, :3, :, :] 32 | if self.binary_prev_mask: 33 | prev_mask = (prev_mask > 0.5).float() 34 | 35 | image = self.normalization(image) 36 | return image, prev_mask 37 | 38 | def get_coord_features(self, image, prev_mask, points): 39 | coord_features = self.dist_maps(image, points) 40 | if prev_mask is not None: 41 | coord_features = torch.cat((prev_mask, coord_features), dim=1) 42 | return coord_features 43 | 44 | def load_pretrained_weights(self, path_to_weights= ''): 45 | state_dict = self.state_dict() 46 | pretrained_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] 47 | ckpt_keys = set(pretrained_state_dict.keys()) 48 | own_keys = set(state_dict.keys()) 49 | missing_keys = own_keys - ckpt_keys 50 | unexpected_keys = ckpt_keys - own_keys 51 | print('Missing Keys: ', missing_keys) 52 | print('Unexpected Keys: ', unexpected_keys) 53 | state_dict.update(pretrained_state_dict) 54 | self.load_state_dict(state_dict, strict= False) 55 | ''' 56 | if self.inference_mode: 57 | for param in self.backbone.parameters(): 58 | param.requires_grad = False 59 | ''' 60 | 61 | def get_coord_features(self, image, prev_mask, points): 62 | coord_features = self.dist_maps(image, points) 63 | if prev_mask is not None: 64 | coord_features = torch.cat((prev_mask, coord_features), dim=1) 65 | return coord_features 66 | 67 | def prepare_points_labels(self, points,feature): 68 | pss = [] 69 | label_list = [] 70 | point_labels = torch.ones([points.size(1),1], dtype=torch.float32, device=feature.device) 71 | point_labels[points.size(1)//2:,:] = -1. 72 | for i in range(points.size(0)): 73 | ps, _ = torch.split(points[i],[2,1],dim=1) 74 | valid_points = torch.logical_and(torch.logical_and(torch.min(ps, dim=1, keepdim=False)[0] >= 0, 75 | ps[:,0] < feature.size(2)), ps[:,1] < feature.size(3) ) 76 | ps = ps[valid_points] # n, 2 77 | pss.append(ps) 78 | label = point_labels[valid_points,:] #n,1 79 | label_list.append(label) 80 | return pss, label_list 81 | 82 | def Pathwise_GP_prior(self, feature, omega): 83 | b,d,h,w = feature.size() 84 | phi_f = math.sqrt(2/self.L)*torch.sin(rearrange(self.theta(rearrange(feature, 'b d h w -> (b h w) d')), '(b h w) d->b d h w',b=b,h=h,w=w)) 85 | prior = contract('blhw,ls->bshw',phi_f,omega) # b,1,h,w 86 | return prior 87 | 88 | def Pathwise_GP_update(self, points, feature,pss,label_list,result,omega): 89 | b,d,h,w = feature.size() 90 | inv_Kmm_list = [] 91 | zf_list = [] 92 | point_nums = [] 93 | weight = F.softplus(self.weights) 94 | 95 | for i in range(points.size(0)): 96 | ps = pss[i] 97 | if ps.size(0)==0: 98 | point_nums.append(0) 99 | continue 100 | ps = torch.cat([ps[:,[0]].clamp(min=0., max=feature.size(2)-1),ps[:,[1]].clamp(min=0., max=feature.size(3)-1)],1) 101 | 102 | point_nums.append(ps.size(0)) 103 | zf = feature[i,:,ps[:,0].long(),ps[:,1].long()].T #n,d 104 | zf_list.append(zf) 105 | norm = torch.norm(torch.exp(self.logsigma2/2)*zf[:,:-3], dim=1,p=2)**2/2 # n, 106 | Kmm = torch.exp(contract('nd,md,d->nm',zf[:,:-3],zf[:,:-3],torch.exp(self.logsigma2))-\ 107 | norm.unsqueeze(0).repeat(ps.size(0),1)-norm.unsqueeze(1).repeat(1,ps.size(0)))+\ 108 | weight*torch.exp(-torch.sum((zf[:,-3:].unsqueeze(1)-zf[:,-3:])**2,2)/2) 109 | 110 | inv_Kmm_list.append(torch.inverse(Kmm+self.eps2*torch.eye(Kmm.size(0),device=Kmm.device))) 111 | 112 | inv_Kmm = torch.block_diag(*inv_Kmm_list) #n,n 113 | zf = torch.cat(zf_list,dim=0) # n,d 114 | label = torch.cat(label_list,dim=0) # n,1 115 | m = F.softplus(self.u_mlp(zf))*label #n,1 116 | 117 | if self.training: 118 | u = m + 0.01*torch.randn(m.size()).to(feature.device) 119 | u_loss = self.u_loss(inv_Kmm.detach(),m,u,label/2+0.5) 120 | else: 121 | u = m 122 | u_loss = torch.tensor([0.],device=feature.device) 123 | 124 | phi = math.sqrt(2/self.L)*torch.sin(self.theta(zf)) 125 | 126 | phi_omega = torch.matmul(phi,omega) # n,1 127 | 128 | v = torch.matmul(inv_Kmm, u-phi_omega) # n,1 129 | num_prev = 0 130 | offset = 0 131 | weight = F.softplus(self.weights) 132 | for i in range(points.size(0)): 133 | if point_nums[i]==0: 134 | offset+=1 135 | continue 136 | norm1 = torch.norm(torch.exp(self.logsigma2/2).view(self.feature_dim,1,1)*feature[i,:-3], dim=0,p=2)**2/2 #h w 137 | norm2 = torch.norm(torch.exp(self.logsigma2/2)*zf_list[i-offset][:,:-3], dim=1,p=2)**2/2 # n, 138 | norm_rgb1 = torch.norm(feature[i,-3:], dim=0,p=2)**2/2 #h w 139 | norm_rgb2 = torch.norm(zf_list[i-offset][:,-3:], dim=1,p=2)**2/2 # n, 140 | Knm = torch.exp(contract('dhw,nd,d->nhw',feature[i,:-3],zf_list[i-offset][:,:-3],torch.exp(self.logsigma2)) -\ 141 | repeat(norm1, 'h w -> n h w',n=point_nums[i]) - repeat(norm2, 'n -> n h w',h=h, w=w)) + \ 142 | weight*torch.exp(contract('dhw,nd->nhw',feature[i,-3:],zf_list[i-offset][:,-3:])-\ 143 | repeat(norm_rgb1, 'h w -> n h w',n=point_nums[i]) - repeat(norm_rgb2, 'n -> n h w',h=h, w=w)) 144 | result[i,...] += contract('nhw,ns->shw',Knm, v[num_prev:num_prev+point_nums[i]]) 145 | num_prev += point_nums[i] 146 | return result, 0.001*u_loss 147 | 148 | def u_loss(self,invK, m, u, y): 149 | n = invK.size(0) 150 | loss = F.binary_cross_entropy_with_logits(u,y)+ torch.matmul(torch.matmul(m.T,invK),m)/n 151 | return loss 152 | 153 | def split_points_by_order(tpoints: torch.Tensor, groups): 154 | points = tpoints.cpu().numpy() 155 | num_groups = len(groups) 156 | bs = points.shape[0] 157 | num_points = points.shape[1] // 2 158 | 159 | groups = [x if x > 0 else num_points for x in groups] 160 | group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) 161 | for x in groups] 162 | 163 | last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int) 164 | for group_indx, group_size in enumerate(groups): 165 | last_point_indx_group[:, group_indx, 1] = group_size 166 | 167 | for bindx in range(bs): 168 | for pindx in range(2 * num_points): 169 | point = points[bindx, pindx, :] 170 | group_id = int(point[2]) 171 | if group_id < 0: 172 | continue 173 | 174 | is_negative = int(pindx >= num_points) 175 | if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click 176 | group_id = num_groups - 1 177 | 178 | new_point_indx = last_point_indx_group[bindx, group_id, is_negative] 179 | last_point_indx_group[bindx, group_id, is_negative] += 1 180 | 181 | group_points[group_id][bindx, new_point_indx, :] = point 182 | 183 | group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) 184 | for x in group_points] 185 | 186 | return group_points 187 | -------------------------------------------------------------------------------- /isegm/model/is_gp_resnet50.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from isegm.utils.serialization import serialize 5 | from .is_gp_model import ISGPModel 6 | from isegm.model.ops import ScaleLayer 7 | from .modeling.deeplab_v3_gp import DeepLabV3Plus 8 | from isegm.model.modifiers import LRMult 9 | 10 | 11 | class GpModel(ISGPModel): 12 | @serialize 13 | def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0., 14 | backbone_norm_layer=None, backbone_lr_mult=0.1, 15 | norm_layer=nn.BatchNorm2d, weight_dir=None, **kwargs): 16 | super().__init__(norm_layer=norm_layer, **kwargs) 17 | 18 | self.model = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, 19 | project_dropout=aspp_dropout, norm_layer=norm_layer, 20 | backbone_norm_layer=backbone_norm_layer, weight_dir=weight_dir) 21 | 22 | side_feature_ch = 256 23 | 24 | self.model.apply(LRMult(backbone_lr_mult)) 25 | 26 | 27 | mt_layers = [ 28 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), 29 | nn.LeakyReLU(negative_slope=0.2), 30 | nn.Conv2d(in_channels=16, out_channels=side_feature_ch, kernel_size=3, stride=1, padding=1), 31 | ScaleLayer(init_value=0.05, lr_mult=1) 32 | ] 33 | self.maps_transform = nn.Sequential(*mt_layers) 34 | self.L=256 35 | self.feature_dim = 48 36 | self.theta = nn.Linear(self.feature_dim+3,self.L) 37 | omega = 0.25*torch.randn(self.L,1) 38 | self.omega = nn.Parameter(omega, requires_grad=True) 39 | omega_var = torch.tensor(0.025) 40 | self.omega_var = nn.Parameter(omega_var, requires_grad=True) 41 | 42 | logsigma2 = torch.ones(self.feature_dim) 43 | self.logsigma2 = nn.Parameter(logsigma2, requires_grad=True) 44 | self.u_mlp = nn.Sequential( 45 | nn.Linear(self.feature_dim+3,96), 46 | nn.ReLU(True), 47 | nn.Linear(96,1) 48 | ) 49 | 50 | weight = torch.zeros(1) 51 | self.weights = nn.Parameter(weight, requires_grad=True) 52 | self.eps2 = 1e-2 53 | 54 | def set_status(self, training): 55 | if training: 56 | self.eps2=1e-2 57 | else: 58 | self.eps2=1e-7 59 | 60 | def forward(self, image, points): 61 | image, prev_mask = self.prepare_input(image) 62 | coord_features = self.get_coord_features(image, prev_mask, points) 63 | coord_features = self.maps_transform(coord_features) 64 | 65 | feature = self.model(image, coord_features) 66 | feature = F.normalize(feature, dim=1) 67 | 68 | feature = nn.functional.interpolate(feature, size=image.size()[2:], 69 | mode='bilinear', align_corners=True) 70 | feature = torch.cat([feature, image],1) 71 | 72 | pss, label_list = self.prepare_points_labels(points,feature) 73 | if self.training: 74 | omega = self.omega+self.omega_var.clamp(min=0.01,max=0.05)*torch.randn(self.L,1).to(feature.device) 75 | else: 76 | omega = self.omega 77 | prior= self.Pathwise_GP_prior(feature, omega) 78 | out, u_loss =self.Pathwise_GP_update(points, feature,pss,label_list,prior,omega) 79 | outputs = {'instances': out, 'u_loss':u_loss} 80 | return outputs 81 | -------------------------------------------------------------------------------- /isegm/model/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from isegm.utils import misc 7 | 8 | 9 | class NormalizedFocalLossSigmoid(nn.Module): 10 | def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, 11 | from_sigmoid=False, detach_delimeter=True, 12 | batch_axis=0, weight=None, size_average=True, 13 | ignore_label=-1): 14 | super(NormalizedFocalLossSigmoid, self).__init__() 15 | self._axis = axis 16 | self._alpha = alpha 17 | self._gamma = gamma 18 | self._ignore_label = ignore_label 19 | self._weight = weight if weight is not None else 1.0 20 | self._batch_axis = batch_axis 21 | 22 | self._from_logits = from_sigmoid 23 | self._eps = eps 24 | self._size_average = size_average 25 | self._detach_delimeter = detach_delimeter 26 | self._max_mult = max_mult 27 | self._k_sum = 0 28 | self._m_max = 0 29 | 30 | def forward(self, pred, label): 31 | #print(pred.shape, label.shape) 32 | pred = pred.float() 33 | one_hot = label > 0.5 34 | sample_weight = label != self._ignore_label 35 | 36 | if not self._from_logits: 37 | pred = torch.sigmoid(pred) 38 | 39 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) 40 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) 41 | 42 | beta = (1 - pt) ** self._gamma 43 | 44 | sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) 45 | beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) 46 | mult = sw_sum / (beta_sum + self._eps) 47 | if self._detach_delimeter: 48 | mult = mult.detach() 49 | beta = beta * mult 50 | if self._max_mult > 0: 51 | beta = torch.clamp_max(beta, self._max_mult) 52 | 53 | with torch.no_grad(): 54 | ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() 55 | sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() 56 | if np.any(ignore_area == 0): 57 | self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() 58 | 59 | beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) 60 | beta_pmax = beta_pmax.mean().item() 61 | self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax 62 | 63 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) 64 | loss = self._weight * (loss * sample_weight) 65 | 66 | if self._size_average: 67 | bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) 68 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) 69 | else: 70 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 71 | 72 | return loss 73 | 74 | def log_states(self, sw, name, global_step): 75 | sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) 76 | sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) 77 | 78 | 79 | 80 | class DiversityLoss(nn.Module): 81 | def __init__(self): 82 | super(DiversityLoss, self).__init__() 83 | self.baseloss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) 84 | self.click_loss = ClickLoss() #WFNL(alpha=0.5, gamma=2, w = 0.99) 85 | 86 | 87 | def forward(self, latent_preds, label, click_map): 88 | div_loss_lst = [] 89 | click_loss = 0 90 | for i in range(latent_preds.shape[1]): 91 | single_pred = latent_preds[:,i,:,:].unsqueeze(1) 92 | single_loss = self.baseloss(single_pred,label) 93 | single_loss = single_loss.unsqueeze(-1) 94 | div_loss_lst.append(single_loss) 95 | click_loss += self.click_loss(single_pred,label,click_map) 96 | 97 | div_losses = torch.cat(div_loss_lst,1) 98 | div_loss_min = torch.min(div_losses,dim=1)[0] 99 | return div_loss_min.mean() + click_loss.mean() 100 | 101 | 102 | 103 | class WFNL(nn.Module): 104 | def __init__(self, axis=-1, alpha=0.25, gamma=2, w = 0.5, max_mult=-1, eps=1e-12, 105 | from_sigmoid=False, detach_delimeter=True, 106 | batch_axis=0, weight=None, size_average=True, 107 | ignore_label=-1): 108 | super(WFNL, self).__init__() 109 | self._axis = axis 110 | self._alpha = alpha 111 | self._gamma = gamma 112 | self._ignore_label = ignore_label 113 | self._weight = weight if weight is not None else 1.0 114 | self._batch_axis = batch_axis 115 | 116 | self._from_logits = from_sigmoid 117 | self._eps = eps 118 | self._size_average = size_average 119 | self._detach_delimeter = detach_delimeter 120 | self._max_mult = max_mult 121 | self._k_sum = 0 122 | self._m_max = 0 123 | self.w = w 124 | 125 | def forward(self, pred, label, weight = None): 126 | one_hot = label > 0.5 127 | sample_weight = label != self._ignore_label 128 | 129 | if not self._from_logits: 130 | pred = torch.sigmoid(pred) 131 | 132 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) 133 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) 134 | 135 | beta = (1 - pt) ** self._gamma 136 | 137 | sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) 138 | beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) 139 | mult = sw_sum / (beta_sum + self._eps) 140 | if self._detach_delimeter: 141 | mult = mult.detach() 142 | beta = beta * mult 143 | if self._max_mult > 0: 144 | beta = torch.clamp_max(beta, self._max_mult) 145 | 146 | with torch.no_grad(): 147 | ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() 148 | sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() 149 | if np.any(ignore_area == 0): 150 | self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() 151 | 152 | beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) 153 | beta_pmax = beta_pmax.mean().item() 154 | self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax 155 | 156 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) 157 | loss = self._weight * (loss * sample_weight) 158 | 159 | if weight is not None: 160 | weight = weight * self.w + (1-self.w) 161 | loss = (loss * weight).sum() / (weight.sum() + self._eps) 162 | else: 163 | bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) 164 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) 165 | return loss 166 | 167 | def log_states(self, sw, name, global_step): 168 | sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) 169 | sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) 170 | 171 | 172 | class FocalLoss(nn.Module): 173 | def __init__(self, axis=-1, alpha=0.25, gamma=2, 174 | from_logits=False, batch_axis=0, 175 | weight=None, num_class=None, 176 | eps=1e-9, size_average=True, scale=1.0, 177 | ignore_label=-1): 178 | super(FocalLoss, self).__init__() 179 | self._axis = axis 180 | self._alpha = alpha 181 | self._gamma = gamma 182 | self._ignore_label = ignore_label 183 | self._weight = weight if weight is not None else 1.0 184 | self._batch_axis = batch_axis 185 | 186 | self._scale = scale 187 | self._num_class = num_class 188 | self._from_logits = from_logits 189 | self._eps = eps 190 | self._size_average = size_average 191 | 192 | def forward(self, pred, label, sample_weight=None): 193 | one_hot = label > 0.5 194 | sample_weight = label != self._ignore_label 195 | 196 | if not self._from_logits: 197 | pred = torch.sigmoid(pred) 198 | 199 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) 200 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) 201 | 202 | beta = (1 - pt) ** self._gamma 203 | 204 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) 205 | loss = self._weight * (loss * sample_weight) 206 | 207 | if self._size_average: 208 | tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) 209 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) 210 | else: 211 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 212 | 213 | return self._scale * loss 214 | 215 | 216 | class SoftIoU(nn.Module): 217 | def __init__(self, from_sigmoid=False, ignore_label=-1): 218 | super().__init__() 219 | self._from_sigmoid = from_sigmoid 220 | self._ignore_label = ignore_label 221 | 222 | def forward(self, pred, label): 223 | label = label.view(pred.size()) 224 | sample_weight = label != self._ignore_label 225 | 226 | if not self._from_sigmoid: 227 | pred = torch.sigmoid(pred) 228 | 229 | loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \ 230 | / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8) 231 | 232 | return loss 233 | 234 | 235 | class SigmoidBinaryCrossEntropyLoss(nn.Module): 236 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): 237 | super(SigmoidBinaryCrossEntropyLoss, self).__init__() 238 | self._from_sigmoid = from_sigmoid 239 | self._ignore_label = ignore_label 240 | self._weight = weight if weight is not None else 1.0 241 | self._batch_axis = batch_axis 242 | 243 | def forward(self, pred, label): 244 | label = label.view(pred.size()) 245 | sample_weight = label != self._ignore_label 246 | label = torch.where(sample_weight, label, torch.zeros_like(label)) 247 | 248 | if not self._from_sigmoid: 249 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) 250 | else: 251 | eps = 1e-12 252 | loss = -(torch.log(pred + eps) * label 253 | + torch.log(1. - pred + eps) * (1. - label)) 254 | 255 | loss = self._weight * (loss * sample_weight) 256 | return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 257 | 258 | 259 | class WeightedSigmoidBinaryCrossEntropyLoss(nn.Module): 260 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): 261 | super(WeightedSigmoidBinaryCrossEntropyLoss, self).__init__() 262 | self._from_sigmoid = from_sigmoid 263 | self._ignore_label = ignore_label 264 | self._weight = weight if weight is not None else 1.0 265 | self._batch_axis = batch_axis 266 | 267 | def forward(self, pred, label, weight): 268 | label = label.view(pred.size()) 269 | sample_weight = label != self._ignore_label 270 | label = torch.where(sample_weight, label, torch.zeros_like(label)) 271 | 272 | if not self._from_sigmoid: 273 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) 274 | else: 275 | eps = 1e-12 276 | loss = -(torch.log(pred + eps) * label 277 | + torch.log(1. - pred + eps) * (1. - label)) 278 | #weight = weight * 0.8 + 0.2 279 | loss = (weight * loss).sum() / weight.sum() 280 | return loss #torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) 281 | 282 | 283 | 284 | 285 | class ClickLoss(nn.Module): 286 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1, alpha = 0.99, beta = 0.01): 287 | super(ClickLoss, self).__init__() 288 | self._from_sigmoid = from_sigmoid 289 | self._ignore_label = ignore_label 290 | self._weight = weight if weight is not None else 1.0 291 | self._batch_axis = batch_axis 292 | self.alpha = alpha 293 | self.beta = beta 294 | 295 | 296 | def forward(self, pred, label, gaussian_maps = None): 297 | h_gt, w_gt = label.shape[-2],label.shape[-1] 298 | h_p, w_p = pred.shape[-2], pred.shape[-1] 299 | if h_gt != h_p or w_gt != w_p: 300 | pred = F.interpolate(pred, size=label.size()[-2:], 301 | mode='bilinear', align_corners=True) 302 | 303 | 304 | label = label.view(pred.size()) 305 | sample_weight = label != self._ignore_label 306 | label = torch.where(sample_weight, label, torch.zeros_like(label)) 307 | 308 | if not self._from_sigmoid: 309 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) 310 | else: 311 | eps = 1e-12 312 | loss = -(torch.log(pred + eps) * label 313 | + torch.log(1. - pred + eps) * (1. - label)) 314 | 315 | loss = self._weight * (loss * sample_weight) 316 | weight_map = gaussian_maps.max(dim=1,keepdim = True)[0] * self.alpha + self.beta 317 | loss = (loss * weight_map).sum() / weight_map.sum() 318 | return loss -------------------------------------------------------------------------------- /isegm/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from isegm.utils import misc 5 | 6 | 7 | class TrainMetric(object): 8 | def __init__(self, pred_outputs, gt_outputs): 9 | self.pred_outputs = pred_outputs 10 | self.gt_outputs = gt_outputs 11 | 12 | def update(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | def get_epoch_value(self): 16 | raise NotImplementedError 17 | 18 | def reset_epoch_stats(self): 19 | raise NotImplementedError 20 | 21 | def log_states(self, sw, tag_prefix, global_step): 22 | pass 23 | 24 | @property 25 | def name(self): 26 | return type(self).__name__ 27 | 28 | 29 | class AdaptiveIoU(TrainMetric): 30 | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, 31 | ignore_label=-1, from_logits=True, 32 | pred_output='instances', gt_output='instances'): 33 | super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) 34 | self._ignore_label = ignore_label 35 | self._from_logits = from_logits 36 | self._iou_thresh = init_thresh 37 | self._thresh_step = thresh_step 38 | self._thresh_beta = thresh_beta 39 | self._iou_beta = iou_beta 40 | self._ema_iou = 0.0 41 | self._epoch_iou_sum = 0.0 42 | self._epoch_batch_count = 0 43 | 44 | def update(self, pred, gt): 45 | gt_mask = gt > 0.5 46 | if self._from_logits: 47 | pred = torch.sigmoid(pred) 48 | 49 | gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() 50 | if np.all(gt_mask_area == 0): 51 | return 52 | 53 | ignore_mask = gt == self._ignore_label 54 | max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() 55 | best_thresh = self._iou_thresh 56 | for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: 57 | temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() 58 | if temp_iou > max_iou: 59 | max_iou = temp_iou 60 | best_thresh = t 61 | 62 | self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh 63 | self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou 64 | self._epoch_iou_sum += max_iou 65 | self._epoch_batch_count += 1 66 | 67 | def get_epoch_value(self): 68 | if self._epoch_batch_count > 0: 69 | return self._epoch_iou_sum / self._epoch_batch_count 70 | else: 71 | return 0.0 72 | 73 | def reset_epoch_stats(self): 74 | self._epoch_iou_sum = 0.0 75 | self._epoch_batch_count = 0 76 | 77 | def log_states(self, sw, tag_prefix, global_step): 78 | sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) 79 | sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) 80 | 81 | @property 82 | def iou_thresh(self): 83 | return self._iou_thresh 84 | 85 | 86 | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): 87 | if ignore_mask is not None: 88 | pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) 89 | 90 | reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) 91 | union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 92 | intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 93 | nonzero = union > 0 94 | 95 | iou = intersection[nonzero] / union[nonzero] 96 | if not keep_ignore: 97 | return iou 98 | else: 99 | result = np.full_like(intersection, -1) 100 | result[nonzero] = iou 101 | return result 102 | -------------------------------------------------------------------------------- /isegm/model/modeling/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from isegm.model import ops 4 | 5 | 6 | class ConvHead(nn.Module): 7 | def __init__(self, out_channels, in_channels=32, num_layers=1, 8 | kernel_size=3, padding=1, 9 | norm_layer=nn.BatchNorm2d): 10 | super(ConvHead, self).__init__() 11 | convhead = [] 12 | 13 | for i in range(num_layers): 14 | convhead.extend([ 15 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), 16 | nn.ReLU(), 17 | norm_layer(in_channels) if norm_layer is not None else nn.Identity() 18 | ]) 19 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) 20 | 21 | self.convhead = nn.Sequential(*convhead) 22 | 23 | def forward(self, *inputs): 24 | return self.convhead(inputs[0]) 25 | 26 | 27 | class SepConvHead(nn.Module): 28 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, 29 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, 30 | norm_layer=nn.BatchNorm2d): 31 | super(SepConvHead, self).__init__() 32 | 33 | sepconvhead = [] 34 | 35 | for i in range(num_layers): 36 | sepconvhead.append( 37 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, 38 | out_channels=mid_channels, 39 | dw_kernel=kernel_size, dw_padding=padding, 40 | norm_layer=norm_layer, activation='relu') 41 | ) 42 | if dropout_ratio > 0 and dropout_indx == i: 43 | sepconvhead.append(nn.Dropout(dropout_ratio)) 44 | 45 | sepconvhead.append( 46 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) 47 | ) 48 | 49 | self.layers = nn.Sequential(*sepconvhead) 50 | 51 | def forward(self, *inputs): 52 | x = inputs[0] 53 | 54 | return self.layers(x) 55 | 56 | 57 | class SeparableConv2d(nn.Module): 58 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, 59 | activation=None, use_bias=False, norm_layer=None): 60 | super(SeparableConv2d, self).__init__() 61 | _activation = ops.select_activation_function(activation) 62 | self.body = nn.Sequential( 63 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, 64 | padding=dw_padding, bias=use_bias, groups=in_channels), 65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), 66 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 67 | _activation() 68 | ) 69 | 70 | def forward(self, x): 71 | return self.body(x) 72 | -------------------------------------------------------------------------------- /isegm/model/modeling/deeplab_v3.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from .basic_blocks import SeparableConv2d 8 | from .resnet import ResNetBackbone 9 | from isegm.model import ops 10 | from isegm.model.modeling.cdnet.FDM import FDM, FDM_v2, FDM_v3 11 | 12 | class DeepLabV3Plus(nn.Module): 13 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, 14 | backbone_norm_layer=None, 15 | ch=256, 16 | project_dropout=0.5, 17 | inference_mode=False, 18 | **kwargs): 19 | super(DeepLabV3Plus, self).__init__() 20 | if backbone_norm_layer is None: 21 | backbone_norm_layer = norm_layer 22 | 23 | self.backbone_name = backbone 24 | self.norm_layer = norm_layer 25 | self.backbone_norm_layer = backbone_norm_layer 26 | self.inference_mode = False 27 | self.ch = ch 28 | self.aspp_in_channels = 2048 29 | self.skip_project_in_channels = 256 # layer 1 out_channels 30 | 31 | self._kwargs = kwargs 32 | 33 | if backbone == 'resnet34' or 'resnet18': 34 | self.aspp_in_channels = 512 35 | self.skip_project_in_channels = 64 36 | else: 37 | self.aspp_in_channels = 512 * 4 38 | self.skip_project_in_channels = 64 * 4 39 | 40 | 41 | 42 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, 43 | norm_layer=self.backbone_norm_layer, **kwargs) 44 | 45 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, 46 | norm_layer=self.norm_layer) 47 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) 48 | self.aspp = _ASPP(in_channels=self.aspp_in_channels, 49 | atrous_rates=[12, 24, 36], 50 | out_channels=ch, 51 | project_dropout=project_dropout, 52 | norm_layer=self.norm_layer) 53 | self.FDM = FDM_v3(self.ch,self.ch) 54 | 55 | if inference_mode: 56 | self.set_prediction_mode() 57 | 58 | def load_pretrained_weights(self): 59 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, 60 | norm_layer=self.backbone_norm_layer, **self._kwargs) 61 | backbone_state_dict = self.backbone.state_dict() 62 | pretrained_state_dict = pretrained.state_dict() 63 | 64 | backbone_state_dict.update(pretrained_state_dict) 65 | self.backbone.load_state_dict(backbone_state_dict) 66 | 67 | if self.inference_mode: 68 | for param in self.backbone.parameters(): 69 | param.requires_grad = False 70 | 71 | def set_prediction_mode(self): 72 | self.inference_mode = True 73 | self.eval() 74 | 75 | def forward(self, x, additional_features, small_clicks): 76 | 77 | with ExitStack() as stack: 78 | if self.inference_mode: 79 | stack.enter_context(torch.no_grad()) 80 | 81 | c1, _, c3, c4 = self.backbone(x, additional_features) 82 | c1 = self.skip_project(c1) 83 | 84 | x = self.aspp(c4) 85 | x, pos_map, neg_map = self.FDM(x, small_clicks) 86 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) 87 | x = torch.cat((x, c1), dim=1) 88 | x = self.head(x) 89 | return x,pos_map 90 | 91 | 92 | class _SkipProject(nn.Module): 93 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 94 | super(_SkipProject, self).__init__() 95 | _activation = ops.select_activation_function("relu") 96 | 97 | self.skip_project = nn.Sequential( 98 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 99 | norm_layer(out_channels), 100 | _activation() 101 | ) 102 | 103 | def forward(self, x): 104 | return self.skip_project(x) 105 | 106 | 107 | class _DeepLabHead(nn.Module): 108 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): 109 | super(_DeepLabHead, self).__init__() 110 | 111 | self.block = nn.Sequential( 112 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, 113 | dw_padding=1, activation='relu', norm_layer=norm_layer), 114 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, 115 | dw_padding=1, activation='relu', norm_layer=norm_layer), 116 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) 117 | ) 118 | 119 | def forward(self, x): 120 | return self.block(x) 121 | 122 | 123 | class _ASPP(nn.Module): 124 | def __init__(self, in_channels, atrous_rates, out_channels=256, 125 | project_dropout=0.5, norm_layer=nn.BatchNorm2d): 126 | super(_ASPP, self).__init__() 127 | 128 | b0 = nn.Sequential( 129 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), 130 | norm_layer(out_channels), 131 | nn.ReLU() 132 | ) 133 | 134 | rate1, rate2, rate3 = tuple(atrous_rates) 135 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) 136 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) 137 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) 138 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) 139 | 140 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) 141 | 142 | project = [ 143 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, 144 | kernel_size=1, bias=False), 145 | norm_layer(out_channels), 146 | nn.ReLU() 147 | ] 148 | if project_dropout > 0: 149 | project.append(nn.Dropout(project_dropout)) 150 | self.project = nn.Sequential(*project) 151 | 152 | def forward(self, x): 153 | x = torch.cat([block(x) for block in self.concurent], dim=1) 154 | 155 | return self.project(x) 156 | 157 | 158 | class _AsppPooling(nn.Module): 159 | def __init__(self, in_channels, out_channels, norm_layer): 160 | super(_AsppPooling, self).__init__() 161 | 162 | self.gap = nn.Sequential( 163 | nn.AdaptiveAvgPool2d((1, 1)), 164 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 165 | kernel_size=1, bias=False), 166 | norm_layer(out_channels), 167 | nn.ReLU() 168 | ) 169 | 170 | def forward(self, x): 171 | pool = self.gap(x) 172 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) 173 | 174 | 175 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): 176 | block = nn.Sequential( 177 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 178 | kernel_size=3, padding=atrous_rate, 179 | dilation=atrous_rate, bias=False), 180 | norm_layer(out_channels), 181 | nn.ReLU() 182 | ) 183 | 184 | return block 185 | -------------------------------------------------------------------------------- /isegm/model/modeling/deeplab_v3_gp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from .basic_blocks import SeparableConv2d 6 | from .resnet import ResNetBackbone 7 | from isegm.model import ops 8 | 9 | class DeepLabV3Plus(nn.Module): 10 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, 11 | backbone_norm_layer=None, 12 | ch=256, 13 | project_dropout=0.5, 14 | inference_mode=False, 15 | weight_dir=None, 16 | **kwargs): 17 | super(DeepLabV3Plus, self).__init__() 18 | if backbone_norm_layer is None: 19 | backbone_norm_layer = norm_layer 20 | 21 | self.backbone_name = backbone 22 | self.norm_layer = norm_layer 23 | self.backbone_norm_layer = backbone_norm_layer 24 | self.inference_mode = False 25 | self.ch = ch 26 | self.aspp_in_channels = 2048 27 | self.skip_project_in_channels = 256 # layer 1 out_channels 28 | self.weight_dir=weight_dir 29 | 30 | self._kwargs = kwargs 31 | 32 | self.aspp_in_channels = 512 * 4 33 | self.skip_project_in_channels = 64 * 4 34 | 35 | 36 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, 37 | norm_layer=self.backbone_norm_layer,weight_dir=weight_dir, **kwargs) 38 | 39 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=48, 40 | norm_layer=self.norm_layer) 41 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) 42 | self.aspp = _ASPP(in_channels=self.aspp_in_channels, 43 | atrous_rates=[12, 24, 36], 44 | out_channels=ch, 45 | project_dropout=project_dropout, 46 | norm_layer=self.norm_layer) 47 | self.feature_dim = 256 48 | 49 | if inference_mode: 50 | self.set_prediction_mode() 51 | 52 | def load_pretrained_weights(self): 53 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, 54 | norm_layer=self.backbone_norm_layer,weight_dir=self.weight_dir, **self._kwargs) 55 | backbone_state_dict = self.backbone.state_dict() 56 | pretrained_state_dict = pretrained.state_dict() 57 | 58 | backbone_state_dict.update(pretrained_state_dict) 59 | self.backbone.load_state_dict(backbone_state_dict) 60 | 61 | if self.inference_mode: 62 | for param in self.backbone.parameters(): 63 | param.requires_grad = False 64 | 65 | def set_prediction_mode(self): 66 | self.inference_mode = True 67 | self.eval() 68 | 69 | def forward(self, x, additional_features): 70 | 71 | c1, _, _, c4 = self.backbone(x, additional_features) 72 | 73 | c1 = self.skip_project(c1) 74 | 75 | x = self.aspp(c4) 76 | 77 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) 78 | x = torch.cat((x, c1), dim=1) 79 | return self.head(x) 80 | 81 | class _SkipProject(nn.Module): 82 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 83 | super(_SkipProject, self).__init__() 84 | _activation = ops.select_activation_function("relu") 85 | 86 | self.skip_project = nn.Sequential( 87 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 88 | norm_layer(out_channels), 89 | _activation() 90 | ) 91 | 92 | def forward(self, x): 93 | return self.skip_project(x) 94 | 95 | 96 | class _DeepLabHead(nn.Module): 97 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): 98 | super(_DeepLabHead, self).__init__() 99 | 100 | self.block = nn.Sequential( 101 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, 102 | dw_padding=1, activation='relu', norm_layer=norm_layer), 103 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, 104 | dw_padding=1, activation='relu', norm_layer=norm_layer), 105 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) 106 | ) 107 | 108 | def forward(self, x): 109 | return self.block(x) 110 | 111 | 112 | class _ASPP(nn.Module): 113 | def __init__(self, in_channels, atrous_rates, out_channels=256, 114 | project_dropout=0.5, norm_layer=nn.BatchNorm2d): 115 | super(_ASPP, self).__init__() 116 | 117 | b0 = nn.Sequential( 118 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), 119 | norm_layer(out_channels), 120 | nn.ReLU() 121 | ) 122 | 123 | rate1, rate2, rate3 = tuple(atrous_rates) 124 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) 125 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) 126 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) 127 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) 128 | 129 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) 130 | 131 | project = [ 132 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, 133 | kernel_size=1, bias=False), 134 | norm_layer(out_channels), 135 | nn.ReLU() 136 | ] 137 | if project_dropout > 0: 138 | project.append(nn.Dropout(project_dropout)) 139 | self.project = nn.Sequential(*project) 140 | 141 | def forward(self, x): 142 | x = torch.cat([block(x) for block in self.concurent], dim=1) 143 | 144 | return self.project(x) 145 | 146 | 147 | class _AsppPooling(nn.Module): 148 | def __init__(self, in_channels, out_channels, norm_layer): 149 | super(_AsppPooling, self).__init__() 150 | 151 | self.gap = nn.Sequential( 152 | nn.AdaptiveAvgPool2d((1, 1)), 153 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 154 | kernel_size=1, bias=False), 155 | norm_layer(out_channels), 156 | nn.ReLU() 157 | ) 158 | 159 | def forward(self, x): 160 | pool = self.gap(x) 161 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) 162 | 163 | 164 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): 165 | block = nn.Sequential( 166 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 167 | kernel_size=3, padding=atrous_rate, 168 | dilation=atrous_rate, bias=False), 169 | norm_layer(out_channels), 170 | nn.ReLU() 171 | ) 172 | 173 | return block 174 | -------------------------------------------------------------------------------- /isegm/model/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnetv1b import resnet18_v1b, resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s 3 | 4 | 5 | class ResNetBackbone(torch.nn.Module): 6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True,weight_dir=None, **kwargs): 7 | super(ResNetBackbone, self).__init__() 8 | 9 | if backbone == 'resnet18': 10 | pretrained = resnet18_v1b(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs) 11 | elif backbone == 'resnet34': 12 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs) 13 | elif backbone == 'resnet50': 14 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs) 15 | elif backbone == 'resnet101': 16 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs) 17 | elif backbone == 'resnet152': 18 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs) 19 | else: 20 | raise RuntimeError(f'unknown backbone: {backbone}') 21 | 22 | self.conv1 = pretrained.conv1 23 | self.bn1 = pretrained.bn1 24 | self.relu = pretrained.relu 25 | self.maxpool = pretrained.maxpool 26 | self.layer1 = pretrained.layer1 27 | self.layer2 = pretrained.layer2 28 | self.layer3 = pretrained.layer3 29 | self.layer4 = pretrained.layer4 30 | 31 | def forward(self, x, additional_features=None): 32 | x = self.conv1(x) 33 | x = self.bn1(x) 34 | x = self.relu(x) 35 | if additional_features is not None: 36 | x = x + torch.nn.functional.pad(additional_features, 37 | [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], 38 | mode='constant', value=0) 39 | x = self.maxpool(x) 40 | c1 = self.layer1(x) 41 | c2 = self.layer2(c1) 42 | c3 = self.layer3(c2) 43 | c4 = self.layer4(c3) 44 | 45 | return c1, c2, c3, c4 46 | -------------------------------------------------------------------------------- /isegm/model/modeling/resnetv1b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' # This is open source, not me 4 | 5 | 6 | class BasicBlockV1b(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 10 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 11 | super(BasicBlockV1b, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 13 | padding=dilation, dilation=dilation, bias=False) 14 | self.bn1 = norm_layer(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 16 | padding=previous_dilation, dilation=previous_dilation, bias=False) 17 | self.bn2 = norm_layer(planes) 18 | 19 | self.relu = nn.ReLU(inplace=True) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out = out + residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | 42 | class BottleneckV1b(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 46 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 47 | super(BottleneckV1b, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = norm_layer(planes) 50 | 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 52 | padding=dilation, dilation=dilation, bias=False) 53 | self.bn2 = norm_layer(planes) 54 | 55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 56 | self.bn3 = norm_layer(planes * self.expansion) 57 | 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out = out + residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNetV1b(nn.Module): 86 | """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. 87 | 88 | Parameters 89 | ---------- 90 | block : Block 91 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 92 | layers : list of int 93 | Numbers of layers in each block 94 | classes : int, default 1000 95 | Number of classification classes. 96 | dilated : bool, default False 97 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 98 | typically used in Semantic Segmentation. 99 | norm_layer : object 100 | Normalization layer used (default: :class:`nn.BatchNorm2d`) 101 | deep_stem : bool, default False 102 | Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. 103 | avg_down : bool, default False 104 | Whether to use average pooling for projection skip connection between stages/downsample. 105 | final_drop : float, default 0.0 106 | Dropout ratio before the final classification layer. 107 | 108 | Reference: 109 | - He, Kaiming, et al. "Deep residual learning for image recognition." 110 | Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 111 | 112 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 113 | """ 114 | def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, 115 | avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): 116 | self.inplanes = stem_width*2 if deep_stem else 64 117 | super(ResNetV1b, self).__init__() 118 | if not deep_stem: 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 120 | else: 121 | self.conv1 = nn.Sequential( 122 | nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), 123 | norm_layer(stem_width), 124 | nn.ReLU(True), 125 | nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), 126 | norm_layer(stem_width), 127 | nn.ReLU(True), 128 | nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) 129 | ) 130 | self.bn1 = norm_layer(self.inplanes) 131 | self.relu = nn.ReLU(True) 132 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, 134 | norm_layer=norm_layer) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, 136 | norm_layer=norm_layer) 137 | if dilated: 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, 139 | avg_down=avg_down, norm_layer=norm_layer) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, 141 | avg_down=avg_down, norm_layer=norm_layer) 142 | else: 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 144 | avg_down=avg_down, norm_layer=norm_layer) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 146 | avg_down=avg_down, norm_layer=norm_layer) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.drop = None 149 | if final_drop > 0.0: 150 | self.drop = nn.Dropout(final_drop) 151 | self.fc = nn.Linear(512 * block.expansion, classes) 152 | 153 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 154 | avg_down=False, norm_layer=nn.BatchNorm2d): 155 | downsample = None 156 | if stride != 1 or self.inplanes != planes * block.expansion: 157 | downsample = [] 158 | if avg_down: 159 | if dilation == 1: 160 | downsample.append( 161 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) 162 | ) 163 | else: 164 | downsample.append( 165 | nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) 166 | ) 167 | downsample.extend([ 168 | nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, 169 | kernel_size=1, stride=1, bias=False), 170 | norm_layer(planes * block.expansion) 171 | ]) 172 | downsample = nn.Sequential(*downsample) 173 | else: 174 | downsample = nn.Sequential( 175 | nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, 176 | kernel_size=1, stride=stride, bias=False), 177 | norm_layer(planes * block.expansion) 178 | ) 179 | 180 | layers = [] 181 | if dilation in (1, 2): 182 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, 183 | previous_dilation=dilation, norm_layer=norm_layer)) 184 | elif dilation == 4: 185 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, 186 | previous_dilation=dilation, norm_layer=norm_layer)) 187 | else: 188 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 189 | 190 | self.inplanes = planes * block.expansion 191 | for _ in range(1, blocks): 192 | layers.append(block(self.inplanes, planes, dilation=dilation, 193 | previous_dilation=dilation, norm_layer=norm_layer)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | x = self.conv1(x) 199 | x = self.bn1(x) 200 | x = self.relu(x) 201 | x = self.maxpool(x) 202 | 203 | x = self.layer1(x) 204 | x = self.layer2(x) 205 | x = self.layer3(x) 206 | x = self.layer4(x) 207 | 208 | x = self.avgpool(x) 209 | x = x.view(x.size(0), -1) 210 | if self.drop is not None: 211 | x = self.drop(x) 212 | x = self.fc(x) 213 | 214 | return x 215 | 216 | 217 | def _safe_state_dict_filtering(orig_dict, model_dict_keys): 218 | filtered_orig_dict = {} 219 | for k, v in orig_dict.items(): 220 | if k in model_dict_keys: 221 | filtered_orig_dict[k] = v 222 | else: 223 | print(f"[ERROR] Failed to load <{k}> in backbone") 224 | return filtered_orig_dict 225 | 226 | 227 | 228 | def resnet18_v1b(pretrained=False, **kwargs): 229 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs) 230 | if pretrained: 231 | pass 232 | return model 233 | 234 | 235 | def resnet34_v1b(pretrained=False, weight_dir=None, **kwargs): 236 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) 237 | if pretrained: 238 | model_dict = model.state_dict() 239 | if weight_dir is None: 240 | filtered_orig_dict = _safe_state_dict_filtering( 241 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), 242 | model_dict.keys() 243 | ) 244 | else: 245 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu')#['state_dict'] 246 | model_dict.update(filtered_orig_dict) 247 | model.load_state_dict(model_dict) 248 | return model 249 | 250 | 251 | def resnet50_v1s(pretrained=False, weight_dir=None, **kwargs): 252 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) 253 | if pretrained: 254 | model_dict = model.state_dict() 255 | if weight_dir is None: 256 | filtered_orig_dict = _safe_state_dict_filtering( 257 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), 258 | model_dict.keys() 259 | ) 260 | else: 261 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu') 262 | model_dict.update(filtered_orig_dict) 263 | model.load_state_dict(model_dict) 264 | return model 265 | 266 | 267 | def resnet101_v1s(pretrained=False, weight_dir=None, **kwargs): 268 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) 269 | if pretrained: 270 | model_dict = model.state_dict() 271 | if weight_dir is None: 272 | filtered_orig_dict = _safe_state_dict_filtering( 273 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), 274 | model_dict.keys() 275 | ) 276 | else: 277 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu') 278 | model_dict.update(filtered_orig_dict) 279 | model.load_state_dict(model_dict) 280 | return model 281 | 282 | 283 | def resnet152_v1s(pretrained=False, weight_dir=None, **kwargs): 284 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) 285 | if pretrained: 286 | model_dict = model.state_dict() 287 | if weight_dir is None: 288 | filtered_orig_dict = _safe_state_dict_filtering( 289 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), 290 | model_dict.keys() 291 | ) 292 | else: 293 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu') 294 | model_dict.update(filtered_orig_dict) 295 | model.load_state_dict(model_dict) 296 | return model 297 | -------------------------------------------------------------------------------- /isegm/model/modifiers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class LRMult(object): 4 | def __init__(self, lr_mult=1.): 5 | self.lr_mult = lr_mult 6 | 7 | def __call__(self, m): 8 | if getattr(m, 'weight', None) is not None: 9 | m.weight.lr_mult = self.lr_mult 10 | if getattr(m, 'bias', None) is not None: 11 | m.bias.lr_mult = self.lr_mult 12 | -------------------------------------------------------------------------------- /isegm/model/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | import numpy as np 4 | import isegm.model.initializer as initializer 5 | 6 | 7 | def select_activation_function(activation): 8 | if isinstance(activation, str): 9 | if activation.lower() == 'relu': 10 | return nn.ReLU 11 | elif activation.lower() == 'softplus': 12 | return nn.Softplus 13 | else: 14 | raise ValueError(f"Unknown activation type {activation}") 15 | elif isinstance(activation, nn.Module): 16 | return activation 17 | else: 18 | raise ValueError(f"Unknown activation type {activation}") 19 | 20 | 21 | class BilinearConvTranspose2d(nn.ConvTranspose2d): 22 | def __init__(self, in_channels, out_channels, scale, groups=1): 23 | kernel_size = 2 * scale - scale % 2 24 | self.scale = scale 25 | 26 | super().__init__( 27 | in_channels, out_channels, 28 | kernel_size=kernel_size, 29 | stride=scale, 30 | padding=1, 31 | groups=groups, 32 | bias=False) 33 | 34 | self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) 35 | 36 | 37 | class DistMaps(nn.Module): 38 | def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): 39 | super(DistMaps, self).__init__() 40 | self.spatial_scale = spatial_scale 41 | self.norm_radius = norm_radius 42 | self.cpu_mode = cpu_mode 43 | self.use_disks = use_disks 44 | if self.cpu_mode: 45 | from isegm.utils.cython import get_dist_maps 46 | self._get_dist_maps = get_dist_maps 47 | 48 | def get_coord_features(self, points, batchsize, rows, cols): 49 | if self.cpu_mode: 50 | coords = [] 51 | for i in range(batchsize): 52 | norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius 53 | coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, 54 | norm_delimeter)) 55 | coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() 56 | else: 57 | num_points = points.shape[1] // 2 58 | points = points.view(-1, points.size(2)) 59 | points, points_order = torch.split(points, [2, 1], dim=1) 60 | 61 | invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 62 | row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) 63 | col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) 64 | 65 | coord_rows, coord_cols = torch.meshgrid(row_array, col_array) 66 | coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) 67 | 68 | add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) 69 | coords.add_(-add_xy) 70 | if not self.use_disks: 71 | coords.div_(self.norm_radius * self.spatial_scale) 72 | coords.mul_(coords) 73 | 74 | coords[:, 0] += coords[:, 1] 75 | coords = coords[:, :1] 76 | 77 | coords[invalid_points, :, :, :] = 1e6 78 | 79 | coords = coords.view(-1, num_points, 1, rows, cols) 80 | coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w 81 | coords = coords.view(-1, 2, rows, cols) 82 | 83 | if self.use_disks: 84 | coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() 85 | else: 86 | coords.sqrt_().mul_(2).tanh_() 87 | 88 | return coords 89 | 90 | def forward(self, x, coords): 91 | return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) 92 | 93 | 94 | class ScaleLayer(nn.Module): 95 | def __init__(self, init_value=1.0, lr_mult=1): 96 | super().__init__() 97 | self.lr_mult = lr_mult 98 | self.scale = nn.Parameter( 99 | torch.full((1,), init_value / lr_mult, dtype=torch.float32) 100 | ) 101 | 102 | def forward(self, x): 103 | scale = torch.abs(self.scale * self.lr_mult) 104 | return x * scale 105 | 106 | 107 | class BatchImageNormalize: 108 | def __init__(self, mean, std, dtype=torch.float): 109 | self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] 110 | self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] 111 | 112 | def __call__(self, tensor): 113 | tensor = tensor.clone() 114 | 115 | tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) 116 | return tensor 117 | -------------------------------------------------------------------------------- /isegm/utils/crop_local.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from skimage.measure import label 4 | 5 | def map_point_in_bbox(y,x,y1,y2,x1,x2,crop_l): 6 | h,w = y2-y1, x2-x1 7 | ry,rx = crop_l/h, crop_l/w 8 | y = (y - y1) * ry 9 | x = (x - x1) * rx 10 | return y,x 11 | 12 | 13 | 14 | def get_focus_cropv1(pred_mask, previous_mask, global_roi, y,x, ratio): 15 | pred_mask = pred_mask > 0.49 16 | previous_mask = previous_mask > 0.49 17 | ymin,ymax,xmin,xmax = global_roi 18 | diff_regions = np.logical_xor(previous_mask, pred_mask) 19 | if previous_mask.sum() == 0: 20 | y1,y2,x1,x2 = get_bbox_from_mask(pred_mask) 21 | else: 22 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8)) 23 | label = labels[y,x] 24 | diff_conn_mask = labels == label 25 | y1d,y2d,x1d,x2d = get_bbox_from_mask(diff_conn_mask) 26 | hd,wd = y2d - y1d, x2d - x1d 27 | 28 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask) 29 | hp,wp = y2p - y1p, x2p - x1p 30 | 31 | if hd < hp/3 or wd < wp/3: 32 | r = 0.2 33 | l = max(hp,wp) 34 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l 35 | else: 36 | y1,y2,x1,x2 = y1d,y2d,x1d,x2d 37 | 38 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio ) 39 | y1 = max(y1,ymin) 40 | y2 = min(y2,ymax) 41 | x1 = max(x1,xmin) 42 | x2 = min(x2,xmax) 43 | return y1,y2,x1,x2 44 | 45 | 46 | def get_focus_cropv2(pred_mask, previous_mask, global_roi, y,x, ratio): 47 | pred_mask = pred_mask > 0.5 48 | previous_mask = previous_mask > 0.5 49 | ymin,ymax,xmin,xmax = global_roi 50 | diff_regions = np.logical_xor(previous_mask, pred_mask) 51 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8)) 52 | label = labels[y,x] 53 | diff_conn_mask = labels == label 54 | 55 | y1d,y2d,x1d,x2d = get_bbox_from_mask(diff_conn_mask) 56 | hd,wd = y2d - y1d, x2d - x1d 57 | 58 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask) 59 | hp,wp = y2p - y1p, x2p - x1p 60 | 61 | if previous_mask.sum() == 0: 62 | y1,y2,x1,x2 = y1p,y2p,x1p,x2p 63 | else: 64 | if hd < hp/3 or wd < wp/3: 65 | r = 0.16 66 | l = max(hp,wp) 67 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l 68 | else: 69 | if diff_conn_mask.sum() > diff_regions.sum() * 0.5: 70 | y1,y2,x1,x2 = y1d,y2d,x1d,x2d 71 | else: 72 | y1,y2,x1,x2 = y1p,y2p,x1p,x2p 73 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio ) 74 | y1 = max(y1,ymin) 75 | y2 = min(y2,ymax) 76 | x1 = max(x1,xmin) 77 | x2 = min(x2,xmax) 78 | return y1,y2,x1,x2 79 | 80 | 81 | def get_object_crop(pred_mask, previous_mask, global_roi, y,x, ratio): 82 | pred_mask = pred_mask > 0.49 83 | y1,y2,x1,x2 = get_bbox_from_mask(pred_mask) 84 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio ) 85 | ymin,ymax,xmin,xmax = global_roi 86 | y1 = max(y1,ymin) 87 | y2 = min(y2,ymax) 88 | x1 = max(x1,xmin) 89 | x2 = min(x2,xmax) 90 | return y1,y2,x1,x2 91 | 92 | 93 | 94 | def get_click_crop(pred_mask, previous_mask, global_roi, y,x, ratio): 95 | pred_mask = pred_mask > 0.49 96 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask) 97 | hp,wp = y2p - y1p, x2p - x1p 98 | r = 0.2 99 | l = max(hp,wp) 100 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l 101 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio ) 102 | ymin,ymax,xmin,xmax = global_roi 103 | y1 = max(y1,ymin) 104 | y2 = min(y2,ymax) 105 | x1 = max(x1,xmin) 106 | x2 = min(x2,xmax) 107 | return y1,y2,x1,x2 108 | 109 | 110 | 111 | 112 | def getLargestCC(segmentation): 113 | if segmentation.sum()<10: 114 | return segmentation 115 | labels = label(segmentation) 116 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 117 | return largestCC 118 | 119 | 120 | 121 | def get_diff_region(pred_mask, previous_mask, y, x): 122 | y,x = int(y), int(x) 123 | diff_regions = np.logical_xor(previous_mask, pred_mask) 124 | if diff_regions.sum() > 1000: 125 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8)) 126 | label = labels[y,x] 127 | corr_mask = labels == label 128 | else: 129 | corr_mask = pred_mask 130 | return corr_mask 131 | 132 | 133 | 134 | 135 | def get_bbox_from_mask(mask): 136 | h,w = mask.shape[0],mask.shape[1] 137 | 138 | if mask.sum() < 10: 139 | return 0,h,0,w 140 | rows = np.any(mask,axis=1) 141 | cols = np.any(mask,axis=0) 142 | y1,y2 = np.where(rows)[0][[0,-1]] 143 | x1,x2 = np.where(cols)[0][[0,-1]] 144 | return y1,y2,x1,x2 145 | 146 | def expand_bbox(mask,y1,y2,x1,x2,ratio, min_crop=0): 147 | H,W = mask.shape[0], mask.shape[1] 148 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2) 149 | h = ratio * (y2-y1+1) 150 | w = ratio * (x2-x1+1) 151 | h = max(h,min_crop) 152 | w = max(w,min_crop) 153 | 154 | x1 = int(xc - w * 0.5) 155 | x2 = int(xc + w * 0.5) 156 | y1 = int(yc - h * 0.5) 157 | y2 = int(yc + h * 0.5) 158 | 159 | x1 = max(0,x1) 160 | x2 = min(W,x2) 161 | y1 = max(0,y1) 162 | y2 = min(H,y2) 163 | return y1,y2,x1,x2 164 | 165 | 166 | def expand_bbox_with_bias(mask,y1,y2,x1,x2,ratio, min_crop=0, bias = 0.3): 167 | H,W = mask.shape[0], mask.shape[1] 168 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2) 169 | h = ratio * (y2-y1+1) 170 | w = ratio * (x2-x1+1) 171 | h = max(h,min_crop) 172 | w = max(w,min_crop) 173 | hmax, wmax = int(h * bias), int(w * bias) 174 | h_bias = np.random.randint(-hmax,hmax+1) 175 | w_bias = np.random.randint(-wmax,wmax+1) 176 | 177 | x1 = int(xc - w * 0.5) + w_bias 178 | x2 = int(xc + w * 0.5) + w_bias 179 | y1 = int(yc - h * 0.5) + h_bias 180 | y2 = int(yc + h * 0.5) + h_bias 181 | 182 | x1 = max(0,x1) 183 | x2 = min(W,x2) 184 | y1 = max(0,y1) 185 | y2 = min(H,y2) 186 | return y1,y2,x1,x2 187 | 188 | 189 | 190 | def CalBox(mask,last_y = None, last_x = None, expand = 1.5): 191 | y1,y2,x1,x2 = get_bbox_from_mask(mask) 192 | H,W = mask.shape[0], mask.shape[1] 193 | if last_y is not None: 194 | y1 = min(y1,last_y) 195 | y2 = max(y2,last_y) 196 | x1 = min(x1, last_x) 197 | x2 = max(x2,last_x) 198 | 199 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2) 200 | h = expand * (y2-y1+1) 201 | w = expand * (x2-x1+1) 202 | x1 = int(xc - w * 0.5) 203 | x2 = int(xc + w * 0.5) 204 | y1 = int(yc - h * 0.5) 205 | y2 = int(yc + h * 0.5) 206 | 207 | x1 = max(0,x1) 208 | x2 = min(W,x2) 209 | y1 = max(0,y1) 210 | y2 = min(H,y2) 211 | return y1,y2,x1,x2 212 | 213 | def points_back(p_np, y1, x1): 214 | if p_np is None: 215 | return None 216 | bias = np.array( [[y1,x1]]).reshape((1,2)) 217 | return p_np + bias 218 | 219 | 220 | 221 | 222 | def PointsInBox(points,y1,y2,x1,x2, H, W ): 223 | if points is None: 224 | return None 225 | 226 | y_ratio = H/(y2-y1) 227 | x_ratio = W/(x2-x1) 228 | num_pos = points.shape[0] // 2 229 | new_points = np.full_like(points,-1) 230 | 231 | valid_pos = 0 232 | for i in range(num_pos): 233 | y,x,index = points[i,0], points[i,1],points[i,2] 234 | if y>y1 and y< y2 and x>x1 and xy1 and y< y2 and x>x1 and x h or crop_w > w: 267 | return 0,0,h,w 268 | 269 | 270 | delta_h = center_h = crop_h // 2 271 | delta_w = center_w = crop_w // 2 272 | 273 | # mask out the validate area for selecting the cropping center 274 | mask = np.zeros_like(unknown) 275 | mask[delta_h:h - delta_h, delta_w:w - delta_w] = 1 276 | if np.any(unknown & mask): 277 | center_h_list, center_w_list = np.where(unknown & mask) 278 | elif np.any(unknown): 279 | center_h_list, center_w_list = np.where(unknown) 280 | else: 281 | #print_log('No unknown pixels found!', level=logging.WARNING) 282 | center_h_list = [center_h] 283 | center_w_list = [center_w] 284 | num_unknowns = len(center_h_list) 285 | rand_ind = np.random.randint(num_unknowns) 286 | center_h = center_h_list[rand_ind] 287 | center_w = center_w_list[rand_ind] 288 | 289 | # make sure the top-left point is valid 290 | top = np.clip(center_h - delta_h, 0, h - crop_h) 291 | left = np.clip(center_w - delta_w, 0, w - crop_w) 292 | y1,x1,y2,x2 = top, left, top + crop_h, left + crop_w 293 | 294 | return y1,x1,y2,x2 -------------------------------------------------------------------------------- /isegm/utils/cython/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .dist_maps import get_dist_maps -------------------------------------------------------------------------------- /isegm/utils/cython/_get_dist_maps.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport cython 3 | cimport numpy as np 4 | from libc.stdlib cimport malloc, free 5 | 6 | ctypedef struct qnode: 7 | int row 8 | int col 9 | int layer 10 | int orig_row 11 | int orig_col 12 | 13 | @cython.infer_types(True) 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | @cython.nonecheck(False) 17 | def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, 18 | int height, int width, float norm_delimeter): 19 | cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ 20 | np.full((2, height, width), 1e6, dtype=np.float32, order="C") 21 | 22 | cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] 23 | cdef int i, j, x, y, dx, dy 24 | cdef qnode v 25 | cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) 26 | cdef int qhead = 0, qtail = -1 27 | cdef float ndist 28 | 29 | for i in range(points.shape[0]): 30 | x, y = round(points[i, 0]), round(points[i, 1]) 31 | if x >= 0: 32 | qtail += 1 33 | q[qtail].row = x 34 | q[qtail].col = y 35 | q[qtail].orig_row = x 36 | q[qtail].orig_col = y 37 | if i >= points.shape[0] / 2: 38 | q[qtail].layer = 1 39 | else: 40 | q[qtail].layer = 0 41 | dist_maps[q[qtail].layer, x, y] = 0 42 | 43 | while qtail - qhead + 1 > 0: 44 | v = q[qhead] 45 | qhead += 1 46 | 47 | for k in range(4): 48 | x = v.row + dxy[2 * k] 49 | y = v.col + dxy[2 * k + 1] 50 | 51 | ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 52 | if (x >= 0 and y >= 0 and x < height and y < width and 53 | dist_maps[v.layer, x, y] > ndist): 54 | qtail += 1 55 | q[qtail].orig_col = v.orig_col 56 | q[qtail].orig_row = v.orig_row 57 | q[qtail].layer = v.layer 58 | q[qtail].row = x 59 | q[qtail].col = y 60 | dist_maps[v.layer, x, y] = ndist 61 | 62 | free(q) 63 | return dist_maps 64 | -------------------------------------------------------------------------------- /isegm/utils/cython/_get_dist_maps.pyxbld: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def make_ext(modname, pyxfilename): 4 | from distutils.extension import Extension 5 | return Extension(modname, [pyxfilename], 6 | include_dirs=[numpy.get_include()], 7 | extra_compile_args=['-O3'], language='c++') 8 | -------------------------------------------------------------------------------- /isegm/utils/cython/dist_maps.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install(pyximport=True, language_level=3) 2 | # noinspection PyUnresolvedReferences 3 | from ._get_dist_maps import get_dist_maps -------------------------------------------------------------------------------- /isegm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed as dist 3 | from torch.utils import data 4 | 5 | 6 | def get_rank(): 7 | if not dist.is_available() or not dist.is_initialized(): 8 | return 0 9 | return dist.get_rank() 10 | 11 | 12 | def synchronize(): 13 | if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: 14 | return 15 | dist.barrier() 16 | 17 | 18 | def get_world_size(): 19 | if not dist.is_available() or not dist.is_initialized(): 20 | return 1 21 | 22 | return dist.get_world_size() 23 | 24 | 25 | def reduce_loss_dict(loss_dict): 26 | world_size = get_world_size() 27 | 28 | if world_size < 2: 29 | return loss_dict 30 | 31 | with torch.no_grad(): 32 | keys = [] 33 | losses = [] 34 | 35 | for k in loss_dict.keys(): 36 | keys.append(k) 37 | losses.append(loss_dict[k]) 38 | 39 | losses = torch.stack(losses, 0) 40 | dist.reduce(losses, dst=0) 41 | 42 | if dist.get_rank() == 0: 43 | losses /= world_size 44 | 45 | reduced_losses = {k: v for k, v in zip(keys, losses)} 46 | 47 | return reduced_losses 48 | 49 | 50 | def get_sampler(dataset, shuffle, distributed): 51 | if distributed: 52 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 53 | 54 | if shuffle: 55 | return data.RandomSampler(dataset) 56 | else: 57 | return data.SequentialSampler(dataset) 58 | 59 | 60 | def get_dp_wrapper(distributed): 61 | class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): 62 | def __getattr__(self, name): 63 | try: 64 | return super().__getattr__(name) 65 | except AttributeError: 66 | return getattr(self.module, name) 67 | return DPWrapper 68 | -------------------------------------------------------------------------------- /isegm/utils/exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import pprint 5 | from pathlib import Path 6 | from datetime import datetime 7 | 8 | import yaml 9 | import torch 10 | from easydict import EasyDict as edict 11 | 12 | from .log import logger, add_logging 13 | from .distributed import synchronize, get_world_size 14 | 15 | 16 | def init_experiment(args, model_name): 17 | model_path = Path(args.model_path) 18 | ftree = get_model_family_tree(model_path, model_name=model_name) 19 | 20 | if ftree is None: 21 | print('Models can only be located in the "models" directory in the root of the repository') 22 | sys.exit(1) 23 | 24 | cfg = load_config(model_path) 25 | update_config(cfg, args) 26 | 27 | cfg.distributed = args.distributed 28 | cfg.local_rank = args.local_rank 29 | if cfg.distributed: 30 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 31 | if args.workers > 0: 32 | torch.multiprocessing.set_start_method('forkserver', force=True) 33 | 34 | experiments_path = Path(cfg.EXPS_PATH) 35 | exp_parent_path = experiments_path / '/'.join(ftree) 36 | exp_parent_path.mkdir(parents=True, exist_ok=True) 37 | 38 | if cfg.resume_exp: 39 | exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) 40 | else: 41 | last_exp_indx = find_last_exp_indx(exp_parent_path) 42 | exp_name = f'{last_exp_indx:03d}' 43 | if cfg.exp_name: 44 | exp_name += '_' + cfg.exp_name 45 | exp_path = exp_parent_path / exp_name 46 | synchronize() 47 | if cfg.local_rank == 0: 48 | exp_path.mkdir(parents=True) 49 | 50 | cfg.EXP_PATH = exp_path 51 | cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' 52 | cfg.VIS_PATH = exp_path / 'vis' 53 | cfg.LOGS_PATH = exp_path / 'logs' 54 | 55 | if cfg.local_rank == 0: 56 | cfg.LOGS_PATH.mkdir(exist_ok=True) 57 | cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) 58 | cfg.VIS_PATH.mkdir(exist_ok=True) 59 | 60 | dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) 61 | if args.temp_model_path: 62 | shutil.copy(args.temp_model_path, dst_script_path) 63 | os.remove(args.temp_model_path) 64 | else: 65 | shutil.copy(model_path, dst_script_path) 66 | 67 | synchronize() 68 | 69 | if cfg.gpus != '': 70 | gpu_ids = [int(id) for id in cfg.gpus.split(',')] 71 | else: 72 | gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) 73 | cfg.gpus = ','.join([str(id) for id in gpu_ids]) 74 | 75 | cfg.gpu_ids = gpu_ids 76 | cfg.ngpus = len(gpu_ids) 77 | cfg.multi_gpu = cfg.ngpus > 1 78 | 79 | if cfg.distributed: 80 | cfg.device = torch.device('cuda') 81 | cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] 82 | torch.cuda.set_device(cfg.gpu_ids[0]) 83 | else: 84 | if cfg.multi_gpu: 85 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus 86 | ngpus = torch.cuda.device_count() 87 | # Added by Xavier 88 | # cfg.gpu_ids = [i for i in range(ngpus)] 89 | # assert ngpus == cfg.ngpus 90 | cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') 91 | 92 | if cfg.local_rank == 0: 93 | add_logging(cfg.LOGS_PATH, prefix='train_') 94 | logger.info(f'Number of GPUs: {cfg.ngpus}') 95 | if cfg.distributed: 96 | logger.info(f'Multi-Process Multi-GPU Distributed Training') 97 | 98 | logger.info('Run experiment with config:') 99 | logger.info(pprint.pformat(cfg, indent=4)) 100 | 101 | return cfg 102 | 103 | 104 | def get_model_family_tree(model_path, terminate_name='models', model_name=None): 105 | if model_name is None: 106 | model_name = model_path.stem 107 | family_tree = [model_name] 108 | for x in model_path.parents: 109 | if x.stem == terminate_name: 110 | break 111 | family_tree.append(x.stem) 112 | else: 113 | return None 114 | 115 | return family_tree[::-1] 116 | 117 | 118 | def find_last_exp_indx(exp_parent_path): 119 | indx = 0 120 | for x in exp_parent_path.iterdir(): 121 | if not x.is_dir(): 122 | continue 123 | 124 | exp_name = x.stem 125 | if exp_name[:3].isnumeric(): 126 | indx = max(indx, int(exp_name[:3]) + 1) 127 | 128 | return indx 129 | 130 | 131 | def find_resume_exp(exp_parent_path, exp_pattern): 132 | candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) 133 | if len(candidates) == 0: 134 | print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') 135 | sys.exit(1) 136 | elif len(candidates) > 1: 137 | print('More than one experiment found:') 138 | for x in candidates: 139 | print(x) 140 | sys.exit(1) 141 | else: 142 | exp_path = candidates[0] 143 | print(f'Continue with experiment "{exp_path}"') 144 | 145 | return exp_path 146 | 147 | 148 | def update_config(cfg, args): 149 | for param_name, value in vars(args).items(): 150 | if param_name.lower() in cfg or param_name.upper() in cfg: 151 | continue 152 | cfg[param_name] = value 153 | 154 | 155 | def load_config(model_path): 156 | model_name = model_path.stem 157 | config_path = model_path.parent / (model_name + '.yml') 158 | 159 | if config_path.exists(): 160 | cfg = load_config_file(config_path) 161 | else: 162 | cfg = dict() 163 | 164 | cwd = Path.cwd() 165 | config_parent = config_path.parent.absolute() 166 | while len(config_parent.parents) > 0: 167 | config_path = config_parent / 'config.yml' 168 | 169 | if config_path.exists(): 170 | local_config = load_config_file(config_path, model_name=model_name) 171 | cfg.update({k: v for k, v in local_config.items() if k not in cfg}) 172 | 173 | if config_parent.absolute() == cwd: 174 | break 175 | config_parent = config_parent.parent 176 | 177 | return edict(cfg) 178 | 179 | 180 | def load_config_file(config_path, model_name=None, return_edict=False): 181 | with open(config_path, 'r') as f: 182 | cfg = yaml.safe_load(f) 183 | 184 | if 'SUBCONFIGS' in cfg: 185 | if model_name is not None and model_name in cfg['SUBCONFIGS']: 186 | cfg.update(cfg['SUBCONFIGS'][model_name]) 187 | del cfg['SUBCONFIGS'] 188 | 189 | return edict(cfg) if return_edict else cfg 190 | -------------------------------------------------------------------------------- /isegm/utils/exp_imports/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from easydict import EasyDict as edict 4 | from albumentations import * 5 | 6 | from isegm.data.datasets import * 7 | from isegm.model.losses import * 8 | from isegm.data.transforms import * 9 | from isegm.model.metrics import AdaptiveIoU 10 | from isegm.data.points_sampler import MultiPointSampler 11 | from isegm.utils.log import logger 12 | from isegm.model import initializer 13 | 14 | -------------------------------------------------------------------------------- /isegm/utils/log.py: -------------------------------------------------------------------------------- 1 | import io 2 | import time 3 | import logging 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | LOGGER_NAME = 'root' 10 | LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' 11 | 12 | handler = logging.StreamHandler() 13 | 14 | logger = logging.getLogger(LOGGER_NAME) 15 | logger.setLevel(logging.INFO) 16 | logger.addHandler(handler) 17 | 18 | 19 | def add_logging(logs_path, prefix): 20 | log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' 21 | stdout_log_path = logs_path / log_name 22 | 23 | fh = logging.FileHandler(str(stdout_log_path)) 24 | formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', 25 | datefmt=LOGGER_DATEFMT) 26 | fh.setFormatter(formatter) 27 | logger.addHandler(fh) 28 | 29 | 30 | class TqdmToLogger(io.StringIO): 31 | logger = None 32 | level = None 33 | buf = '' 34 | 35 | def __init__(self, logger, level=None, mininterval=5): 36 | super(TqdmToLogger, self).__init__() 37 | self.logger = logger 38 | self.level = level or logging.INFO 39 | self.mininterval = mininterval 40 | self.last_time = 0 41 | 42 | def write(self, buf): 43 | self.buf = buf.strip('\r\n\t ') 44 | 45 | def flush(self): 46 | if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: 47 | self.logger.log(self.level, self.buf) 48 | self.last_time = time.time() 49 | 50 | 51 | class SummaryWriterAvg(SummaryWriter): 52 | def __init__(self, *args, dump_period=20, **kwargs): 53 | super().__init__(*args, **kwargs) 54 | self._dump_period = dump_period 55 | self._avg_scalars = dict() 56 | 57 | def add_scalar(self, tag, value, global_step=None, disable_avg=False): 58 | if disable_avg or isinstance(value, (tuple, list, dict)): 59 | super().add_scalar(tag, np.array(value), global_step=global_step) 60 | else: 61 | if tag not in self._avg_scalars: 62 | self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) 63 | avg_scalar = self._avg_scalars[tag] 64 | avg_scalar.add(value) 65 | 66 | if avg_scalar.is_full(): 67 | super().add_scalar(tag, avg_scalar.value, 68 | global_step=global_step) 69 | avg_scalar.reset() 70 | 71 | 72 | class ScalarAccumulator(object): 73 | def __init__(self, period): 74 | self.sum = 0 75 | self.cnt = 0 76 | self.period = period 77 | 78 | def add(self, value): 79 | self.sum += value 80 | self.cnt += 1 81 | 82 | @property 83 | def value(self): 84 | if self.cnt > 0: 85 | return self.sum / self.cnt 86 | else: 87 | return 0 88 | 89 | def reset(self): 90 | self.cnt = 0 91 | self.sum = 0 92 | 93 | def is_full(self): 94 | return self.cnt >= self.period 95 | 96 | def __len__(self): 97 | return self.cnt 98 | -------------------------------------------------------------------------------- /isegm/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .log import logger 5 | 6 | 7 | def get_dims_with_exclusion(dim, exclude=None): 8 | dims = list(range(dim)) 9 | if exclude is not None: 10 | dims.remove(exclude) 11 | 12 | return dims 13 | 14 | 15 | def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False): 16 | if epoch is None: 17 | checkpoint_name = 'last_checkpoint.pth' 18 | else: 19 | checkpoint_name = f'{epoch:03d}.pth' 20 | 21 | if prefix: 22 | checkpoint_name = f'{prefix}_{checkpoint_name}' 23 | 24 | if not checkpoints_path.exists(): 25 | checkpoints_path.mkdir(parents=True) 26 | 27 | checkpoint_path = checkpoints_path / checkpoint_name 28 | if verbose: 29 | logger.info(f'Save checkpoint to {str(checkpoint_path)}') 30 | 31 | net = net.module if multi_gpu else net 32 | torch.save({'state_dict': net.state_dict(), 33 | 'config': net._config}, str(checkpoint_path)) 34 | 35 | 36 | def get_bbox_from_mask(mask): 37 | rows = np.any(mask, axis=1) 38 | cols = np.any(mask, axis=0) 39 | rmin, rmax = np.where(rows)[0][[0, -1]] 40 | cmin, cmax = np.where(cols)[0][[0, -1]] 41 | 42 | return rmin, rmax, cmin, cmax 43 | 44 | 45 | def expand_bbox(bbox, expand_ratio, min_crop_size=None): 46 | rmin, rmax, cmin, cmax = bbox 47 | rcenter = 0.5 * (rmin + rmax) 48 | ccenter = 0.5 * (cmin + cmax) 49 | height = expand_ratio * (rmax - rmin + 1) 50 | width = expand_ratio * (cmax - cmin + 1) 51 | if min_crop_size is not None: 52 | height = max(height, min_crop_size) 53 | width = max(width, min_crop_size) 54 | 55 | rmin = int(round(rcenter - 0.5 * height)) 56 | rmax = int(round(rcenter + 0.5 * height)) 57 | cmin = int(round(ccenter - 0.5 * width)) 58 | cmax = int(round(ccenter + 0.5 * width)) 59 | 60 | return rmin, rmax, cmin, cmax 61 | 62 | 63 | def clamp_bbox(bbox, rmin, rmax, cmin, cmax): 64 | return (max(rmin, bbox[0]), min(rmax, bbox[1]), 65 | max(cmin, bbox[2]), min(cmax, bbox[3])) 66 | 67 | 68 | def get_bbox_iou(b1, b2): 69 | h_iou = get_segments_iou(b1[:2], b2[:2]) 70 | w_iou = get_segments_iou(b1[2:4], b2[2:4]) 71 | return h_iou * w_iou 72 | 73 | 74 | def get_segments_iou(s1, s2): 75 | a, b = s1 76 | c, d = s2 77 | intersection = max(0, min(b, d) - max(a, c) + 1) 78 | union = max(1e-6, max(b, d) - min(a, c) + 1) 79 | return intersection / union 80 | 81 | 82 | def get_labels_with_sizes(x): 83 | obj_sizes = np.bincount(x.flatten()) 84 | labels = np.nonzero(obj_sizes)[0].tolist() 85 | labels = [x for x in labels if x != 0] 86 | return labels, obj_sizes[labels].tolist() 87 | -------------------------------------------------------------------------------- /isegm/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from copy import deepcopy 3 | import inspect 4 | import torch.nn as nn 5 | 6 | 7 | def serialize(init): 8 | parameters = list(inspect.signature(init).parameters) 9 | 10 | @wraps(init) 11 | def new_init(self, *args, **kwargs): 12 | params = deepcopy(kwargs) 13 | for pname, value in zip(parameters[1:], args): 14 | params[pname] = value 15 | 16 | config = { 17 | 'class': get_classname(self.__class__), 18 | 'params': dict() 19 | } 20 | specified_params = set(params.keys()) 21 | 22 | for pname, param in get_default_params(self.__class__).items(): 23 | if pname not in params: 24 | params[pname] = param.default 25 | 26 | for name, value in list(params.items()): 27 | param_type = 'builtin' 28 | if inspect.isclass(value): 29 | param_type = 'class' 30 | value = get_classname(value) 31 | 32 | config['params'][name] = { 33 | 'type': param_type, 34 | 'value': value, 35 | 'specified': name in specified_params 36 | } 37 | 38 | setattr(self, '_config', config) 39 | init(self, *args, **kwargs) 40 | 41 | return new_init 42 | 43 | 44 | def load_model(config, **kwargs): 45 | model_class = get_class_from_str(config['class']) 46 | model_default_params = get_default_params(model_class) 47 | 48 | model_args = dict() 49 | for pname, param in config['params'].items(): 50 | value = param['value'] 51 | if param['type'] == 'class': 52 | value = get_class_from_str(value) 53 | 54 | if pname not in model_default_params and not param['specified']: 55 | continue 56 | 57 | assert pname in model_default_params 58 | if not param['specified'] and model_default_params[pname].default == value: 59 | continue 60 | model_args[pname] = value 61 | 62 | model_args.update(kwargs) 63 | 64 | return model_class(**model_args) 65 | 66 | 67 | def get_config_repr(config): 68 | config_str = f'Model: {config["class"]}\n' 69 | for pname, param in config['params'].items(): 70 | value = param["value"] 71 | if param['type'] == 'class': 72 | value = value.split('.')[-1] 73 | param_str = f'{pname:<22} = {str(value):<12}' 74 | if not param['specified']: 75 | param_str += ' (default)' 76 | config_str += param_str + '\n' 77 | return config_str 78 | 79 | 80 | def get_default_params(some_class): 81 | params = dict() 82 | for mclass in some_class.mro(): 83 | if mclass is nn.Module or mclass is object: 84 | continue 85 | 86 | mclass_params = inspect.signature(mclass.__init__).parameters 87 | for pname, param in mclass_params.items(): 88 | if param.default != param.empty and pname not in params: 89 | params[pname] = param 90 | 91 | return params 92 | 93 | 94 | def get_classname(cls): 95 | module = cls.__module__ 96 | name = cls.__qualname__ 97 | if module is not None and module != "__builtin__": 98 | name = module + "." + name 99 | return name 100 | 101 | 102 | def get_class_from_str(class_str): 103 | components = class_str.split('.') 104 | mod = __import__('.'.join(components[:-1])) 105 | for comp in components[1:]: 106 | mod = getattr(mod, comp) 107 | return mod 108 | -------------------------------------------------------------------------------- /isegm/utils/vis.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def visualize_instances(imask, bg_color=255, 7 | boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): 8 | num_objects = imask.max() + 1 9 | palette = get_palette(num_objects) 10 | if bg_color is not None: 11 | palette[0] = bg_color 12 | 13 | result = palette[imask].astype(np.uint8) 14 | if boundaries_color is not None: 15 | boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) 16 | tresult = result.astype(np.float32) 17 | tresult[boundaries_mask] = boundaries_color 18 | tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result 19 | result = tresult.astype(np.uint8) 20 | 21 | return result 22 | 23 | 24 | @lru_cache(maxsize=16) 25 | def get_palette(num_cls): 26 | palette = np.zeros(3 * num_cls, dtype=np.int32) 27 | 28 | for j in range(0, num_cls): 29 | lab = j 30 | i = 0 31 | 32 | while lab > 0: 33 | palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) 34 | palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) 35 | palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) 36 | i = i + 1 37 | lab >>= 3 38 | 39 | return palette.reshape((-1, 3)) 40 | 41 | 42 | def visualize_mask(mask, num_cls): 43 | palette = get_palette(num_cls) 44 | mask[mask == -1] = 0 45 | 46 | return palette[mask].astype(np.uint8) 47 | 48 | 49 | def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): 50 | proposal_map, colors, candidates = proposals_info 51 | 52 | proposal_map = draw_probmap(proposal_map) 53 | for x, y in candidates: 54 | proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) 55 | 56 | return proposal_map 57 | 58 | 59 | def draw_probmap(x): 60 | return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) 61 | 62 | 63 | def draw_points(image, points, color, radius=3): 64 | image = image.copy() 65 | for p in points: 66 | if p[0] < 0: 67 | continue 68 | if len(p) == 3: 69 | pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2 70 | else: 71 | pradius = radius 72 | image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1) 73 | 74 | return image 75 | 76 | 77 | def draw_instance_map(x, palette=None): 78 | num_colors = x.max() + 1 79 | if palette is None: 80 | palette = get_palette(num_colors) 81 | 82 | return palette[x].astype(np.uint8) 83 | 84 | 85 | def blend_mask(image, mask, alpha=0.6): 86 | if mask.min() == -1: 87 | mask = mask.copy() + 1 88 | 89 | imap = draw_instance_map(mask) 90 | result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) 91 | return result 92 | 93 | 94 | def get_boundaries(instances_masks, boundaries_width=1): 95 | boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) 96 | 97 | for obj_id in np.unique(instances_masks.flatten()): 98 | if obj_id == 0: 99 | continue 100 | 101 | obj_mask = instances_masks == obj_id 102 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) 103 | inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) 104 | 105 | obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) 106 | boundaries = np.logical_or(boundaries, obj_boundary) 107 | return boundaries 108 | 109 | 110 | def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), 111 | neg_color=(255, 0, 0), radius=4): 112 | result = img.copy() 113 | 114 | if mask is not None: 115 | palette = get_palette(np.max(mask) + 1) 116 | rgb_mask = palette[mask.astype(np.uint8)] 117 | 118 | mask_region = (mask > 0).astype(np.uint8) 119 | result = result * (1 - mask_region[:, :, np.newaxis]) + \ 120 | (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ 121 | alpha * rgb_mask 122 | result = result.astype(np.uint8) 123 | 124 | # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) 125 | 126 | if clicks_list is not None and len(clicks_list) > 0: 127 | pos_points = [click.coords for click in clicks_list if click.is_positive] 128 | neg_points = [click.coords for click in clicks_list if not click.is_positive] 129 | 130 | result = draw_points(result, pos_points, pos_color, radius=radius) 131 | result = draw_points(result, neg_points, neg_color, radius=radius) 132 | 133 | return result 134 | 135 | 136 | def add_tag(image, tag = 'nodefined', tag_h = 40): 137 | image = image.astype(np.uint8) 138 | H,W = image.shape[0], image.shape[1] 139 | tag_blanc = np.ones((tag_h,W,3)).astype(np.uint8) * 255 140 | cv2.putText(tag_blanc,tag,(10,30),cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0 ), 1) 141 | image = cv2.vconcat([image,tag_blanc]) 142 | return image 143 | 144 | 145 | -------------------------------------------------------------------------------- /models/gp_sbd_resnet50.py: -------------------------------------------------------------------------------- 1 | from isegm.utils.exp_imports.default import * 2 | MODEL_NAME = 'resnet50' 3 | # from isegm.data.compose import ComposeDataset,ProportionalComposeDataset 4 | import torch.nn as nn 5 | from isegm.data.aligned_augmentation import AlignedAugmentator 6 | from isegm.engine.gp_trainer import ISTrainer 7 | import importlib 8 | 9 | def main(cfg): 10 | model, model_cfg = init_model(cfg) 11 | train(model, cfg, model_cfg) 12 | 13 | 14 | def init_model(cfg): 15 | model_cfg = edict() 16 | model_cfg.crop_size = (cfg.crop_size, cfg.crop_size) 17 | model_cfg.num_max_points = cfg.num_max_points 18 | GpModel = importlib.import_module('isegm.model.'+cfg.gp_model).GpModel 19 | model = GpModel(backbone = 'resnet50', use_leaky_relu=True, use_disks=(not cfg.nodisk), binary_prev_mask=False, 20 | with_prev_mask=(not cfg.noprev_mask), weight_dir=cfg.IMAGENET_PRETRAINED_MODELS.RESNET50_v1s) 21 | model.to(cfg.device) 22 | model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0)) 23 | model.model.load_pretrained_weights() 24 | return model, model_cfg 25 | 26 | 27 | def train(model, cfg, model_cfg): 28 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size 29 | cfg.val_batch_size = cfg.batch_size 30 | crop_size = model_cfg.crop_size 31 | 32 | loss_cfg = edict() 33 | loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2) 34 | loss_cfg.instance_loss_weight = 1.0 35 | 36 | train_augmentator = AlignedAugmentator(ratio=[0.3,1.3], target_size=crop_size,flip=True, distribution='Gaussian', gs_center=0.8) 37 | 38 | val_augmentator = Compose([ 39 | UniformRandomResize(scale_range=(0.75, 1.25)), 40 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0), 41 | RandomCrop(*crop_size) 42 | ], p=1.0) 43 | 44 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.70, 45 | merge_objects_prob=0.15, 46 | max_num_merged_objects=2, 47 | use_hierarchy=False, 48 | first_click_center=True) 49 | 50 | trainset = SBDDataset( 51 | cfg.SBD_PATH, 52 | split='train', 53 | augmentator=train_augmentator, 54 | min_object_area=80, 55 | keep_background_prob=0.01, 56 | points_sampler=points_sampler, 57 | samples_scores_gamma=1.25 58 | ) 59 | 60 | valset = SBDDataset( 61 | cfg.SBD_PATH, 62 | split='val', 63 | augmentator=val_augmentator, 64 | min_object_area=80, 65 | points_sampler=points_sampler, 66 | epoch_len=500 67 | ) 68 | 69 | optimizer_params = { 70 | 'lr': cfg.lr, 'betas': (0.9, 0.999), 'eps': 1e-8 71 | } 72 | 73 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 74 | milestones=cfg.milestones[:-1], gamma=0.1) 75 | trainer = ISTrainer(model, cfg, model_cfg, loss_cfg, 76 | trainset, valset, 77 | optimizer='adam', 78 | optimizer_params=optimizer_params, 79 | lr_scheduler=lr_scheduler, 80 | checkpoint_interval=[(0, 50), (200, 10)], 81 | image_dump_interval=cfg.image_dump_interval, 82 | metrics=[AdaptiveIoU()], 83 | max_interactive_points=model_cfg.num_max_points, 84 | max_num_next_clicks=cfg.max_num_next_clicks) 85 | trainer.run(num_epochs=cfg.milestones[-1]) 86 | -------------------------------------------------------------------------------- /net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmhhmz/GPCIS_CVPR2023/6460415a2e784f5623a0c859971f884a89eb0fd0/net.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | easydict==1.9 3 | opencv_contrib_python==4.2.0.32 4 | torchvision==0.9.0a0+01dfa8e 5 | mmcv_full==1.2.7 6 | numpy==1.17.0 7 | torch==1.8.0 8 | albumentations==0.5.1 9 | termcolor==1.1.0 10 | attrs==21.2.0 11 | timm==0.3.2 12 | scikit_image==0.17.2 13 | scipy==1.5.4 14 | Cython==0.29.23 15 | tqdm==4.61.0 16 | attr==0.3.1 17 | ipython==7.30.1 18 | mmcv==1.4.1 19 | Pillow==8.4.0 20 | PyYAML==6.0 21 | skimage==0.0 22 | thop==0.0.31-2005241907 23 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # training 2 | python3 train.py models/gp_sbd_resnet50.py \ 3 | --gpus=0,1 \ 4 | --workers=12 \ 5 | --batch-size=32 \ 6 | --milestones 190 220 230 \ 7 | --max_num_next_clicks=3 \ 8 | --num_max_points=24 \ 9 | --crop_size=256 \ 10 | --gp_model=is_gp_resnet50 \ 11 | --exp-name=GP_Resnet50_SBD_230epo 12 | 13 | # Evaluation 14 | # python3 scripts/evaluate_model.py Baseline \ 15 | # --model_dir=checkpoints/ \ 16 | # --checkpoint=GPCIS_Resnet50.pth \ 17 | # --datasets=GrabCut,Berkeley,SBD,DAVIS \ 18 | # --gpus=0 \ 19 | # --n-clicks=20 \ 20 | # --target-iou=0.90 \ 21 | # --thresh=0.50 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib.util 4 | 5 | import torch 6 | from isegm.utils.exp import init_experiment 7 | 8 | 9 | def main(): 10 | args = parse_args() 11 | if args.temp_model_path: 12 | model_script = load_module(args.temp_model_path) 13 | else: 14 | model_script = load_module(args.model_path) 15 | 16 | model_base_name = getattr(model_script, 'MODEL_NAME', None) 17 | 18 | args.distributed = 'WORLD_SIZE' in os.environ 19 | cfg = init_experiment(args, model_base_name) 20 | 21 | torch.backends.cudnn.benchmark = True 22 | torch.multiprocessing.set_sharing_strategy('file_system') 23 | 24 | model_script.main(cfg) 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument('model_path', type=str, 31 | help='Path to the model script.') 32 | 33 | parser.add_argument('--exp-name', type=str, default='', 34 | help='Here you can specify the name of the experiment. ' 35 | 'It will be added as a suffix to the experiment folder.') 36 | 37 | parser.add_argument('--workers', type=int, default=4, 38 | metavar='N', help='Dataloader threads.') 39 | 40 | parser.add_argument('--batch-size', type=int, default=-1, 41 | help='You can override model batch size by specify positive number.') 42 | 43 | parser.add_argument('--ngpus', type=int, default=1, 44 | help='Number of GPUs. ' 45 | 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' 46 | 'You should use either this argument or "--gpus".') 47 | 48 | parser.add_argument('--gpus', type=str, default='', required=False, 49 | help='Ids of used GPUs. You should use either this argument or "--ngpus".') 50 | 51 | parser.add_argument('--resume-exp', type=str, default=None, 52 | help='The prefix of the name of the experiment to be continued. ' 53 | 'If you use this field, you must specify the "--resume-prefix" argument.') 54 | 55 | parser.add_argument('--resume-prefix', type=str, default='latest', 56 | help='The prefix of the name of the checkpoint to be loaded.') 57 | 58 | parser.add_argument('--start-epoch', type=int, default=0, 59 | help='The number of the starting epoch from which training will continue. ' 60 | '(it is important for correct logging and learning rate)') 61 | 62 | parser.add_argument('--weights', type=str, default=None, 63 | help='Model weights will be loaded from the specified path if you use this argument.') 64 | 65 | parser.add_argument('--temp-model-path', type=str, default='', 66 | help='Do not use this argument (for internal purposes).') 67 | 68 | parser.add_argument("--local_rank", type=int, default=0) 69 | 70 | parser.add_argument("--lr", type=float, default=5e-3) 71 | 72 | parser.add_argument("--max_num_next_clicks", type=int, default=3) 73 | 74 | parser.add_argument("--num_max_points", type=int, default=24) 75 | 76 | parser.add_argument('--gp_model', type=str, default='') 77 | 78 | parser.add_argument('--noprev_mask', action='store_true') 79 | 80 | parser.add_argument('--nodisk', action='store_true') 81 | 82 | parser.add_argument('--binary_prev_mask', action='store_true') 83 | 84 | parser.add_argument("--image_dump_interval", type=int, default=500) 85 | 86 | parser.add_argument("--crop_size", type=int, default=256) 87 | 88 | parser.add_argument('--milestones', type=int, nargs='+', default=[190,210,230]) 89 | return parser.parse_args() 90 | 91 | 92 | def load_module(script_path): 93 | spec = importlib.util.spec_from_file_location("model_script", script_path) 94 | model_script = importlib.util.module_from_spec(spec) 95 | spec.loader.exec_module(model_script) 96 | 97 | return model_script 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | --------------------------------------------------------------------------------