├── .gitattributes
├── GPCIS_supp.pdf
├── LICENSE
├── README.md
├── checkpoints
└── GPCIS_Resnet50.pth
├── config.yml
├── isegm
├── data
│ ├── aligned_augmentation.py
│ ├── base.py
│ ├── compose.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── berkeley.py
│ │ ├── davis.py
│ │ ├── grabcut.py
│ │ └── sbd.py
│ ├── points_sampler.py
│ ├── sample.py
│ └── transforms.py
├── engine
│ ├── gp_trainer.py
│ └── optimizer.py
├── inference
│ ├── clicker.py
│ ├── evaluation.py
│ ├── predictors
│ │ ├── __init__.py
│ │ └── baseline.py
│ ├── transforms
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── crops.py
│ │ ├── flip.py
│ │ ├── limit_longest_side.py
│ │ ├── resize.py
│ │ └── zoom_in.py
│ └── utils.py
├── model
│ ├── initializer.py
│ ├── is_gp_model.py
│ ├── is_gp_resnet50.py
│ ├── losses.py
│ ├── metrics.py
│ ├── modeling
│ │ ├── basic_blocks.py
│ │ ├── deeplab_v3.py
│ │ ├── deeplab_v3_gp.py
│ │ ├── resnet.py
│ │ └── resnetv1b.py
│ ├── modifiers.py
│ └── ops.py
└── utils
│ ├── crop_local.py
│ ├── cython
│ ├── __init__.py
│ ├── _get_dist_maps.pyx
│ ├── _get_dist_maps.pyxbld
│ └── dist_maps.py
│ ├── distributed.py
│ ├── exp.py
│ ├── exp_imports
│ └── default.py
│ ├── log.py
│ ├── misc.py
│ ├── serialization.py
│ └── vis.py
├── models
└── gp_sbd_resnet50.py
├── net.png
├── requirements.txt
├── run.sh
├── scripts
└── evaluate_model.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 | *.pth filter=lfs diff=lfs merge=lfs -text
4 |
--------------------------------------------------------------------------------
/GPCIS_supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmhhmz/GPCIS_CVPR2023/6460415a2e784f5623a0c859971f884a89eb0fd0/GPCIS_supp.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MinghaoZhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Interactive Segmentation as Gaussian Process Classification (CVPR2023 Highlight)
2 | Minghao Zhou, [Hong Wang](https://hongwang01.github.io/), Qian Zhao, Yuexiang Li, Yawen Huang, [Deyu Meng](http://gr.xjtu.edu.cn/web/dymeng), [Yefeng Zheng](https://sites.google.com/site/yefengzheng/)
3 |
4 |
5 | [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_Interactive_Segmentation_As_Gaussion_Process_Classification_CVPR_2023_paper.pdf) [[Poster]](https://cvpr2023.thecvf.com/media/PosterPDFs/CVPR%202023/23088.png?t=1684895990.406102) [[Video]](https://youtu.be/mapyH-WujhY) [[Slides]](https://cvpr2023.thecvf.com/media/cvpr-2023/Slides/23088.pdf) [[Supp]](GPCIS_supp.pdf)
6 |
7 | ## Update 2023/8/1
8 | We have updated [is_gp_model.py](isegm/model/is_gp_model.py) for a more memory-efficient implementation.
9 |
10 | ## Usage
11 | Please first set up the environment and prepare the training (SBD)/testing (GrabCut, Berkeley, SBD, DAVIS) datasets following [RITM](https://github.com/saic-vul/ritm_interactive_segmentation), and change the directories in [config.yml](config.yml).
12 |
13 | Please run [run.sh](run.sh) for training/evaluation. For training, the resnet50 weights pretrained on ImageNet is used. Please download the [weights](https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth) and change the corresponding directory in [config.yml](config.yml). For evaluation, you can directly test with our provided checkpoint in [checkpoints/GPCIS_Resnet50.pth](checkpoints/GPCIS_Resnet50.pth).
14 |
15 | The core codes of the GPCIS model can be found in [isegm/model/is_gp_model.py](isegm/model/is_gp_model.py) and [isegm/model/is_gp_resnet50.py](isegm/model/is_gp_resnet50.py).
16 |
17 | ## Overview of GPCIS
18 |

19 |
--------------------------------------------------------------------------------
/checkpoints/GPCIS_Resnet50.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1b59a4c8b7f1ea300abe8add185f8ab98beeeac01bd80bb2a2d029200195c3c9
3 | size 157934784
4 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | INTERACTIVE_MODELS_PATH: " "
2 | EXPS_PATH: "./experiments"
3 |
4 | # Datasets
5 | GRABCUT_PATH: "path_to/GrabCut"
6 | BERKELEY_PATH: "path_to/Berkeley"
7 | DAVIS_PATH: "path_to/DAVIS"
8 | COCO_MVAL_PATH: "path_to/COCO_MVal"
9 | PASCALVOC_PATH: "path_to/VOC2012"
10 | DAVIS585_PATH: "path_to/Selected_480P"
11 | SBD_PATH: "path_to/SBD/dataset"
12 |
13 | # Pretrained weights
14 | IMAGENET_PRETRAINED_MODELS:
15 | RESNET50_v1s: "path_to/gluon_resnet50_v1s-1762acc0.pth"
16 |
17 |
18 |
--------------------------------------------------------------------------------
/isegm/data/aligned_augmentation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from scipy.stats import truncnorm
4 |
5 | def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
6 | return truncnorm(
7 | (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
8 |
9 | #X1 = get_truncated_normal(mean=0.7, sd=0.3, low=0.2, upp=1)
10 | #x1 = X1.rvs(1)[0]
11 |
12 |
13 | class AlignedAugmentator:
14 | def __init__(self, ratio = [0.3,1], target_size = (256,256), flip = True,
15 | distribution = 'Uniform', gs_center = 0.8, gs_sd = 0.4,
16 | color_augmentator = None):
17 | '''
18 | distribution belongs to [ 'Uniform, Gaussian' ]
19 | '''
20 | self.ratio = ratio
21 | self.target_size = target_size
22 | self.flip = flip
23 | self.distribution = distribution
24 | self.gaussian = get_truncated_normal(mean=gs_center, sd=gs_sd, low=ratio[0], upp=ratio[1])
25 | self.color_augmentator = color_augmentator
26 |
27 | def __call__(self, image, mask):
28 | '''
29 | image: np.array (267, 400, 3) np.uint8
30 | mask: np.array (267, 400, 1) np.int32
31 | '''
32 |
33 | if self.distribution == 'Uniform':
34 | hr,wr = np.random.uniform(*self.ratio),np.random.uniform(*self.ratio)
35 | elif self.distribution == 'Gaussian':
36 | hr,wr = self.gaussian.rvs(2)
37 |
38 | H,W = image.shape[0], image.shape[1]
39 | h,w = int(H*hr), int(W*wr)
40 | if hr > 1 or wr > 1:
41 | image, mask = self.pad_image_mask(image, mask, hr, wr)
42 | H,W = image.shape[0], image.shape[1]
43 |
44 | y1 = np.random.randint(0,H-h)
45 | x1 = np.random.randint(0,W-w)
46 | y2 = y1 + h
47 | x2 = y1 + W
48 |
49 | image_crop = image[y1:y2,x1:x2,:]
50 | image_crop = cv2.resize(image_crop, tuple(self.target_size))
51 | mask_crop = mask[y1:y2,x1:x2,:].astype(np.uint8)
52 | mask_crop = (cv2.resize(mask_crop, tuple(self.target_size))).astype(np.int32)
53 | if len(mask_crop.shape) == 2:
54 | mask_crop = np.expand_dims(mask_crop,-1)
55 |
56 | if self.flip:
57 | if np.random.rand() < 0.3:
58 | image_crop = np.flip(image_crop,0)
59 | mask_crop = np.flip(mask_crop,0)
60 | if np.random.rand() < 0.3:
61 | image_crop = np.flip(image_crop,1)
62 | mask_crop = np.flip(mask_crop,1)
63 |
64 | image_crop = np.ascontiguousarray(image_crop)
65 | mask_crop = np.ascontiguousarray(mask_crop)
66 |
67 | if self.color_augmentator is not None:
68 | image_crop = self.color_augmentator(image=image_crop)['image']
69 |
70 | aug_output = {}
71 | aug_output['image'] = image_crop
72 | aug_output['mask'] = mask_crop
73 | return aug_output
74 |
75 | def pad_image_mask(self, image, mask, hr, wr):
76 | H,W = image.shape[0], image.shape[1]
77 | if hr > 1:
78 | new_h = int(H * hr) + 1
79 | pad_h = new_h - H
80 | pad_h1 = np.random.randint(0,pad_h)
81 | pad_h2 = pad_h - pad_h1
82 | image = np.pad(image, ((pad_h1, pad_h2),(0,0),(0,0)), 'constant')
83 | mask = np.pad(mask, ((pad_h1, pad_h2),(0,0),(0,0)), 'constant')
84 |
85 | if wr > 1:
86 | new_w = int(W * wr) + 1
87 | pad_w = new_w - W
88 | pad_w1 = np.random.randint(0,pad_w)
89 | pad_w2 = pad_w - pad_w1
90 | image = np.pad(image, ((0,0), (pad_w1, pad_w2),(0,0)), 'constant')
91 | mask = np.pad(mask, ( (0,0), (pad_w1, pad_w2),(0,0)), 'constant')
92 | return image, mask
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/isegm/data/base.py:
--------------------------------------------------------------------------------
1 | import random
2 | import pickle
3 | import numpy as np
4 | import torch
5 | from torchvision import transforms
6 | from .points_sampler import MultiPointSampler
7 | from .sample import DSample
8 | import cv2
9 | from isegm.utils.crop_local import random_choose_target,get_bbox_from_mask,getLargestCC,expand_bbox, expand_bbox_with_bias
10 | from skimage import morphology
11 | # import pdb
12 |
13 | class ISDataset(torch.utils.data.dataset.Dataset):
14 | def __init__(self,
15 | augmentator=None,
16 | points_sampler=MultiPointSampler(max_num_points=12),
17 | min_object_area=0,
18 | keep_background_prob=0.0,
19 | with_image_info=False,
20 | samples_scores_path=None,
21 | samples_scores_gamma=1.0,
22 | epoch_len=-1,
23 | with_refiner = False):
24 | super(ISDataset, self).__init__()
25 | self.epoch_len = epoch_len
26 | self.augmentator = augmentator
27 | self.min_object_area = min_object_area
28 | self.keep_background_prob = keep_background_prob
29 | self.points_sampler = points_sampler
30 | self.with_image_info = with_image_info
31 | self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma)
32 | self.to_tensor = transforms.ToTensor()
33 | self.with_refiner = with_refiner
34 | self.dataset_samples = None
35 |
36 | def __getitem__(self, index):
37 | while(1):
38 | try:
39 | if self.samples_precomputed_scores is not None:
40 | index = np.random.choice(self.samples_precomputed_scores['indices'],
41 | p=self.samples_precomputed_scores['probs'])
42 | else:
43 | if self.epoch_len > 0:
44 | index = random.randrange(0, len(self.dataset_samples))
45 |
46 | sample = self.get_sample(index)
47 | sample = self.augment_sample(sample)
48 | sample.remove_small_objects(self.min_object_area)
49 |
50 | self.points_sampler.sample_object(sample)
51 | points = np.array(self.points_sampler.sample_points())
52 | mask = self.points_sampler.selected_mask
53 | mask = self.remove_small_regions(mask)
54 | image = sample.image
55 | mask_area = mask[0].shape[0] * mask[0].shape[1]
56 |
57 | if self.with_refiner:
58 | trimap = self.get_trimap(mask[0])
59 | if mask[0].sum() < 3600: # 80 * 80
60 | y1,x1,y2,x2 = self.sampling_roi_full_object(mask[0])
61 | else:
62 | if np.random.rand() < 0.4:
63 | y1,x1,y2,x2 = self.sampling_roi_on_boundary(mask[0])
64 | else:
65 | y1,x1,y2,x2 = self.sampling_roi_full_object(mask[0])
66 |
67 | roi = torch.tensor([x1, y1, x2, y2])
68 | h,w = mask[0].shape[0], mask[0].shape[1]
69 | image_focus = image[y1:y2,x1:x2,:]
70 | image_focus = cv2.resize(image_focus, (h,w))
71 |
72 | mask_255 = (mask[0] * 255).astype(np.uint8)
73 | mask_focus = mask_255[y1:y2,x1:x2]
74 | mask_focus = cv2.resize(mask_focus, (h,w)) > 128
75 | mask_focus = np.expand_dims(mask_focus,0).astype(np.float32)
76 |
77 | trimap_255 = (trimap[0] * 255).astype(np.uint8)
78 | trimap_focus = trimap_255[y1:y2,x1:x2]
79 | trimap_focus = cv2.resize(trimap_focus, (h,w)) > 128
80 | trimap_focus = np.expand_dims(trimap_focus,0).astype(np.float32)
81 |
82 | hc,wc = y2-y1, x2-x1
83 | ry,rx = h/hc, w/wc
84 | bias = np.array([y1,x1,0])
85 | ratio = np.array([ry,rx,1])
86 | points_focus = (points - bias) * ratio
87 |
88 | if mask.sum() > self.min_object_area and mask.sum() < mask_area * 0.85:
89 |
90 | output = {
91 | 'images': self.to_tensor(image),
92 | 'points': points.astype(np.float32),
93 | 'instances': mask,
94 | 'trimap':trimap,
95 | 'images_focus':self.to_tensor(image_focus),
96 | 'instances_focus':mask_focus,
97 | 'trimap_focus': trimap_focus,
98 | 'points_focus': points_focus.astype(np.float32),
99 | 'rois':roi.float()
100 | }
101 |
102 | if self.with_image_info:
103 | output['image_info'] = sample.sample_id
104 | return output
105 | else:
106 | index = np.random.randint(len(self.dataset_samples)-1)
107 | else:
108 | if mask.sum() > self.min_object_area and mask.sum() < mask_area * 0.85:
109 | output = {
110 | 'images': self.to_tensor(image),
111 | 'points': points.astype(np.float32),
112 | 'instances': mask,
113 | }
114 |
115 | if self.with_image_info:
116 | output['image_info'] = sample.sample_id
117 | return output
118 | else:
119 | index = np.random.randint(len(self.dataset_samples)-1)
120 | except:
121 |
122 | index = np.random.randint(len(self.dataset_samples)-1)
123 |
124 | # def __getitem__(self, index):
125 | # if self.samples_precomputed_scores is not None:
126 | # index = np.random.choice(self.samples_precomputed_scores['indices'],
127 | # p=self.samples_precomputed_scores['probs'])
128 | # else:
129 | # if self.epoch_len > 0:
130 | # index = random.randrange(0, len(self.dataset_samples))
131 |
132 | # sample = self.get_sample(index)
133 | # sample = self.augment_sample(sample)
134 | # sample.remove_small_objects(self.min_object_area)
135 |
136 | # self.points_sampler.sample_object(sample)
137 | # points = np.array(self.points_sampler.sample_points())
138 | # mask = self.points_sampler.selected_mask
139 | # output = {
140 | # 'images': self.to_tensor(sample.image),
141 | # 'points': points.astype(np.float32),
142 | # 'instances': mask
143 | # }
144 |
145 | # if self.with_image_info:
146 | # output['image_info'] = sample.sample_id
147 |
148 | # return output
149 |
150 | def remove_small_regions(self,mask):
151 | mask = mask[0] > 0.5
152 | mask = morphology.remove_small_objects(mask,min_size= 900)
153 | mask = np.expand_dims(mask,0).astype(np.float32)
154 | return mask
155 |
156 |
157 | def sampling_roi_full_object(self, gt_mask, min_size=32):
158 | max_mask = getLargestCC(gt_mask)
159 | y1,y2,x1,x2 = get_bbox_from_mask(max_mask)
160 | ratio = np.random.randint(11,17)/10
161 | y1,y2,x1,x2 = expand_bbox_with_bias(gt_mask,y1,y2,x1,x2,ratio,min_size,0.3)
162 | return y1,x1,y2,x2
163 |
164 | def sampling_roi_on_boundary(self,gt_mask):
165 | h,w = gt_mask.shape[0], gt_mask.shape[1]
166 | rh = np.random.randint(15,40)/10
167 | rw = np.random.randint(15,40)/10
168 | new_h,new_w = h/rh, w/rw
169 | crop_size = (int(new_h), int(new_w))
170 |
171 | alpha = gt_mask > 0.5
172 | alpha = alpha.astype(np.uint8)
173 | kernel = np.ones((5,5),np.uint8)
174 | dilate = cv2.dilate(alpha,kernel,iterations = 1)
175 | boundary = np.logical_and( dilate, np.logical_not(alpha))
176 | y1,x1,y2,x2 = random_choose_target(boundary,crop_size)
177 | return y1,x1,y2,x2
178 |
179 |
180 | def get_trimap(self, mask):
181 | h,w = mask.shape[0],mask.shape[1]
182 | hs,ws = h//8,w//8
183 | mask_255_big = (mask * 255).astype(np.uint8)
184 | mask_255_small = (cv2.resize(mask_255_big, (ws,hs)) > 128) * 255
185 | mask_resized = cv2.resize(mask_255_small.astype(np.uint8),(w,h)) > 128
186 | diff_mask = np.logical_xor(mask, mask_resized).astype(np.uint8)
187 |
188 | kernel = np.ones((3, 3), dtype=np.uint8)
189 | diff_mask = cv2.dilate(diff_mask, kernel, iterations=2) # 1:迭代次数,也就是执行几次膨胀操作
190 |
191 | diff_mask = diff_mask.astype(np.float32)
192 | diff_mask = np.expand_dims(diff_mask,0)
193 | return diff_mask
194 |
195 |
196 | def augment_sample(self, sample) -> DSample:
197 | valid_augmentation = False
198 | while not valid_augmentation:
199 | sample.augment(self.augmentator)
200 | keep_sample = (self.keep_background_prob < 0.0 or
201 | random.random() < self.keep_background_prob)
202 | valid_augmentation = len(sample) > 0 or keep_sample
203 |
204 | return sample
205 |
206 | def get_sample(self, index) -> DSample:
207 | raise NotImplementedError
208 |
209 | def __len__(self):
210 | if self.epoch_len > 0:
211 | return self.epoch_len
212 | else:
213 | return self.get_samples_number()
214 |
215 | def get_samples_number(self):
216 | return len(self.dataset_samples)
217 |
218 | @staticmethod
219 | def _load_samples_scores(samples_scores_path, samples_scores_gamma):
220 | if samples_scores_path is None:
221 | return None
222 |
223 | with open(samples_scores_path, 'rb') as f:
224 | images_scores = pickle.load(f)
225 |
226 | probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
227 | probs /= probs.sum()
228 | samples_scores = {
229 | 'indices': [x[0] for x in images_scores],
230 | 'probs': probs
231 | }
232 | print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}')
233 | return samples_scores
234 |
--------------------------------------------------------------------------------
/isegm/data/compose.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import isclose
3 | from .base import ISDataset
4 |
5 |
6 | class ComposeDataset(ISDataset):
7 | def __init__(self, datasets, **kwargs):
8 | super(ComposeDataset, self).__init__(**kwargs)
9 |
10 | self._datasets = datasets
11 | self.dataset_samples = []
12 | for dataset_indx, dataset in enumerate(self._datasets):
13 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
14 |
15 | def get_sample(self, index):
16 | dataset_indx, sample_indx = self.dataset_samples[index]
17 | return self._datasets[dataset_indx].get_sample(sample_indx)
18 |
19 |
20 | class ProportionalComposeDataset(ISDataset):
21 | def __init__(self, datasets, ratios, **kwargs):
22 | super().__init__(**kwargs)
23 |
24 | assert len(ratios) == len(datasets),\
25 | "The number of datasets must match the number of ratios"
26 | assert isclose(sum(ratios), 1.0),\
27 | "The sum of ratios must be equal to 1"
28 |
29 | self._ratios = ratios
30 | self._datasets = datasets
31 | self.dataset_samples = []
32 | for dataset_indx, dataset in enumerate(self._datasets):
33 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
34 |
35 | def get_sample(self, index):
36 | dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
37 | sample_indx = np.random.choice(len(self._datasets[dataset_indx]))
38 |
39 | return self._datasets[dataset_indx].get_sample(sample_indx)
40 |
--------------------------------------------------------------------------------
/isegm/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
2 | from .berkeley import BerkeleyDataset
3 | from .davis import DavisDataset
4 | from .grabcut import GrabCutDataset
5 | from .sbd import SBDDataset, SBDEvaluationDataset
6 |
--------------------------------------------------------------------------------
/isegm/data/datasets/berkeley.py:
--------------------------------------------------------------------------------
1 | from .grabcut import GrabCutDataset
2 |
3 |
4 | class BerkeleyDataset(GrabCutDataset):
5 | def __init__(self, dataset_path, **kwargs):
6 | super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs)
7 | self.name = 'Berkeley'
8 |
--------------------------------------------------------------------------------
/isegm/data/datasets/davis.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from isegm.data.base import ISDataset
7 | from isegm.data.sample import DSample
8 |
9 |
10 | class DavisDataset(ISDataset):
11 | def __init__(self, dataset_path,
12 | images_dir_name='img', masks_dir_name='gt',
13 | init_mask_mode = None, **kwargs):
14 | super(DavisDataset, self).__init__(**kwargs)
15 | self.name = 'Davis'
16 | self.dataset_path = Path(dataset_path)
17 | self._images_path = self.dataset_path / images_dir_name
18 | self._insts_path = self.dataset_path / masks_dir_name
19 |
20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22 | self.init_mask_mode = init_mask_mode
23 |
24 | def get_sample(self, index) -> DSample:
25 | image_name = self.dataset_samples[index]
26 | image_path = str(self._images_path / image_name)
27 | mask_path = str(self._masks_paths[image_name.split('.')[0]])
28 |
29 | image = cv2.imread(image_path)
30 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
31 | instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2)
32 | instances_mask[instances_mask > 0] = 1
33 |
34 | init_mask = None
35 |
36 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index, init_mask=init_mask)
37 |
--------------------------------------------------------------------------------
/isegm/data/datasets/grabcut.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from isegm.data.base import ISDataset
7 | from isegm.data.sample import DSample
8 |
9 |
10 | class GrabCutDataset(ISDataset):
11 | def __init__(self, dataset_path,
12 | images_dir_name='data_GT', masks_dir_name='boundary_GT',
13 | **kwargs):
14 | super(GrabCutDataset, self).__init__(**kwargs)
15 | self.name = 'GrabCut'
16 | self.dataset_path = Path(dataset_path)
17 | self._images_path = self.dataset_path / images_dir_name
18 | self._insts_path = self.dataset_path / masks_dir_name
19 |
20 | self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21 | self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22 |
23 | def get_sample(self, index) -> DSample:
24 | image_name = self.dataset_samples[index]
25 | image_path = str(self._images_path / image_name)
26 | mask_path = str(self._masks_paths[image_name.split('.')[0]])
27 |
28 | image = cv2.imread(image_path)
29 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30 | instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32)
31 | instances_mask[instances_mask == 128] = -1
32 | instances_mask[instances_mask > 128] = 1
33 |
34 | return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index)
35 |
--------------------------------------------------------------------------------
/isegm/data/datasets/sbd.py:
--------------------------------------------------------------------------------
1 | import pickle as pkl
2 | from pathlib import Path
3 |
4 | import cv2
5 | import numpy as np
6 | from scipy.io import loadmat
7 |
8 | from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
9 | from isegm.data.base import ISDataset
10 | from isegm.data.sample import DSample
11 |
12 |
13 | class SBDDataset(ISDataset):
14 | def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs):
15 | super(SBDDataset, self).__init__(**kwargs)
16 | assert split in {'train', 'val'}
17 | self.name = 'SBD'
18 | self.dataset_path = Path(dataset_path)
19 | self.dataset_split = split
20 | self._images_path = self.dataset_path / 'img'
21 | self._insts_path = self.dataset_path / 'inst'
22 | self._buggy_objects = dict()
23 | self._buggy_mask_thresh = buggy_mask_thresh
24 |
25 | with open(self.dataset_path / f'{split}.txt', 'r') as f:
26 | self.dataset_samples = [x.strip() for x in f.readlines()]
27 |
28 | def get_sample(self, index):
29 | image_name = self.dataset_samples[index]
30 | image_path = str(self._images_path / f'{image_name}.jpg')
31 | inst_info_path = str(self._insts_path / f'{image_name}.mat')
32 |
33 | image = cv2.imread(image_path)
34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
36 | instances_mask = self.remove_buggy_masks(index, instances_mask)
37 | instances_ids, _ = get_labels_with_sizes(instances_mask)
38 |
39 | return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index)
40 |
41 | def remove_buggy_masks(self, index, instances_mask):
42 | if self._buggy_mask_thresh > 0.0:
43 | buggy_image_objects = self._buggy_objects.get(index, None)
44 | if buggy_image_objects is None:
45 | buggy_image_objects = []
46 | instances_ids, _ = get_labels_with_sizes(instances_mask)
47 | for obj_id in instances_ids:
48 | obj_mask = instances_mask == obj_id
49 | mask_area = obj_mask.sum()
50 | bbox = get_bbox_from_mask(obj_mask)
51 | bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1)
52 | obj_area_ratio = mask_area / bbox_area
53 | if obj_area_ratio < self._buggy_mask_thresh:
54 | buggy_image_objects.append(obj_id)
55 |
56 | self._buggy_objects[index] = buggy_image_objects
57 | for obj_id in buggy_image_objects:
58 | instances_mask[instances_mask == obj_id] = 0
59 |
60 | return instances_mask
61 |
62 |
63 | class SBDEvaluationDataset(ISDataset):
64 | def __init__(self, dataset_path, split='val', **kwargs):
65 | super(SBDEvaluationDataset, self).__init__(**kwargs)
66 | assert split in {'train', 'val'}
67 |
68 | self.dataset_path = Path(dataset_path)
69 | self.dataset_split = split
70 | self._images_path = self.dataset_path / 'img'
71 | self._insts_path = self.dataset_path / 'inst'
72 |
73 | with open(self.dataset_path / f'{split}.txt', 'r') as f:
74 | self.dataset_samples = [x.strip() for x in f.readlines()]
75 |
76 | self.dataset_samples = self.get_sbd_images_and_ids_list()
77 |
78 | def get_sample(self, index) -> DSample:
79 | image_name, instance_id = self.dataset_samples[index]
80 | image_path = str(self._images_path / f'{image_name}.jpg')
81 | inst_info_path = str(self._insts_path / f'{image_name}.mat')
82 |
83 | image = cv2.imread(image_path)
84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
86 | instances_mask[instances_mask != instance_id] = 0
87 | instances_mask[instances_mask > 0] = 1
88 |
89 | return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
90 |
91 | def get_sbd_images_and_ids_list(self):
92 | pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl'
93 |
94 | if pkl_path.exists():
95 | with open(str(pkl_path), 'rb') as fp:
96 | images_and_ids_list = pkl.load(fp)
97 | else:
98 | images_and_ids_list = []
99 |
100 | for sample in self.dataset_samples:
101 | inst_info_path = str(self._insts_path / f'{sample}.mat')
102 | instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
103 | instances_ids, _ = get_labels_with_sizes(instances_mask)
104 |
105 | for instances_id in instances_ids:
106 | images_and_ids_list.append((sample, instances_id))
107 |
108 | with open(str(pkl_path), 'wb') as fp:
109 | pkl.dump(images_and_ids_list, fp)
110 |
111 | return images_and_ids_list
112 |
--------------------------------------------------------------------------------
/isegm/data/points_sampler.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import random
4 | import numpy as np
5 | from functools import lru_cache
6 | from .sample import DSample
7 |
8 |
9 | class BasePointSampler:
10 | def __init__(self):
11 | self._selected_mask = None
12 | self._selected_masks = None
13 |
14 | def sample_object(self, sample: DSample):
15 | raise NotImplementedError
16 |
17 | def sample_points(self):
18 | raise NotImplementedError
19 |
20 | @property
21 | def selected_mask(self):
22 | assert self._selected_mask is not None
23 | return self._selected_mask
24 |
25 | @selected_mask.setter
26 | def selected_mask(self, mask):
27 | self._selected_mask = mask[np.newaxis, :].astype(np.float32)
28 |
29 |
30 | class MultiPointSampler(BasePointSampler):
31 | def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1,
32 | positive_erode_prob=0.9, positive_erode_iters=3,
33 | negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5,
34 | merge_objects_prob=0.0, max_num_merged_objects=2,
35 | use_hierarchy=False, soft_targets=False,
36 | first_click_center=False, only_one_first_click=False,
37 | sfc_inner_k=1.7, sfc_full_inner_prob=0.0):
38 | super().__init__()
39 | self.max_num_points = max_num_points
40 | self.expand_ratio = expand_ratio
41 | self.positive_erode_prob = positive_erode_prob
42 | self.positive_erode_iters = positive_erode_iters
43 | self.merge_objects_prob = merge_objects_prob
44 | self.use_hierarchy = use_hierarchy
45 | self.soft_targets = soft_targets
46 | self.first_click_center = first_click_center
47 | self.only_one_first_click = only_one_first_click
48 | self.sfc_inner_k = sfc_inner_k
49 | self.sfc_full_inner_prob = sfc_full_inner_prob
50 |
51 | if max_num_merged_objects == -1:
52 | max_num_merged_objects = max_num_points
53 | self.max_num_merged_objects = max_num_merged_objects
54 |
55 | self.neg_strategies = ['bg', 'other', 'border']
56 | self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob]
57 | assert math.isclose(sum(self.neg_strategies_prob), 1.0)
58 |
59 | self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
60 | self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma)
61 | self._neg_masks = None
62 |
63 | def sample_object(self, sample: DSample):
64 | if len(sample) == 0:
65 | bg_mask = sample.get_background_mask()
66 | self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
67 | self._selected_masks = [[]]
68 | self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
69 | self._neg_masks['required'] = []
70 | return
71 |
72 | gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
73 | binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0
74 |
75 | self.selected_mask = gt_mask
76 | self._selected_masks = pos_masks
77 |
78 | neg_mask_bg = np.logical_not(binary_gt_mask)
79 | neg_mask_border = self._get_border_mask(binary_gt_mask)
80 | if len(sample) <= len(self._selected_masks):
81 | neg_mask_other = neg_mask_bg
82 | else:
83 | neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()),
84 | np.logical_not(binary_gt_mask))
85 |
86 | self._neg_masks = {
87 | 'bg': neg_mask_bg,
88 | 'other': neg_mask_other,
89 | 'border': neg_mask_border,
90 | 'required': neg_masks
91 | }
92 |
93 | def _sample_mask(self, sample: DSample):
94 | root_obj_ids = sample.root_objects
95 |
96 | if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob:
97 | max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects)
98 | num_selected_objects = np.random.randint(2, max_selected_objects + 1)
99 | random_ids = random.sample(root_obj_ids, num_selected_objects)
100 | else:
101 | random_ids = [random.choice(root_obj_ids)]
102 |
103 | gt_mask = None
104 | pos_segments = []
105 | neg_segments = []
106 | for obj_id in random_ids:
107 | obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample)
108 | if gt_mask is None:
109 | gt_mask = obj_gt_mask
110 | else:
111 | gt_mask = np.maximum(gt_mask, obj_gt_mask)
112 |
113 | pos_segments.extend(obj_pos_segments)
114 | neg_segments.extend(obj_neg_segments)
115 |
116 | pos_masks = [self._positive_erode(x) for x in pos_segments]
117 | neg_masks = [self._positive_erode(x) for x in neg_segments]
118 |
119 | return gt_mask, pos_masks, neg_masks
120 |
121 | def _sample_from_masks_layer(self, obj_id, sample: DSample):
122 | objs_tree = sample._objects
123 |
124 | if not self.use_hierarchy:
125 | node_mask = sample.get_object_mask(obj_id)
126 | gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
127 | return gt_mask, [node_mask], []
128 |
129 | def _select_node(node_id):
130 | node_info = objs_tree[node_id]
131 | if not node_info['children'] or random.random() < 0.5:
132 | return node_id
133 | return _select_node(random.choice(node_info['children']))
134 |
135 | selected_node = _select_node(obj_id)
136 | node_info = objs_tree[selected_node]
137 | node_mask = sample.get_object_mask(selected_node)
138 | gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask
139 | pos_mask = node_mask.copy()
140 |
141 | negative_segments = []
142 | if node_info['parent'] is not None and node_info['parent'] in objs_tree:
143 | parent_mask = sample.get_object_mask(node_info['parent'])
144 | #negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask)))
145 |
146 | for child_id in node_info['children']:
147 | if objs_tree[child_id]['area'] / node_info['area'] < 0.10:
148 | child_mask = sample.get_object_mask(child_id)
149 | pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
150 |
151 | if node_info['children']:
152 | max_disabled_children = min(len(node_info['children']), 3)
153 | num_disabled_children = np.random.randint(0, max_disabled_children + 1)
154 | disabled_children = random.sample(node_info['children'], num_disabled_children)
155 |
156 | for child_id in disabled_children:
157 | child_mask = sample.get_object_mask(child_id)
158 | pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
159 | if self.soft_targets:
160 | soft_child_mask = sample.get_soft_object_mask(child_id)
161 | gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask)
162 | else:
163 | gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask))
164 | #negative_segments.append(child_mask)
165 |
166 | return gt_mask, [pos_mask], negative_segments
167 |
168 | def sample_points(self):
169 | assert self._selected_mask is not None
170 |
171 | pos_points = self._multi_mask_sample_points(self._selected_masks,
172 | is_negative=[False] * len(self._selected_masks),
173 | with_first_click=self.first_click_center)
174 |
175 | neg_strategy = [(self._neg_masks[k], prob)
176 | for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)]
177 | neg_masks = self._neg_masks['required'] + [neg_strategy]
178 | neg_points = self._multi_mask_sample_points(neg_masks,
179 | is_negative=[True] * len(self._neg_masks['required']) + [True])
180 | #rint('selected :', len(self._selected_masks))
181 | #print('neg_masks : ', len(neg_masks))
182 |
183 | return pos_points + neg_points
184 |
185 | def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False):
186 | selected_masks = selected_masks[:self.max_num_points]
187 |
188 | each_obj_points = [
189 | self._sample_points(mask, is_negative=is_negative[i],
190 | with_first_click=with_first_click)
191 | for i, mask in enumerate(selected_masks)
192 | ]
193 | each_obj_points = [x for x in each_obj_points if len(x) > 0]
194 |
195 | points = []
196 | if len(each_obj_points) == 1:
197 | points = each_obj_points[0]
198 | elif len(each_obj_points) > 1:
199 | if self.only_one_first_click:
200 | each_obj_points = each_obj_points[:1]
201 |
202 | points = [obj_points[0] for obj_points in each_obj_points]
203 |
204 | aggregated_masks_with_prob = []
205 | for indx, x in enumerate(selected_masks):
206 | if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)):
207 | for t, prob in x:
208 | aggregated_masks_with_prob.append((t, prob / len(selected_masks)))
209 | else:
210 | aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
211 |
212 | other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True)
213 | if len(other_points_union) + len(points) <= self.max_num_points:
214 | points.extend(other_points_union)
215 | else:
216 | points.extend(random.sample(other_points_union, self.max_num_points - len(points)))
217 |
218 | if len(points) < self.max_num_points:
219 | points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
220 |
221 | return points
222 |
223 | def _sample_points(self, mask, is_negative=False, with_first_click=False):
224 | if is_negative:
225 | num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs)
226 | else:
227 | num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs)
228 |
229 | indices_probs = None
230 | if isinstance(mask, (list, tuple)):
231 | indices_probs = [x[1] for x in mask]
232 | indices = [(np.argwhere(x), prob) for x, prob in mask]
233 | if indices_probs:
234 | assert math.isclose(sum(indices_probs), 1.0)
235 | else:
236 | indices = np.argwhere(mask)
237 |
238 | points = []
239 | for j in range(num_points):
240 | first_click = with_first_click and j == 0 and indices_probs is None
241 |
242 | if first_click:
243 | point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob)
244 | elif indices_probs:
245 | point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs)
246 | point_indices = indices[point_indices_indx][0]
247 | else:
248 | point_indices = indices
249 |
250 | num_indices = len(point_indices)
251 | if num_indices > 0:
252 | point_indx = 0 if first_click else 100
253 | click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx]
254 | points.append(click)
255 |
256 | return points
257 |
258 | def _positive_erode(self, mask):
259 | if random.random() > self.positive_erode_prob:
260 | return mask
261 |
262 | kernel = np.ones((3, 3), np.uint8)
263 | eroded_mask = cv2.erode(mask.astype(np.uint8),
264 | kernel, iterations=self.positive_erode_iters).astype(np.bool)
265 |
266 | if eroded_mask.sum() > 10:
267 | return eroded_mask
268 | else:
269 | return mask
270 |
271 | def _get_border_mask(self, mask):
272 | expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum())))
273 | kernel = np.ones((3, 3), np.uint8)
274 | expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r)
275 | expanded_mask[mask.astype(np.bool)] = 0
276 | return expanded_mask
277 |
278 |
279 | @lru_cache(maxsize=None)
280 | def generate_probs(max_num_points, gamma):
281 | probs = []
282 | last_value = 1
283 | for i in range(max_num_points):
284 | probs.append(last_value)
285 | last_value *= gamma
286 |
287 | probs = np.array(probs)
288 | probs /= probs.sum()
289 |
290 | return probs
291 |
292 |
293 | def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
294 | if full_prob > 0 and random.random() < full_prob:
295 | return obj_mask
296 |
297 | padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant')
298 |
299 | dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
300 | if k > 0:
301 | inner_mask = dt > dt.max() / k
302 | return np.argwhere(inner_mask)
303 | else:
304 | prob_map = dt.flatten()
305 | prob_map /= max(prob_map.sum(), 1e-6)
306 | click_indx = np.random.choice(len(prob_map), p=prob_map)
307 | click_coords = np.unravel_index(click_indx, dt.shape)
308 | return np.array([click_coords])
309 |
--------------------------------------------------------------------------------
/isegm/data/sample.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | from isegm.utils.misc import get_labels_with_sizes
4 | from isegm.data.transforms import remove_image_only_transforms
5 | from albumentations import ReplayCompose
6 |
7 | class DSample:
8 | def __init__(self, image, encoded_masks, objects=None,
9 | objects_ids=None, ignore_ids=None, sample_id=None,
10 | init_mask = None):
11 | self.image = image
12 | self.sample_id = sample_id
13 | self.init_mask = init_mask
14 |
15 | if len(encoded_masks.shape) == 2:
16 | encoded_masks = encoded_masks[:, :, np.newaxis]
17 | self._encoded_masks = encoded_masks
18 | self._ignored_regions = []
19 |
20 | if objects_ids is not None:
21 | if not objects_ids or not isinstance(objects_ids[0], tuple):
22 | assert encoded_masks.shape[2] == 1
23 | objects_ids = [(0, obj_id) for obj_id in objects_ids]
24 |
25 | self._objects = dict()
26 | for indx, obj_mapping in enumerate(objects_ids):
27 | self._objects[indx] = {
28 | 'parent': None,
29 | 'mapping': obj_mapping,
30 | 'children': []
31 | }
32 |
33 | if ignore_ids:
34 | if isinstance(ignore_ids[0], tuple):
35 | self._ignored_regions = ignore_ids
36 | else:
37 | self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
38 | else:
39 | self._objects = deepcopy(objects)
40 |
41 | self._augmented = False
42 | self._soft_mask_aug = None
43 | self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)
44 |
45 |
46 | def augment(self, augmentator):
47 | self.reset_augmentation()
48 | aug_output = augmentator(image=self.image, mask=self._encoded_masks)
49 | image, mask = aug_output['image'],aug_output['mask']
50 | self.image = image
51 | self._encoded_masks = mask
52 | self._compute_objects_areas()
53 | self.remove_small_objects(min_area=1)
54 | self._augmented = True
55 |
56 |
57 | def reset_augmentation(self):
58 | if not self._augmented:
59 | return
60 | orig_image, orig_masks, orig_objects = self._original_data
61 | self.image = orig_image
62 | self._encoded_masks = orig_masks
63 | self._objects = deepcopy(orig_objects)
64 | self._augmented = False
65 | self._soft_mask_aug = None
66 |
67 | def remove_small_objects(self, min_area):
68 | if self._objects and not 'area' in list(self._objects.values())[0]:
69 | self._compute_objects_areas()
70 |
71 | for obj_id, obj_info in list(self._objects.items()):
72 | if obj_info['area'] < min_area:
73 | self._remove_object(obj_id)
74 |
75 | def get_object_mask(self, obj_id):
76 | layer_indx, mask_id = self._objects[obj_id]['mapping']
77 | obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
78 | if self._ignored_regions:
79 | for layer_indx, mask_id in self._ignored_regions:
80 | ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
81 | obj_mask[ignore_mask] = -1
82 |
83 | return obj_mask
84 |
85 | def get_soft_object_mask(self, obj_id):
86 | assert self._soft_mask_aug is not None
87 | original_encoded_masks = self._original_data[1]
88 | layer_indx, mask_id = self._objects[obj_id]['mapping']
89 | obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32)
90 | obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image']
91 | return np.clip(obj_mask, 0, 1)
92 |
93 | def get_background_mask(self):
94 | return np.max(self._encoded_masks, axis=2) == 0
95 |
96 | @property
97 | def objects_ids(self):
98 | return list(self._objects.keys())
99 |
100 | @property
101 | def gt_mask(self):
102 | assert len(self._objects) == 1
103 | return self.get_object_mask(self.objects_ids[0])
104 |
105 | @property
106 | def root_objects(self):
107 | return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None]
108 |
109 | def _compute_objects_areas(self):
110 | inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()}
111 | ignored_regions_keys = set(self._ignored_regions)
112 |
113 | for layer_indx in range(self._encoded_masks.shape[2]):
114 | objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx])
115 | for obj_id, obj_area in zip(objects_ids, objects_areas):
116 | inv_key = (layer_indx, obj_id)
117 | if inv_key in ignored_regions_keys:
118 | continue
119 | try:
120 | self._objects[inverse_index[inv_key]]['area'] = obj_area
121 | del inverse_index[inv_key]
122 | except KeyError:
123 | layer = self._encoded_masks[:, :, layer_indx]
124 | layer[layer == obj_id] = 0
125 | self._encoded_masks[:, :, layer_indx] = layer
126 |
127 | for obj_id in inverse_index.values():
128 | self._objects[obj_id]['area'] = 0
129 |
130 | def _remove_object(self, obj_id):
131 | obj_info = self._objects[obj_id]
132 | obj_parent = obj_info['parent']
133 | for child_id in obj_info['children']:
134 | self._objects[child_id]['parent'] = obj_parent
135 |
136 | if obj_parent is not None:
137 | parent_children = self._objects[obj_parent]['children']
138 | parent_children = [x for x in parent_children if x != obj_id]
139 | self._objects[obj_parent]['children'] = parent_children + obj_info['children']
140 |
141 | del self._objects[obj_id]
142 |
143 | def __len__(self):
144 | return len(self._objects)
145 |
--------------------------------------------------------------------------------
/isegm/data/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 |
5 | from albumentations.core.serialization import SERIALIZABLE_REGISTRY
6 | from albumentations import ImageOnlyTransform, DualTransform
7 | from albumentations.core.transforms_interface import to_tuple
8 | from albumentations.augmentations import functional as F
9 | from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes
10 |
11 |
12 | class UniformRandomResize(DualTransform):
13 | def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
14 | super().__init__(always_apply, p)
15 | self.scale_range = scale_range
16 | self.interpolation = interpolation
17 |
18 | def get_params_dependent_on_targets(self, params):
19 | scale = random.uniform(*self.scale_range)
20 | height = int(round(params['image'].shape[0] * scale))
21 | width = int(round(params['image'].shape[1] * scale))
22 | return {'new_height': height, 'new_width': width}
23 |
24 | def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params):
25 | return F.resize(img, height=new_height, width=new_width, interpolation=interpolation)
26 |
27 | def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
28 | scale_x = new_width / params["cols"]
29 | scale_y = new_height / params["rows"]
30 | return F.keypoint_scale(keypoint, scale_x, scale_y)
31 |
32 | def get_transform_init_args_names(self):
33 | return "scale_range", "interpolation"
34 |
35 | @property
36 | def targets_as_params(self):
37 | return ["image"]
38 |
39 |
40 | class ZoomIn(DualTransform):
41 | def __init__(
42 | self,
43 | height,
44 | width,
45 | bbox_jitter=0.1,
46 | expansion_ratio=1.4,
47 | min_crop_size=200,
48 | min_area=100,
49 | always_resize=False,
50 | always_apply=False,
51 | p=0.5,
52 | ):
53 | super(ZoomIn, self).__init__(always_apply, p)
54 | self.height = height
55 | self.width = width
56 | self.bbox_jitter = to_tuple(bbox_jitter)
57 | self.expansion_ratio = expansion_ratio
58 | self.min_crop_size = min_crop_size
59 | self.min_area = min_area
60 | self.always_resize = always_resize
61 |
62 | def apply(self, img, selected_object, bbox, **params):
63 | if selected_object is None:
64 | if self.always_resize:
65 | img = F.resize(img, height=self.height, width=self.width)
66 | return img
67 |
68 | rmin, rmax, cmin, cmax = bbox
69 | img = img[rmin:rmax + 1, cmin:cmax + 1]
70 | img = F.resize(img, height=self.height, width=self.width)
71 |
72 | return img
73 |
74 | def apply_to_mask(self, mask, selected_object, bbox, **params):
75 | if selected_object is None:
76 | if self.always_resize:
77 | mask = F.resize(mask, height=self.height, width=self.width,
78 | interpolation=cv2.INTER_NEAREST)
79 | return mask
80 |
81 | rmin, rmax, cmin, cmax = bbox
82 | mask = mask[rmin:rmax + 1, cmin:cmax + 1]
83 | if isinstance(selected_object, tuple):
84 | layer_indx, mask_id = selected_object
85 | obj_mask = mask[:, :, layer_indx] == mask_id
86 | new_mask = np.zeros_like(mask)
87 | new_mask[:, :, layer_indx][obj_mask] = mask_id
88 | else:
89 | obj_mask = mask == selected_object
90 | new_mask = mask.copy()
91 | new_mask[np.logical_not(obj_mask)] = 0
92 |
93 | new_mask = F.resize(new_mask, height=self.height, width=self.width,
94 | interpolation=cv2.INTER_NEAREST)
95 | return new_mask
96 |
97 | def get_params_dependent_on_targets(self, params):
98 | instances = params['mask']
99 |
100 | is_mask_layer = len(instances.shape) > 2
101 | candidates = []
102 | if is_mask_layer:
103 | for layer_indx in range(instances.shape[2]):
104 | labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
105 | candidates.extend([(layer_indx, obj_id)
106 | for obj_id, area in zip(labels, areas)
107 | if area > self.min_area])
108 | else:
109 | labels, areas = get_labels_with_sizes(instances)
110 | candidates = [obj_id for obj_id, area in zip(labels, areas)
111 | if area > self.min_area]
112 |
113 | selected_object = None
114 | bbox = None
115 | if candidates:
116 | selected_object = random.choice(candidates)
117 | if is_mask_layer:
118 | layer_indx, mask_id = selected_object
119 | obj_mask = instances[:, :, layer_indx] == mask_id
120 | else:
121 | obj_mask = instances == selected_object
122 |
123 | bbox = get_bbox_from_mask(obj_mask)
124 |
125 | if isinstance(self.expansion_ratio, tuple):
126 | expansion_ratio = random.uniform(*self.expansion_ratio)
127 | else:
128 | expansion_ratio = self.expansion_ratio
129 |
130 | bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size)
131 | bbox = self._jitter_bbox(bbox)
132 | bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
133 |
134 | return {
135 | 'selected_object': selected_object,
136 | 'bbox': bbox
137 | }
138 |
139 | def _jitter_bbox(self, bbox):
140 | rmin, rmax, cmin, cmax = bbox
141 | height = rmax - rmin + 1
142 | width = cmax - cmin + 1
143 | rmin = int(rmin + random.uniform(*self.bbox_jitter) * height)
144 | rmax = int(rmax + random.uniform(*self.bbox_jitter) * height)
145 | cmin = int(cmin + random.uniform(*self.bbox_jitter) * width)
146 | cmax = int(cmax + random.uniform(*self.bbox_jitter) * width)
147 |
148 | return rmin, rmax, cmin, cmax
149 |
150 | def apply_to_bbox(self, bbox, **params):
151 | raise NotImplementedError
152 |
153 | def apply_to_keypoint(self, keypoint, **params):
154 | raise NotImplementedError
155 |
156 | @property
157 | def targets_as_params(self):
158 | return ["mask"]
159 |
160 | def get_transform_init_args_names(self):
161 | return ("height", "width", "bbox_jitter",
162 | "expansion_ratio", "min_crop_size", "min_area", "always_resize")
163 |
164 |
165 | def remove_image_only_transforms(sdict):
166 | if not 'transforms' in sdict:
167 | return sdict
168 |
169 | keep_transforms = []
170 | for tdict in sdict['transforms']:
171 | cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']]
172 | if 'transforms' in tdict:
173 | keep_transforms.append(remove_image_only_transforms(tdict))
174 | elif not issubclass(cls, ImageOnlyTransform):
175 | keep_transforms.append(tdict)
176 | sdict['transforms'] = keep_transforms
177 |
178 | return sdict
179 |
--------------------------------------------------------------------------------
/isegm/engine/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from isegm.utils.log import logger
4 |
5 |
6 | def get_optimizer(model, opt_name, opt_kwargs):
7 | params = []
8 | base_lr = opt_kwargs['lr']
9 | for name, param in model.named_parameters():
10 | param_group = {'params': [param]}
11 | if not param.requires_grad:
12 | params.append(param_group)
13 | continue
14 |
15 | if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0):
16 | logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
17 | param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult
18 |
19 | params.append(param_group)
20 |
21 | optimizer = {
22 | 'sgd': torch.optim.SGD,
23 | 'adam': torch.optim.Adam,
24 | 'adamw': torch.optim.AdamW
25 | }[opt_name.lower()](params, **opt_kwargs)
26 |
27 | return optimizer
28 |
--------------------------------------------------------------------------------
/isegm/inference/clicker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import cv2
4 |
5 |
6 | class Clicker(object):
7 | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
8 | self.click_indx_offset = click_indx_offset
9 | if gt_mask is not None:
10 | self.gt_mask = gt_mask == 1
11 | self.not_ignore_mask = gt_mask != ignore_label
12 | else:
13 | self.gt_mask = None
14 |
15 | self.reset_clicks()
16 |
17 | if init_clicks is not None:
18 | for click in init_clicks:
19 | self.add_click(click)
20 |
21 | def make_next_click(self, pred_mask):
22 | assert self.gt_mask is not None
23 | click = self._get_next_click(pred_mask)
24 | self.add_click(click)
25 |
26 | def get_clicks(self, clicks_limit=None):
27 | return self.clicks_list[:clicks_limit]
28 |
29 | def _get_next_click(self, pred_mask, padding=True):
30 | fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
31 | fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
32 |
33 | if padding:
34 | fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
35 | fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
36 |
37 | fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
38 | fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
39 |
40 | if padding:
41 | fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
42 | fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
43 |
44 | fn_mask_dt = fn_mask_dt * self.not_clicked_map
45 | fp_mask_dt = fp_mask_dt * self.not_clicked_map
46 |
47 | fn_max_dist = np.max(fn_mask_dt)
48 | fp_max_dist = np.max(fp_mask_dt)
49 |
50 | is_positive = fn_max_dist > fp_max_dist
51 | if is_positive:
52 | coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
53 | else:
54 | coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
55 |
56 | return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
57 |
58 | def add_click(self, click):
59 | coords = click.coords
60 |
61 | click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
62 | if click.is_positive:
63 | self.num_pos_clicks += 1
64 | else:
65 | self.num_neg_clicks += 1
66 |
67 | self.clicks_list.append(click)
68 | if self.gt_mask is not None:
69 | self.not_clicked_map[coords[0], coords[1]] = False
70 |
71 | def _remove_last_click(self):
72 | click = self.clicks_list.pop()
73 | coords = click.coords
74 |
75 | if click.is_positive:
76 | self.num_pos_clicks -= 1
77 | else:
78 | self.num_neg_clicks -= 1
79 |
80 | if self.gt_mask is not None:
81 | self.not_clicked_map[coords[0], coords[1]] = True
82 |
83 | def reset_clicks(self):
84 | if self.gt_mask is not None:
85 | self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
86 |
87 | self.num_pos_clicks = 0
88 | self.num_neg_clicks = 0
89 |
90 | self.clicks_list = []
91 |
92 | def get_state(self):
93 | return deepcopy(self.clicks_list)
94 |
95 | def set_state(self, state):
96 | self.reset_clicks()
97 | for click in state:
98 | self.add_click(click)
99 |
100 | def __len__(self):
101 | return len(self.clicks_list)
102 |
103 |
104 | class Click:
105 | def __init__(self, is_positive, coords, indx=None):
106 | self.is_positive = is_positive
107 | self.coords = coords
108 | self.indx = indx
109 |
110 | @property
111 | def coords_and_indx(self):
112 | return (*self.coords, self.indx)
113 |
114 | def copy(self, **kwargs):
115 | self_copy = deepcopy(self)
116 | for k, v in kwargs.items():
117 | setattr(self_copy, k, v)
118 | return self_copy
119 |
--------------------------------------------------------------------------------
/isegm/inference/evaluation.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 | import numpy as np
4 | import torch
5 | import os
6 | from isegm.inference import utils
7 | from isegm.inference.clicker import Clicker
8 | import shutil
9 | import cv2
10 | from isegm.utils.vis import add_tag
11 |
12 |
13 |
14 | try:
15 | get_ipython()
16 | from tqdm import tqdm_notebook as tqdm
17 | except NameError:
18 | from tqdm import tqdm
19 |
20 |
21 | def evaluate_dataset(dataset, predictor, vis = True, vis_path = './experiments/vis_val/',**kwargs):
22 | all_ious = []
23 | if vis:
24 | save_dir = vis_path + dataset.name + '/'
25 | if os.path.exists(save_dir):
26 | shutil.rmtree(save_dir)
27 | os.makedirs(save_dir)
28 | else:
29 | save_dir = None
30 |
31 | start_time = time()
32 | for index in tqdm(range(len(dataset)), leave=False):
33 | sample = dataset.get_sample(index)
34 |
35 | _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, sample.init_mask, predictor,
36 | sample_id=index, vis= vis, save_dir = save_dir,
37 | index = index, **kwargs)
38 | all_ious.append(sample_ious)
39 | end_time = time()
40 | elapsed_time = end_time - start_time
41 |
42 | return all_ious, elapsed_time
43 |
44 | def Progressive_Merge(pred_mask, previous_mask, y, x):
45 | diff_regions = np.logical_xor(previous_mask, pred_mask)
46 | num, labels = cv2.connectedComponents(diff_regions.astype(np.uint8))
47 | label = labels[y,x]
48 | corr_mask = labels == label
49 | if previous_mask[y,x] == 1:
50 | progressive_mask = np.logical_and( previous_mask, np.logical_not(corr_mask))
51 | else:
52 | progressive_mask = np.logical_or( previous_mask, corr_mask)
53 | return progressive_mask
54 |
55 |
56 | def evaluate_sample(image, gt_mask, init_mask, predictor, max_iou_thr,
57 | pred_thr=0.49, min_clicks=1, max_clicks=20,
58 | sample_id=None, vis = True, save_dir = None, index = 0, callback=None,
59 | progressive_mode = True,
60 | ):
61 | clicker = Clicker(gt_mask=gt_mask)
62 | pred_mask = np.zeros_like(gt_mask)
63 | prev_mask = pred_mask
64 | ious_list = []
65 |
66 | with torch.no_grad():
67 | predictor.set_input_image(image)
68 | if init_mask is not None:
69 | predictor.set_prev_mask(init_mask)
70 | pred_mask = init_mask
71 | prev_mask = init_mask
72 | num_pm = 0
73 | else:
74 | num_pm = 999
75 |
76 | for click_indx in range(max_clicks):
77 | vis_pred = prev_mask
78 | clicker.make_next_click(pred_mask)
79 | pred_probs = predictor.get_prediction(clicker)
80 | pred_mask = pred_probs > pred_thr
81 |
82 | if progressive_mode:
83 | clicks = clicker.get_clicks()
84 | if len(clicks) >= num_pm:
85 | last_click = clicks[-1]
86 | last_y, last_x = last_click.coords[0], last_click.coords[1]
87 | pred_mask = Progressive_Merge(pred_mask, prev_mask,last_y, last_x)
88 | predictor.transforms[0]._prev_probs = np.expand_dims(np.expand_dims(pred_mask,0),0)
89 | if callback is not None:
90 | callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
91 |
92 | iou = utils.get_iou(gt_mask, pred_mask)
93 | ious_list.append(iou)
94 | prev_mask = pred_mask
95 |
96 | if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
97 | break
98 |
99 | if vis:
100 | clicks_list = clicker.get_clicks()
101 | last_y, last_x = predictor.last_y, predictor.last_x
102 | out_image = vis_result_base(image, pred_mask, gt_mask, init_mask, iou,click_indx+1,clicks_list, vis_pred, last_y, last_x)
103 | cv2.imwrite(save_dir+str(index)+'.png', out_image)
104 | return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
105 |
106 |
107 | def vis_result_base(image, pred_mask, instances_mask, init_mask, iou, num_clicks, clicks_list, prev_prediction, last_y, last_x):
108 |
109 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
110 |
111 | pred_mask = pred_mask.astype(np.float32)
112 | prev_mask = prev_prediction.astype(np.float32)
113 | instances_mask = instances_mask.astype(np.float32)
114 | image = image.astype(np.float32)
115 |
116 | pred_mask_3 = np.repeat(pred_mask[...,np.newaxis],3,2)
117 | prev_mask_3 = np.repeat(prev_mask[...,np.newaxis],3,2)
118 | gt_mask_3 = np.repeat( instances_mask[...,np.newaxis],3,2 )
119 |
120 | color_mask_gt = np.zeros_like(pred_mask_3)
121 | color_mask_gt[:,:,0] = instances_mask * 255
122 |
123 | color_mask_pred = np.zeros_like(pred_mask_3) #+ 255
124 | color_mask_pred[:,:,0] = pred_mask * 255
125 |
126 | color_mask_prev = np.zeros_like(prev_mask_3) #+ 255
127 | color_mask_prev[:,:,0] = prev_mask * 255
128 |
129 |
130 | fusion_pred = image * 0.4 + color_mask_pred * 0.6
131 | fusion_pred = image * (1-pred_mask_3) + fusion_pred * pred_mask_3
132 |
133 | fusion_prev = image * 0.4 + color_mask_prev * 0.6
134 | fusion_prev = image * (1-prev_mask_3) + fusion_prev * prev_mask_3
135 |
136 |
137 | fusion_gt = image * 0.4 + color_mask_gt * 0.6
138 |
139 | color_mask_init = np.zeros_like(pred_mask_3)
140 | if init_mask is not None:
141 | color_mask_init[:,:,0] = init_mask * 255
142 |
143 | fusion_init = image * 0.4 + color_mask_init * 0.6
144 | fusion_init = image * (1-color_mask_init) + fusion_init * color_mask_init
145 |
146 |
147 | #cv2.putText( image, 'click num: '+str(num_clicks)+ ' iou: '+ str(round(iou,3)), (50,50),
148 | # cv2.FONT_HERSHEY_COMPLEX, 1, (255, 255, 255 ), 1 )
149 |
150 | for i in range(len(clicks_list)):
151 | click_tuple = clicks_list[i]
152 |
153 | if click_tuple.is_positive:
154 | color = (0,0,255)
155 | else:
156 | color = (0,255,0)
157 |
158 | coord = click_tuple.coords
159 | x,y = coord[1], coord[0]
160 | if x < 0 or y< 0:
161 | continue
162 | cv2.circle(fusion_pred,(x,y),4,color,-1)
163 | #cv2.putText(fusion_pred, str(i+1), (x-10, y-10), cv2.FONT_HERSHEY_COMPLEX, 0.6 , color,1 )
164 |
165 | cv2.circle(fusion_pred,(last_x,last_y),2,(255,255,255),-1)
166 | image = add_tag(image, 'nclicks:'+str(num_clicks)+ ' iou:'+ str(round(iou,3)))
167 | fusion_init = add_tag(fusion_init,'init mask')
168 | fusion_pred = add_tag(fusion_pred,'pred')
169 | fusion_gt = add_tag(fusion_gt,'gt')
170 | fusion_prev = add_tag(fusion_prev,'prev pred')
171 |
172 | h,w = image.shape[0],image.shape[1]
173 | if h < w:
174 | out_image = cv2.hconcat([image.astype(np.float32),fusion_init.astype(np.float32),fusion_pred.astype(np.float32), fusion_gt.astype(np.float32),fusion_prev.astype(np.float32)])
175 | else:
176 | out_image = cv2.hconcat([image.astype(np.float32),fusion_init.astype(np.float32), fusion_pred.astype(np.float32), fusion_gt.astype(np.float32),fusion_prev.astype(np.float32)])
177 |
178 | return out_image
179 |
180 |
--------------------------------------------------------------------------------
/isegm/inference/predictors/__init__.py:
--------------------------------------------------------------------------------
1 | from .baseline import BaselinePredictor
2 | from isegm.inference.transforms import ZoomIn
3 |
4 |
5 |
6 | def get_predictor(net, brs_mode, device,
7 | prob_thresh=0.49,
8 | infer_size = 256,
9 | focus_crop_r= 1.4,
10 | with_flip=False,
11 | zoom_in_params=dict(),
12 | predictor_params=None,
13 | brs_opt_func_params=None,
14 | lbfgs_params=None):
15 |
16 | predictor_params_ = {
17 | 'optimize_after_n_clicks': 1
18 | }
19 |
20 | if zoom_in_params is not None:
21 | zoom_in = ZoomIn(**zoom_in_params)
22 | else:
23 | zoom_in = None
24 |
25 | if predictor_params is not None:
26 | predictor_params_.update(predictor_params)
27 | predictor = BaselinePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, infer_size =infer_size, **predictor_params_)
28 |
29 |
30 |
31 | return predictor
32 |
--------------------------------------------------------------------------------
/isegm/inference/predictors/baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchvision import transforms
4 | from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide, ResizeTrans
5 |
6 |
7 | class BaselinePredictor(object):
8 | def __init__(self, model, device,
9 | net_clicks_limit=None,
10 | with_flip=False,
11 | zoom_in=None,
12 | max_size=None,
13 | infer_size = 384,
14 | **kwargs):
15 | self.with_flip = with_flip
16 | self.net_clicks_limit = net_clicks_limit
17 | self.original_image = None
18 | self.device = device
19 | self.zoom_in = zoom_in
20 | self.prev_prediction = None
21 | self.model_indx = 0
22 | self.click_models = None
23 | self.net_state_dict = None
24 |
25 | if isinstance(model, tuple):
26 | self.net, self.click_models = model
27 | else:
28 | self.net = model
29 |
30 | self.to_tensor = transforms.ToTensor()
31 |
32 | self.transforms = [zoom_in] if zoom_in is not None else []
33 | if max_size is not None:
34 | self.transforms.append(LimitLongestSide(max_size=max_size))
35 | self.crop_l = infer_size
36 | self.transforms.append(ResizeTrans(self.crop_l))
37 | self.transforms.append(SigmoidForPred())
38 | self.focus_roi = None
39 | self.global_roi = None
40 | self.with_flip = True
41 | if hasattr(self.net, 'set_status'):
42 | self.net.set_status(training=False)
43 |
44 | def set_input_image(self, image):
45 | image_nd = self.to_tensor(image)
46 | for transform in self.transforms:
47 | transform.reset()
48 | self.original_image = image_nd.to(self.device)
49 | if len(self.original_image.shape) == 3:
50 | self.original_image = self.original_image.unsqueeze(0)
51 | self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
52 |
53 | def set_prev_mask(self, mask):
54 | self.prev_prediction = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(self.device).float()
55 |
56 | def get_prediction(self, clicker, prev_mask=None):
57 | clicks_list = clicker.get_clicks()
58 | click = clicks_list[-1]
59 | last_y,last_x = click.coords[0],click.coords[1]
60 | self.last_y = last_y
61 | self.last_x = last_x
62 |
63 | if self.click_models is not None:
64 | model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
65 | if model_indx != self.model_indx:
66 | self.model_indx = model_indx
67 | self.net = self.click_models[model_indx]
68 |
69 | input_image = self.original_image
70 | if prev_mask is None:
71 | prev_mask = self.prev_prediction
72 | if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
73 | input_image = torch.cat((input_image, prev_mask), dim=1)
74 |
75 |
76 | image_nd, clicks_lists, is_image_changed = self.apply_transforms(
77 | input_image, [clicks_list]
78 | )
79 |
80 | pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
81 | prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
82 | size=image_nd.size()[2:])
83 |
84 | for t in reversed(self.transforms):
85 | prediction = t.inv_transform(prediction)
86 |
87 | self.prev_prediction = prediction
88 | return prediction.cpu().numpy()[0, 0]
89 |
90 | def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
91 | points_nd = self.get_points_nd(clicks_lists)
92 | output = self.net(image_nd, points_nd)
93 | return output['instances']
94 |
95 | def mapp_roi(self, focus_roi, global_roi):
96 | yg1,yg2,xg1,xg2 = global_roi
97 | hg,wg = yg2-yg1, xg2-xg1
98 | yf1,yf2,xf1,xf2 = focus_roi
99 |
100 | '''
101 | yf1_n = (yf1-yg1+1) * (self.crop_l/hg)
102 | yf2_n = (yf2-yg1+1) * (self.crop_l/hg)
103 | xf1_n = (xf1-xg1+1) * (self.crop_l/wg)
104 | xf2_n = (xf2-xg1+1) * (self.crop_l/wg)
105 |
106 | '''
107 | yf1_n = (yf1-yg1) * (self.crop_l/hg)
108 | yf2_n = (yf2-yg1) * (self.crop_l/hg)
109 | xf1_n = (xf1-xg1) * (self.crop_l/wg)
110 | xf2_n = (xf2-xg1) * (self.crop_l/wg)
111 |
112 | yf1_n = max(yf1_n,0)
113 | yf2_n = min(yf2_n,self.crop_l)
114 | xf1_n = max(xf1_n,0)
115 | xf2_n = min(xf2_n,self.crop_l)
116 | return (yf1_n,yf2_n,xf1_n,xf2_n)
117 |
118 |
119 |
120 | def _get_transform_states(self):
121 | return [x.get_state() for x in self.transforms]
122 |
123 | def _set_transform_states(self, states):
124 | assert len(states) == len(self.transforms)
125 | for state, transform in zip(states, self.transforms):
126 | transform.set_state(state)
127 | print('_set_transform_states')
128 |
129 | def apply_transforms(self, image_nd, clicks_lists):
130 | is_image_changed = False
131 | for t in self.transforms:
132 | image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
133 | is_image_changed |= t.image_changed
134 |
135 | return image_nd, clicks_lists, is_image_changed
136 |
137 | def get_points_nd(self, clicks_lists):
138 | total_clicks = []
139 | num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
140 | num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
141 | num_max_points = max(num_pos_clicks + num_neg_clicks)
142 | if self.net_clicks_limit is not None:
143 | num_max_points = min(self.net_clicks_limit, num_max_points)
144 | num_max_points = max(1, num_max_points)
145 |
146 | for clicks_list in clicks_lists:
147 | clicks_list = clicks_list[:self.net_clicks_limit]
148 | pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
149 | pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
150 |
151 | neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
152 | neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
153 | total_clicks.append(pos_clicks + neg_clicks)
154 |
155 | return torch.tensor(total_clicks, device=self.device)
156 |
157 | def get_states(self):
158 | return {
159 | 'transform_states': self._get_transform_states(),
160 | 'prev_prediction': self.prev_prediction.clone()
161 | }
162 |
163 | def set_states(self, states):
164 | self._set_transform_states(states['transform_states'])
165 | self.prev_prediction = states['prev_prediction']
166 | print('set')
167 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import SigmoidForPred
2 | from .flip import AddHorizontalFlip
3 | from .zoom_in import ZoomIn
4 | from .limit_longest_side import LimitLongestSide
5 | from .crops import Crops
6 | from .resize import ResizeTrans
7 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseTransform(object):
5 | def __init__(self):
6 | self.image_changed = False
7 |
8 | def transform(self, image_nd, clicks_lists):
9 | raise NotImplementedError
10 |
11 | def inv_transform(self, prob_map):
12 | raise NotImplementedError
13 |
14 | def reset(self):
15 | raise NotImplementedError
16 |
17 | def get_state(self):
18 | raise NotImplementedError
19 |
20 | def set_state(self, state):
21 | raise NotImplementedError
22 |
23 |
24 | class SigmoidForPred(BaseTransform):
25 | def transform(self, image_nd, clicks_lists):
26 | return image_nd, clicks_lists
27 |
28 | def inv_transform(self, prob_map):
29 | return torch.sigmoid(prob_map)
30 |
31 | def reset(self):
32 | pass
33 |
34 | def get_state(self):
35 | return None
36 |
37 | def set_state(self, state):
38 | pass
39 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/crops.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import numpy as np
5 | from typing import List
6 |
7 | from isegm.inference.clicker import Click
8 | from .base import BaseTransform
9 |
10 |
11 | class Crops(BaseTransform):
12 | def __init__(self, crop_size=(320, 480), min_overlap=0.2):
13 | super().__init__()
14 | self.crop_height, self.crop_width = crop_size
15 | self.min_overlap = min_overlap
16 |
17 | self.x_offsets = None
18 | self.y_offsets = None
19 | self._counts = None
20 |
21 | def transform(self, image_nd, clicks_lists: List[List[Click]]):
22 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
23 | image_height, image_width = image_nd.shape[2:4]
24 | self._counts = None
25 |
26 | if image_height < self.crop_height or image_width < self.crop_width:
27 | return image_nd, clicks_lists
28 |
29 | self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
30 | self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
31 | self._counts = np.zeros((image_height, image_width))
32 |
33 | image_crops = []
34 | for dy in self.y_offsets:
35 | for dx in self.x_offsets:
36 | self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
37 | image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
38 | image_crops.append(image_crop)
39 | image_crops = torch.cat(image_crops, dim=0)
40 | self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
41 |
42 | clicks_list = clicks_lists[0]
43 | clicks_lists = []
44 | for dy in self.y_offsets:
45 | for dx in self.x_offsets:
46 | crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list]
47 | clicks_lists.append(crop_clicks)
48 |
49 | return image_crops, clicks_lists
50 |
51 | def inv_transform(self, prob_map):
52 | if self._counts is None:
53 | return prob_map
54 |
55 | new_prob_map = torch.zeros((1, 1, *self._counts.shape),
56 | dtype=prob_map.dtype, device=prob_map.device)
57 |
58 | crop_indx = 0
59 | for dy in self.y_offsets:
60 | for dx in self.x_offsets:
61 | new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
62 | crop_indx += 1
63 | new_prob_map = torch.div(new_prob_map, self._counts)
64 |
65 | return new_prob_map
66 |
67 | def get_state(self):
68 | return self.x_offsets, self.y_offsets, self._counts
69 |
70 | def set_state(self, state):
71 | self.x_offsets, self.y_offsets, self._counts = state
72 |
73 | def reset(self):
74 | self.x_offsets = None
75 | self.y_offsets = None
76 | self._counts = None
77 |
78 |
79 | def get_offsets(length, crop_size, min_overlap_ratio=0.2):
80 | if length == crop_size:
81 | return [0]
82 |
83 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
84 | N = math.ceil(N)
85 |
86 | overlap_ratio = (N - length / crop_size) / (N - 1)
87 | overlap_width = int(crop_size * overlap_ratio)
88 |
89 | offsets = [0]
90 | for i in range(1, N):
91 | new_offset = offsets[-1] + crop_size - overlap_width
92 | if new_offset + crop_size > length:
93 | new_offset = length - crop_size
94 |
95 | offsets.append(new_offset)
96 |
97 | return offsets
98 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/flip.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from typing import List
4 | from isegm.inference.clicker import Click
5 | from .base import BaseTransform
6 |
7 |
8 | class AddHorizontalFlip(BaseTransform):
9 | def transform(self, image_nd, clicks_lists: List[List[Click]]):
10 | assert len(image_nd.shape) == 4
11 | image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
12 |
13 | image_width = image_nd.shape[3]
14 | clicks_lists_flipped = []
15 | for clicks_list in clicks_lists:
16 | clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
17 | for click in clicks_list]
18 | clicks_lists_flipped.append(clicks_list_flipped)
19 | clicks_lists = clicks_lists + clicks_lists_flipped
20 |
21 | return image_nd, clicks_lists
22 |
23 | def inv_transform(self, prob_map):
24 | assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
25 | num_maps = prob_map.shape[0] // 2
26 | prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
27 |
28 | return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
29 |
30 | def get_state(self):
31 | return None
32 |
33 | def set_state(self, state):
34 | pass
35 |
36 | def reset(self):
37 | pass
38 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/limit_longest_side.py:
--------------------------------------------------------------------------------
1 | from .zoom_in import ZoomIn, get_roi_image_nd
2 |
3 |
4 | class LimitLongestSide(ZoomIn):
5 | def __init__(self, max_size=800):
6 | super().__init__(target_size=max_size, skip_clicks=0)
7 |
8 | def transform(self, image_nd, clicks_lists):
9 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
10 | image_max_size = max(image_nd.shape[2:4])
11 | self.image_changed = False
12 |
13 | if image_max_size <= self.target_size:
14 | return image_nd, clicks_lists
15 | self._input_image = image_nd
16 |
17 | self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
18 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
19 | self.image_changed = True
20 |
21 | tclicks_lists = [self._transform_clicks(clicks_lists[0])]
22 | return self._roi_image, tclicks_lists
23 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/resize.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from isegm.inference.clicker import Click
7 | from .base import BaseTransform
8 | import torch.nn.functional as F
9 |
10 |
11 | class ResizeTrans(BaseTransform):
12 | def __init__(self, l=480):
13 | super().__init__()
14 | self.crop_height = l
15 | self.crop_width = l
16 |
17 | def transform(self, image_nd, clicks_lists):
18 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
19 | image_height, image_width = image_nd.shape[2:4]
20 | self.image_height = image_height
21 | self.image_width = image_width
22 | #image_np = np.transpose( image_nd[0].numpy(), (1,2,0)).astype(np.uint8)
23 | #image_np_r = cv2.resize( image_np, (self.crop_width, self.crop_height))
24 | #image_nd_r = torch.from_numpy(image_np_r).unsqueeze(0).permute(0,3,1,2).float()
25 | image_nd_r = F.interpolate(image_nd, (self.crop_height, self.crop_width), mode = 'bilinear', align_corners=True )
26 |
27 | y_ratio = self.crop_height / image_height
28 | x_ratio = self.crop_width / image_width
29 |
30 | #clicks_list = clicks_lists[0]
31 | #clicks_lists = []
32 | #resize_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] * y_ratio, x.coords[1] * x_ratio ))
33 | # for x in clicks_list]
34 | #clicks_lists.append(resize_clicks)
35 |
36 | clicks_lists_resized = []
37 | for clicks_list in clicks_lists:
38 | clicks_list_resized = [click.copy(coords=(click.coords[0] * y_ratio, click.coords[1] * x_ratio ))
39 | for click in clicks_list]
40 | clicks_lists_resized.append(clicks_list_resized)
41 |
42 | return image_nd_r, clicks_lists_resized
43 |
44 | def inv_transform(self, prob_map):
45 | new_prob_map = F.interpolate(prob_map, (self.image_height, self.image_width), mode='bilinear', align_corners=True )
46 |
47 | return new_prob_map
48 |
49 | def get_state(self):
50 | return self.x_offsets, self.y_offsets, self._counts
51 |
52 | def set_state(self, state):
53 | self.x_offsets, self.y_offsets, self._counts = state
54 |
55 | def reset(self):
56 | self.x_offsets = None
57 | self.y_offsets = None
58 | self._counts = None
59 |
60 |
61 | def get_offsets(length, crop_size, min_overlap_ratio=0.2):
62 | if length == crop_size:
63 | return [0]
64 |
65 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
66 | N = math.ceil(N)
67 |
68 | overlap_ratio = (N - length / crop_size) / (N - 1)
69 | overlap_width = int(crop_size * overlap_ratio)
70 |
71 | offsets = [0]
72 | for i in range(1, N):
73 | new_offset = offsets[-1] + crop_size - overlap_width
74 | if new_offset + crop_size > length:
75 | new_offset = length - crop_size
76 |
77 | offsets.append(new_offset)
78 |
79 | return offsets
80 |
--------------------------------------------------------------------------------
/isegm/inference/transforms/zoom_in.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from typing import List
4 | from isegm.inference.clicker import Click
5 | from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
6 | from .base import BaseTransform
7 |
8 |
9 |
10 | class ZoomIn(BaseTransform):
11 | def __init__(self,
12 | target_size=480,
13 | skip_clicks=1,
14 | expansion_ratio=1.4,
15 | min_crop_size=10,#200
16 | recompute_thresh_iou=0.5,
17 | prob_thresh=0.49):
18 | super().__init__()
19 | self.target_size = target_size
20 | self.min_crop_size = min_crop_size
21 | self.skip_clicks = skip_clicks
22 | self.expansion_ratio = expansion_ratio
23 | self.recompute_thresh_iou = recompute_thresh_iou
24 | self.prob_thresh = prob_thresh
25 |
26 | self._input_image_shape = None
27 | self._prev_probs = None
28 | self._object_roi = None
29 | self._roi_image = None
30 |
31 | def transform(self, image_nd, clicks_lists: List[List[Click]]):
32 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
33 | self.image_changed = False
34 |
35 | clicks_list = clicks_lists[0]
36 | if len(clicks_list) <= self.skip_clicks:
37 | return image_nd, clicks_lists
38 |
39 | self._input_image_shape = image_nd.shape
40 |
41 | current_object_roi = None
42 | if self._prev_probs is not None:
43 | current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
44 | if current_pred_mask.sum() > 0:
45 | current_object_roi = get_object_roi(current_pred_mask, clicks_list,
46 | self.expansion_ratio, self.min_crop_size)
47 | else:
48 | print('None')
49 |
50 | if current_object_roi is None:
51 | if self.skip_clicks >= 0:
52 | return image_nd, clicks_lists
53 | else:
54 | current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1
55 |
56 | # here
57 | update_object_roi = True
58 | if self._object_roi is None:
59 | update_object_roi = True
60 | elif not check_object_roi(self._object_roi, clicks_list):
61 | update_object_roi = True
62 | elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
63 | update_object_roi = True
64 |
65 | if update_object_roi:
66 | self._object_roi = current_object_roi
67 | self.image_changed = True
68 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
69 |
70 | tclicks_lists = [self._transform_clicks(clicks_list)]
71 | return self._roi_image.to(image_nd.device), tclicks_lists
72 |
73 | def inv_transform(self, prob_map):
74 | if self._object_roi is None:
75 | self._prev_probs = prob_map.cpu().numpy()
76 | return prob_map
77 |
78 | assert prob_map.shape[0] == 1
79 | rmin, rmax, cmin, cmax = self._object_roi
80 | prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
81 | mode='bilinear', align_corners=True)
82 |
83 |
84 |
85 | if self._prev_probs is not None:
86 | new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
87 | new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
88 | #new_prob_map[:, :, rmin:rmax, cmin:cmax] = prob_map
89 | else:
90 | new_prob_map = prob_map
91 |
92 | self._prev_probs = new_prob_map.cpu().numpy()
93 |
94 | return new_prob_map
95 |
96 | def check_possible_recalculation(self):
97 | if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
98 | return False
99 |
100 | pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
101 | if pred_mask.sum() > 0:
102 | possible_object_roi = get_object_roi(pred_mask, [],
103 | self.expansion_ratio, self.min_crop_size)
104 | image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
105 | if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
106 | return True
107 | return False
108 |
109 | def get_state(self):
110 | roi_image = self._roi_image.cpu() if self._roi_image is not None else None
111 | return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
112 |
113 | def set_state(self, state):
114 | self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
115 |
116 | def reset(self):
117 | self._input_image_shape = None
118 | self._object_roi = None
119 | self._prev_probs = None
120 | self._roi_image = None
121 | self.image_changed = False
122 |
123 | def _transform_clicks(self, clicks_list):
124 | if self._object_roi is None:
125 | return clicks_list
126 |
127 | rmin, rmax, cmin, cmax = self._object_roi
128 | crop_height, crop_width = self._roi_image.shape[2:]
129 |
130 | transformed_clicks = []
131 | for click in clicks_list:
132 | new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
133 | new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
134 | transformed_clicks.append(click.copy(coords=(new_r, new_c)))
135 | return transformed_clicks
136 |
137 |
138 | def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
139 | pred_mask = pred_mask.copy()
140 |
141 | for click in clicks_list:
142 | if click.is_positive:
143 | pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
144 |
145 | bbox = get_bbox_from_mask(pred_mask)
146 | bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
147 | h, w = pred_mask.shape[0], pred_mask.shape[1]
148 | bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
149 |
150 | return bbox
151 |
152 |
153 | def get_roi_image_nd(image_nd, object_roi, target_size):
154 | rmin, rmax, cmin, cmax = object_roi
155 |
156 | height = rmax - rmin + 1
157 | width = cmax - cmin + 1
158 |
159 | if isinstance(target_size, tuple):
160 | new_height, new_width = target_size
161 | else:
162 | scale = target_size / max(height, width)
163 | new_height = int(round(height * scale))
164 | new_width = int(round(width * scale))
165 |
166 | with torch.no_grad():
167 | roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
168 | #roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
169 | # mode='bilinear', align_corners=True)
170 |
171 | return roi_image_nd
172 |
173 |
174 | def check_object_roi(object_roi, clicks_list):
175 | for click in clicks_list:
176 | if click.is_positive:
177 | if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
178 | return False
179 | if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
180 | return False
181 |
182 | return True
183 |
--------------------------------------------------------------------------------
/isegm/inference/utils.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from pathlib import Path
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset
8 |
9 | from isegm.utils.serialization import load_model
10 |
11 |
12 | def get_time_metrics(all_ious, elapsed_time):
13 | n_images = len(all_ious)
14 | n_clicks = sum(map(len, all_ious))
15 |
16 | mean_spc = elapsed_time / n_clicks
17 | mean_spi = elapsed_time / n_images
18 |
19 | return mean_spc, mean_spi
20 |
21 |
22 | def load_is_model(checkpoint, device, **kwargs):
23 | if isinstance(checkpoint, (str, Path)):
24 | state_dict = torch.load(checkpoint, map_location='cpu')
25 | else:
26 | state_dict = checkpoint
27 |
28 | if isinstance(state_dict, list):
29 | model = load_single_is_model(state_dict[0], device, **kwargs)
30 | models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
31 |
32 | return model, models
33 | else:
34 | return load_single_is_model(state_dict, device, **kwargs)
35 |
36 |
37 | def load_single_is_model(state_dict, device, **kwargs):
38 | model = load_model(state_dict['config'], **kwargs)
39 | model.load_state_dict(state_dict['state_dict'], strict=False)
40 |
41 | for param in model.parameters():
42 | param.requires_grad = False
43 | model.to(device)
44 | model.eval()
45 |
46 | return model
47 |
48 |
49 | def get_dataset(dataset_name, cfg):
50 | if dataset_name == 'GrabCut':
51 | dataset = GrabCutDataset(cfg.GRABCUT_PATH)
52 | elif dataset_name == 'Berkeley':
53 | dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
54 | elif dataset_name == 'DAVIS':
55 | dataset = DavisDataset(cfg.DAVIS_PATH)
56 | elif dataset_name == 'SBD':
57 | dataset = SBDEvaluationDataset(cfg.SBD_PATH)
58 | elif dataset_name == 'SBD_Train':
59 | dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train')
60 | else:
61 | dataset = None
62 | return dataset
63 |
64 |
65 | def get_iou(gt_mask, pred_mask, ignore_label=-1):
66 | ignore_gt_mask_inv = gt_mask != ignore_label
67 | obj_gt_mask = gt_mask == 1
68 |
69 | intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
70 | union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
71 |
72 | return intersection / union
73 |
74 |
75 | def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
76 | def _get_noc(iou_arr, iou_thr):
77 | vals = iou_arr >= iou_thr
78 | return np.argmax(vals) + 1 if np.any(vals) else max_clicks
79 |
80 | noc_list = []
81 | over_max_list = []
82 | for iou_thr in iou_thrs:
83 | scores_arr = np.array([_get_noc(iou_arr, iou_thr)
84 | for iou_arr in all_ious], dtype=np.int)
85 |
86 | score = scores_arr.mean()
87 | over_max = (scores_arr == max_clicks).sum()
88 |
89 | noc_list.append(score)
90 | over_max_list.append(over_max)
91 |
92 | return noc_list, over_max_list
93 |
94 |
95 | def find_checkpoint(weights_folder, checkpoint_name):
96 | weights_folder = Path(weights_folder)
97 | if ':' in checkpoint_name:
98 | model_name, checkpoint_name = checkpoint_name.split(':')
99 | models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
100 | assert len(models_candidates) == 1
101 | model_folder = models_candidates[0]
102 | else:
103 | model_folder = weights_folder
104 |
105 | if checkpoint_name.endswith('.pth'):
106 | if Path(checkpoint_name).exists():
107 | checkpoint_path = checkpoint_name
108 | else:
109 | checkpoint_path = weights_folder / checkpoint_name
110 | else:
111 | model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
112 | assert len(model_checkpoints) == 1
113 | checkpoint_path = model_checkpoints[0]
114 |
115 | return str(checkpoint_path)
116 |
117 |
118 | def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
119 | n_clicks=20, model_name=None):
120 | table_header = (f'|{"Pipeline":^13}|{"Dataset":^11}|'
121 | f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
122 | f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
123 | f'{"SPC,s":^7}|{"Time":^9}|')
124 | row_width = len(table_header)
125 |
126 | header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
127 | header += '-' * row_width + '\n'
128 | header += table_header + '\n' + '-' * row_width
129 |
130 | eval_time = str(timedelta(seconds=int(elapsed_time)))
131 | table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
132 | table_row += f'{noc_list[0]:^9.2f}|'
133 | table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
134 | table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
135 | table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
136 | table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
137 | table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
138 |
139 | return header, table_row
140 |
--------------------------------------------------------------------------------
/isegm/model/initializer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class Initializer(object):
7 | def __init__(self, local_init=True, gamma=None):
8 | self.local_init = local_init
9 | self.gamma = gamma
10 |
11 | def __call__(self, m):
12 | if getattr(m, '__initialized', False):
13 | return
14 |
15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
18 | if m.weight is not None:
19 | self._init_gamma(m.weight.data)
20 | if m.bias is not None:
21 | self._init_beta(m.bias.data)
22 | else:
23 | if getattr(m, 'weight', None) is not None:
24 | self._init_weight(m.weight.data)
25 | if getattr(m, 'bias', None) is not None:
26 | self._init_bias(m.bias.data)
27 |
28 | if self.local_init:
29 | object.__setattr__(m, '__initialized', True)
30 |
31 | def _init_weight(self, data):
32 | nn.init.uniform_(data, -0.07, 0.07)
33 |
34 | def _init_bias(self, data):
35 | nn.init.constant_(data, 0)
36 |
37 | def _init_gamma(self, data):
38 | if self.gamma is None:
39 | nn.init.constant_(data, 1.0)
40 | else:
41 | nn.init.normal_(data, 1.0, self.gamma)
42 |
43 | def _init_beta(self, data):
44 | nn.init.constant_(data, 0)
45 |
46 |
47 | class Bilinear(Initializer):
48 | def __init__(self, scale, groups, in_channels, **kwargs):
49 | super().__init__(**kwargs)
50 | self.scale = scale
51 | self.groups = groups
52 | self.in_channels = in_channels
53 |
54 | def _init_weight(self, data):
55 | """Reset the weight and bias."""
56 | bilinear_kernel = self.get_bilinear_kernel(self.scale)
57 | weight = torch.zeros_like(data)
58 | for i in range(self.in_channels):
59 | if self.groups == 1:
60 | j = i
61 | else:
62 | j = 0
63 | weight[i, j] = bilinear_kernel
64 | data[:] = weight
65 |
66 | @staticmethod
67 | def get_bilinear_kernel(scale):
68 | """Generate a bilinear upsampling kernel."""
69 | kernel_size = 2 * scale - scale % 2
70 | scale = (kernel_size + 1) // 2
71 | center = scale - 0.5 * (1 + kernel_size % 2)
72 |
73 | og = np.ogrid[:kernel_size, :kernel_size]
74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
75 |
76 | return torch.tensor(kernel, dtype=torch.float32)
77 |
78 |
79 | class XavierGluon(Initializer):
80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
81 | super().__init__(**kwargs)
82 |
83 | self.rnd_type = rnd_type
84 | self.factor_type = factor_type
85 | self.magnitude = float(magnitude)
86 |
87 | def _init_weight(self, arr):
88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
89 |
90 | if self.factor_type == 'avg':
91 | factor = (fan_in + fan_out) / 2.0
92 | elif self.factor_type == 'in':
93 | factor = fan_in
94 | elif self.factor_type == 'out':
95 | factor = fan_out
96 | else:
97 | raise ValueError('Incorrect factor type')
98 | scale = np.sqrt(self.magnitude / factor)
99 |
100 | if self.rnd_type == 'uniform':
101 | nn.init.uniform_(arr, -scale, scale)
102 | elif self.rnd_type == 'gaussian':
103 | nn.init.normal_(arr, 0, scale)
104 | else:
105 | raise ValueError('Unknown random type')
106 |
--------------------------------------------------------------------------------
/isegm/model/is_gp_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from isegm.model.ops import DistMaps, BatchImageNormalize
7 | from einops import rearrange, repeat
8 | from opt_einsum import contract
9 | import math
10 |
11 | class ISGPModel(nn.Module):
12 | def __init__(self, use_rgb_conv=False, feature_stride = 4, with_aux_output=False,
13 | norm_radius=260, use_disks=False, cpu_dist_maps=False,
14 | clicks_groups=None, with_prev_mask=False, use_leaky_relu=False,
15 | binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d,
16 | norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
17 | super().__init__()
18 | self.with_aux_output = with_aux_output
19 | self.clicks_groups = clicks_groups
20 | self.with_prev_mask = with_prev_mask
21 | self.binary_prev_mask = binary_prev_mask
22 | self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
23 | self.dist_maps = DistMaps(norm_radius=5, spatial_scale=1.0,
24 | cpu_mode=False, use_disks=True)
25 |
26 |
27 | def prepare_input(self, image):
28 | prev_mask = None
29 | if self.with_prev_mask:
30 | prev_mask = image[:, 3:, :, :]
31 | image = image[:, :3, :, :]
32 | if self.binary_prev_mask:
33 | prev_mask = (prev_mask > 0.5).float()
34 |
35 | image = self.normalization(image)
36 | return image, prev_mask
37 |
38 | def get_coord_features(self, image, prev_mask, points):
39 | coord_features = self.dist_maps(image, points)
40 | if prev_mask is not None:
41 | coord_features = torch.cat((prev_mask, coord_features), dim=1)
42 | return coord_features
43 |
44 | def load_pretrained_weights(self, path_to_weights= ''):
45 | state_dict = self.state_dict()
46 | pretrained_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict']
47 | ckpt_keys = set(pretrained_state_dict.keys())
48 | own_keys = set(state_dict.keys())
49 | missing_keys = own_keys - ckpt_keys
50 | unexpected_keys = ckpt_keys - own_keys
51 | print('Missing Keys: ', missing_keys)
52 | print('Unexpected Keys: ', unexpected_keys)
53 | state_dict.update(pretrained_state_dict)
54 | self.load_state_dict(state_dict, strict= False)
55 | '''
56 | if self.inference_mode:
57 | for param in self.backbone.parameters():
58 | param.requires_grad = False
59 | '''
60 |
61 | def get_coord_features(self, image, prev_mask, points):
62 | coord_features = self.dist_maps(image, points)
63 | if prev_mask is not None:
64 | coord_features = torch.cat((prev_mask, coord_features), dim=1)
65 | return coord_features
66 |
67 | def prepare_points_labels(self, points,feature):
68 | pss = []
69 | label_list = []
70 | point_labels = torch.ones([points.size(1),1], dtype=torch.float32, device=feature.device)
71 | point_labels[points.size(1)//2:,:] = -1.
72 | for i in range(points.size(0)):
73 | ps, _ = torch.split(points[i],[2,1],dim=1)
74 | valid_points = torch.logical_and(torch.logical_and(torch.min(ps, dim=1, keepdim=False)[0] >= 0,
75 | ps[:,0] < feature.size(2)), ps[:,1] < feature.size(3) )
76 | ps = ps[valid_points] # n, 2
77 | pss.append(ps)
78 | label = point_labels[valid_points,:] #n,1
79 | label_list.append(label)
80 | return pss, label_list
81 |
82 | def Pathwise_GP_prior(self, feature, omega):
83 | b,d,h,w = feature.size()
84 | phi_f = math.sqrt(2/self.L)*torch.sin(rearrange(self.theta(rearrange(feature, 'b d h w -> (b h w) d')), '(b h w) d->b d h w',b=b,h=h,w=w))
85 | prior = contract('blhw,ls->bshw',phi_f,omega) # b,1,h,w
86 | return prior
87 |
88 | def Pathwise_GP_update(self, points, feature,pss,label_list,result,omega):
89 | b,d,h,w = feature.size()
90 | inv_Kmm_list = []
91 | zf_list = []
92 | point_nums = []
93 | weight = F.softplus(self.weights)
94 |
95 | for i in range(points.size(0)):
96 | ps = pss[i]
97 | if ps.size(0)==0:
98 | point_nums.append(0)
99 | continue
100 | ps = torch.cat([ps[:,[0]].clamp(min=0., max=feature.size(2)-1),ps[:,[1]].clamp(min=0., max=feature.size(3)-1)],1)
101 |
102 | point_nums.append(ps.size(0))
103 | zf = feature[i,:,ps[:,0].long(),ps[:,1].long()].T #n,d
104 | zf_list.append(zf)
105 | norm = torch.norm(torch.exp(self.logsigma2/2)*zf[:,:-3], dim=1,p=2)**2/2 # n,
106 | Kmm = torch.exp(contract('nd,md,d->nm',zf[:,:-3],zf[:,:-3],torch.exp(self.logsigma2))-\
107 | norm.unsqueeze(0).repeat(ps.size(0),1)-norm.unsqueeze(1).repeat(1,ps.size(0)))+\
108 | weight*torch.exp(-torch.sum((zf[:,-3:].unsqueeze(1)-zf[:,-3:])**2,2)/2)
109 |
110 | inv_Kmm_list.append(torch.inverse(Kmm+self.eps2*torch.eye(Kmm.size(0),device=Kmm.device)))
111 |
112 | inv_Kmm = torch.block_diag(*inv_Kmm_list) #n,n
113 | zf = torch.cat(zf_list,dim=0) # n,d
114 | label = torch.cat(label_list,dim=0) # n,1
115 | m = F.softplus(self.u_mlp(zf))*label #n,1
116 |
117 | if self.training:
118 | u = m + 0.01*torch.randn(m.size()).to(feature.device)
119 | u_loss = self.u_loss(inv_Kmm.detach(),m,u,label/2+0.5)
120 | else:
121 | u = m
122 | u_loss = torch.tensor([0.],device=feature.device)
123 |
124 | phi = math.sqrt(2/self.L)*torch.sin(self.theta(zf))
125 |
126 | phi_omega = torch.matmul(phi,omega) # n,1
127 |
128 | v = torch.matmul(inv_Kmm, u-phi_omega) # n,1
129 | num_prev = 0
130 | offset = 0
131 | weight = F.softplus(self.weights)
132 | for i in range(points.size(0)):
133 | if point_nums[i]==0:
134 | offset+=1
135 | continue
136 | norm1 = torch.norm(torch.exp(self.logsigma2/2).view(self.feature_dim,1,1)*feature[i,:-3], dim=0,p=2)**2/2 #h w
137 | norm2 = torch.norm(torch.exp(self.logsigma2/2)*zf_list[i-offset][:,:-3], dim=1,p=2)**2/2 # n,
138 | norm_rgb1 = torch.norm(feature[i,-3:], dim=0,p=2)**2/2 #h w
139 | norm_rgb2 = torch.norm(zf_list[i-offset][:,-3:], dim=1,p=2)**2/2 # n,
140 | Knm = torch.exp(contract('dhw,nd,d->nhw',feature[i,:-3],zf_list[i-offset][:,:-3],torch.exp(self.logsigma2)) -\
141 | repeat(norm1, 'h w -> n h w',n=point_nums[i]) - repeat(norm2, 'n -> n h w',h=h, w=w)) + \
142 | weight*torch.exp(contract('dhw,nd->nhw',feature[i,-3:],zf_list[i-offset][:,-3:])-\
143 | repeat(norm_rgb1, 'h w -> n h w',n=point_nums[i]) - repeat(norm_rgb2, 'n -> n h w',h=h, w=w))
144 | result[i,...] += contract('nhw,ns->shw',Knm, v[num_prev:num_prev+point_nums[i]])
145 | num_prev += point_nums[i]
146 | return result, 0.001*u_loss
147 |
148 | def u_loss(self,invK, m, u, y):
149 | n = invK.size(0)
150 | loss = F.binary_cross_entropy_with_logits(u,y)+ torch.matmul(torch.matmul(m.T,invK),m)/n
151 | return loss
152 |
153 | def split_points_by_order(tpoints: torch.Tensor, groups):
154 | points = tpoints.cpu().numpy()
155 | num_groups = len(groups)
156 | bs = points.shape[0]
157 | num_points = points.shape[1] // 2
158 |
159 | groups = [x if x > 0 else num_points for x in groups]
160 | group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
161 | for x in groups]
162 |
163 | last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
164 | for group_indx, group_size in enumerate(groups):
165 | last_point_indx_group[:, group_indx, 1] = group_size
166 |
167 | for bindx in range(bs):
168 | for pindx in range(2 * num_points):
169 | point = points[bindx, pindx, :]
170 | group_id = int(point[2])
171 | if group_id < 0:
172 | continue
173 |
174 | is_negative = int(pindx >= num_points)
175 | if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
176 | group_id = num_groups - 1
177 |
178 | new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
179 | last_point_indx_group[bindx, group_id, is_negative] += 1
180 |
181 | group_points[group_id][bindx, new_point_indx, :] = point
182 |
183 | group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
184 | for x in group_points]
185 |
186 | return group_points
187 |
--------------------------------------------------------------------------------
/isegm/model/is_gp_resnet50.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from isegm.utils.serialization import serialize
5 | from .is_gp_model import ISGPModel
6 | from isegm.model.ops import ScaleLayer
7 | from .modeling.deeplab_v3_gp import DeepLabV3Plus
8 | from isegm.model.modifiers import LRMult
9 |
10 |
11 | class GpModel(ISGPModel):
12 | @serialize
13 | def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.,
14 | backbone_norm_layer=None, backbone_lr_mult=0.1,
15 | norm_layer=nn.BatchNorm2d, weight_dir=None, **kwargs):
16 | super().__init__(norm_layer=norm_layer, **kwargs)
17 |
18 | self.model = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch,
19 | project_dropout=aspp_dropout, norm_layer=norm_layer,
20 | backbone_norm_layer=backbone_norm_layer, weight_dir=weight_dir)
21 |
22 | side_feature_ch = 256
23 |
24 | self.model.apply(LRMult(backbone_lr_mult))
25 |
26 |
27 | mt_layers = [
28 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
29 | nn.LeakyReLU(negative_slope=0.2),
30 | nn.Conv2d(in_channels=16, out_channels=side_feature_ch, kernel_size=3, stride=1, padding=1),
31 | ScaleLayer(init_value=0.05, lr_mult=1)
32 | ]
33 | self.maps_transform = nn.Sequential(*mt_layers)
34 | self.L=256
35 | self.feature_dim = 48
36 | self.theta = nn.Linear(self.feature_dim+3,self.L)
37 | omega = 0.25*torch.randn(self.L,1)
38 | self.omega = nn.Parameter(omega, requires_grad=True)
39 | omega_var = torch.tensor(0.025)
40 | self.omega_var = nn.Parameter(omega_var, requires_grad=True)
41 |
42 | logsigma2 = torch.ones(self.feature_dim)
43 | self.logsigma2 = nn.Parameter(logsigma2, requires_grad=True)
44 | self.u_mlp = nn.Sequential(
45 | nn.Linear(self.feature_dim+3,96),
46 | nn.ReLU(True),
47 | nn.Linear(96,1)
48 | )
49 |
50 | weight = torch.zeros(1)
51 | self.weights = nn.Parameter(weight, requires_grad=True)
52 | self.eps2 = 1e-2
53 |
54 | def set_status(self, training):
55 | if training:
56 | self.eps2=1e-2
57 | else:
58 | self.eps2=1e-7
59 |
60 | def forward(self, image, points):
61 | image, prev_mask = self.prepare_input(image)
62 | coord_features = self.get_coord_features(image, prev_mask, points)
63 | coord_features = self.maps_transform(coord_features)
64 |
65 | feature = self.model(image, coord_features)
66 | feature = F.normalize(feature, dim=1)
67 |
68 | feature = nn.functional.interpolate(feature, size=image.size()[2:],
69 | mode='bilinear', align_corners=True)
70 | feature = torch.cat([feature, image],1)
71 |
72 | pss, label_list = self.prepare_points_labels(points,feature)
73 | if self.training:
74 | omega = self.omega+self.omega_var.clamp(min=0.01,max=0.05)*torch.randn(self.L,1).to(feature.device)
75 | else:
76 | omega = self.omega
77 | prior= self.Pathwise_GP_prior(feature, omega)
78 | out, u_loss =self.Pathwise_GP_update(points, feature,pss,label_list,prior,omega)
79 | outputs = {'instances': out, 'u_loss':u_loss}
80 | return outputs
81 |
--------------------------------------------------------------------------------
/isegm/model/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from isegm.utils import misc
7 |
8 |
9 | class NormalizedFocalLossSigmoid(nn.Module):
10 | def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12,
11 | from_sigmoid=False, detach_delimeter=True,
12 | batch_axis=0, weight=None, size_average=True,
13 | ignore_label=-1):
14 | super(NormalizedFocalLossSigmoid, self).__init__()
15 | self._axis = axis
16 | self._alpha = alpha
17 | self._gamma = gamma
18 | self._ignore_label = ignore_label
19 | self._weight = weight if weight is not None else 1.0
20 | self._batch_axis = batch_axis
21 |
22 | self._from_logits = from_sigmoid
23 | self._eps = eps
24 | self._size_average = size_average
25 | self._detach_delimeter = detach_delimeter
26 | self._max_mult = max_mult
27 | self._k_sum = 0
28 | self._m_max = 0
29 |
30 | def forward(self, pred, label):
31 | #print(pred.shape, label.shape)
32 | pred = pred.float()
33 | one_hot = label > 0.5
34 | sample_weight = label != self._ignore_label
35 |
36 | if not self._from_logits:
37 | pred = torch.sigmoid(pred)
38 |
39 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
40 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
41 |
42 | beta = (1 - pt) ** self._gamma
43 |
44 | sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
45 | beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
46 | mult = sw_sum / (beta_sum + self._eps)
47 | if self._detach_delimeter:
48 | mult = mult.detach()
49 | beta = beta * mult
50 | if self._max_mult > 0:
51 | beta = torch.clamp_max(beta, self._max_mult)
52 |
53 | with torch.no_grad():
54 | ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
55 | sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
56 | if np.any(ignore_area == 0):
57 | self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
58 |
59 | beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
60 | beta_pmax = beta_pmax.mean().item()
61 | self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
62 |
63 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
64 | loss = self._weight * (loss * sample_weight)
65 |
66 | if self._size_average:
67 | bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
68 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
69 | else:
70 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
71 |
72 | return loss
73 |
74 | def log_states(self, sw, name, global_step):
75 | sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
76 | sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
77 |
78 |
79 |
80 | class DiversityLoss(nn.Module):
81 | def __init__(self):
82 | super(DiversityLoss, self).__init__()
83 | self.baseloss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2)
84 | self.click_loss = ClickLoss() #WFNL(alpha=0.5, gamma=2, w = 0.99)
85 |
86 |
87 | def forward(self, latent_preds, label, click_map):
88 | div_loss_lst = []
89 | click_loss = 0
90 | for i in range(latent_preds.shape[1]):
91 | single_pred = latent_preds[:,i,:,:].unsqueeze(1)
92 | single_loss = self.baseloss(single_pred,label)
93 | single_loss = single_loss.unsqueeze(-1)
94 | div_loss_lst.append(single_loss)
95 | click_loss += self.click_loss(single_pred,label,click_map)
96 |
97 | div_losses = torch.cat(div_loss_lst,1)
98 | div_loss_min = torch.min(div_losses,dim=1)[0]
99 | return div_loss_min.mean() + click_loss.mean()
100 |
101 |
102 |
103 | class WFNL(nn.Module):
104 | def __init__(self, axis=-1, alpha=0.25, gamma=2, w = 0.5, max_mult=-1, eps=1e-12,
105 | from_sigmoid=False, detach_delimeter=True,
106 | batch_axis=0, weight=None, size_average=True,
107 | ignore_label=-1):
108 | super(WFNL, self).__init__()
109 | self._axis = axis
110 | self._alpha = alpha
111 | self._gamma = gamma
112 | self._ignore_label = ignore_label
113 | self._weight = weight if weight is not None else 1.0
114 | self._batch_axis = batch_axis
115 |
116 | self._from_logits = from_sigmoid
117 | self._eps = eps
118 | self._size_average = size_average
119 | self._detach_delimeter = detach_delimeter
120 | self._max_mult = max_mult
121 | self._k_sum = 0
122 | self._m_max = 0
123 | self.w = w
124 |
125 | def forward(self, pred, label, weight = None):
126 | one_hot = label > 0.5
127 | sample_weight = label != self._ignore_label
128 |
129 | if not self._from_logits:
130 | pred = torch.sigmoid(pred)
131 |
132 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
133 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
134 |
135 | beta = (1 - pt) ** self._gamma
136 |
137 | sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
138 | beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
139 | mult = sw_sum / (beta_sum + self._eps)
140 | if self._detach_delimeter:
141 | mult = mult.detach()
142 | beta = beta * mult
143 | if self._max_mult > 0:
144 | beta = torch.clamp_max(beta, self._max_mult)
145 |
146 | with torch.no_grad():
147 | ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
148 | sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
149 | if np.any(ignore_area == 0):
150 | self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
151 |
152 | beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
153 | beta_pmax = beta_pmax.mean().item()
154 | self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
155 |
156 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
157 | loss = self._weight * (loss * sample_weight)
158 |
159 | if weight is not None:
160 | weight = weight * self.w + (1-self.w)
161 | loss = (loss * weight).sum() / (weight.sum() + self._eps)
162 | else:
163 | bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
164 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
165 | return loss
166 |
167 | def log_states(self, sw, name, global_step):
168 | sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
169 | sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
170 |
171 |
172 | class FocalLoss(nn.Module):
173 | def __init__(self, axis=-1, alpha=0.25, gamma=2,
174 | from_logits=False, batch_axis=0,
175 | weight=None, num_class=None,
176 | eps=1e-9, size_average=True, scale=1.0,
177 | ignore_label=-1):
178 | super(FocalLoss, self).__init__()
179 | self._axis = axis
180 | self._alpha = alpha
181 | self._gamma = gamma
182 | self._ignore_label = ignore_label
183 | self._weight = weight if weight is not None else 1.0
184 | self._batch_axis = batch_axis
185 |
186 | self._scale = scale
187 | self._num_class = num_class
188 | self._from_logits = from_logits
189 | self._eps = eps
190 | self._size_average = size_average
191 |
192 | def forward(self, pred, label, sample_weight=None):
193 | one_hot = label > 0.5
194 | sample_weight = label != self._ignore_label
195 |
196 | if not self._from_logits:
197 | pred = torch.sigmoid(pred)
198 |
199 | alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
200 | pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
201 |
202 | beta = (1 - pt) ** self._gamma
203 |
204 | loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
205 | loss = self._weight * (loss * sample_weight)
206 |
207 | if self._size_average:
208 | tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
209 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
210 | else:
211 | loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
212 |
213 | return self._scale * loss
214 |
215 |
216 | class SoftIoU(nn.Module):
217 | def __init__(self, from_sigmoid=False, ignore_label=-1):
218 | super().__init__()
219 | self._from_sigmoid = from_sigmoid
220 | self._ignore_label = ignore_label
221 |
222 | def forward(self, pred, label):
223 | label = label.view(pred.size())
224 | sample_weight = label != self._ignore_label
225 |
226 | if not self._from_sigmoid:
227 | pred = torch.sigmoid(pred)
228 |
229 | loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \
230 | / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8)
231 |
232 | return loss
233 |
234 |
235 | class SigmoidBinaryCrossEntropyLoss(nn.Module):
236 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
237 | super(SigmoidBinaryCrossEntropyLoss, self).__init__()
238 | self._from_sigmoid = from_sigmoid
239 | self._ignore_label = ignore_label
240 | self._weight = weight if weight is not None else 1.0
241 | self._batch_axis = batch_axis
242 |
243 | def forward(self, pred, label):
244 | label = label.view(pred.size())
245 | sample_weight = label != self._ignore_label
246 | label = torch.where(sample_weight, label, torch.zeros_like(label))
247 |
248 | if not self._from_sigmoid:
249 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
250 | else:
251 | eps = 1e-12
252 | loss = -(torch.log(pred + eps) * label
253 | + torch.log(1. - pred + eps) * (1. - label))
254 |
255 | loss = self._weight * (loss * sample_weight)
256 | return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
257 |
258 |
259 | class WeightedSigmoidBinaryCrossEntropyLoss(nn.Module):
260 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
261 | super(WeightedSigmoidBinaryCrossEntropyLoss, self).__init__()
262 | self._from_sigmoid = from_sigmoid
263 | self._ignore_label = ignore_label
264 | self._weight = weight if weight is not None else 1.0
265 | self._batch_axis = batch_axis
266 |
267 | def forward(self, pred, label, weight):
268 | label = label.view(pred.size())
269 | sample_weight = label != self._ignore_label
270 | label = torch.where(sample_weight, label, torch.zeros_like(label))
271 |
272 | if not self._from_sigmoid:
273 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
274 | else:
275 | eps = 1e-12
276 | loss = -(torch.log(pred + eps) * label
277 | + torch.log(1. - pred + eps) * (1. - label))
278 | #weight = weight * 0.8 + 0.2
279 | loss = (weight * loss).sum() / weight.sum()
280 | return loss #torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
281 |
282 |
283 |
284 |
285 | class ClickLoss(nn.Module):
286 | def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1, alpha = 0.99, beta = 0.01):
287 | super(ClickLoss, self).__init__()
288 | self._from_sigmoid = from_sigmoid
289 | self._ignore_label = ignore_label
290 | self._weight = weight if weight is not None else 1.0
291 | self._batch_axis = batch_axis
292 | self.alpha = alpha
293 | self.beta = beta
294 |
295 |
296 | def forward(self, pred, label, gaussian_maps = None):
297 | h_gt, w_gt = label.shape[-2],label.shape[-1]
298 | h_p, w_p = pred.shape[-2], pred.shape[-1]
299 | if h_gt != h_p or w_gt != w_p:
300 | pred = F.interpolate(pred, size=label.size()[-2:],
301 | mode='bilinear', align_corners=True)
302 |
303 |
304 | label = label.view(pred.size())
305 | sample_weight = label != self._ignore_label
306 | label = torch.where(sample_weight, label, torch.zeros_like(label))
307 |
308 | if not self._from_sigmoid:
309 | loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
310 | else:
311 | eps = 1e-12
312 | loss = -(torch.log(pred + eps) * label
313 | + torch.log(1. - pred + eps) * (1. - label))
314 |
315 | loss = self._weight * (loss * sample_weight)
316 | weight_map = gaussian_maps.max(dim=1,keepdim = True)[0] * self.alpha + self.beta
317 | loss = (loss * weight_map).sum() / weight_map.sum()
318 | return loss
--------------------------------------------------------------------------------
/isegm/model/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from isegm.utils import misc
5 |
6 |
7 | class TrainMetric(object):
8 | def __init__(self, pred_outputs, gt_outputs):
9 | self.pred_outputs = pred_outputs
10 | self.gt_outputs = gt_outputs
11 |
12 | def update(self, *args, **kwargs):
13 | raise NotImplementedError
14 |
15 | def get_epoch_value(self):
16 | raise NotImplementedError
17 |
18 | def reset_epoch_stats(self):
19 | raise NotImplementedError
20 |
21 | def log_states(self, sw, tag_prefix, global_step):
22 | pass
23 |
24 | @property
25 | def name(self):
26 | return type(self).__name__
27 |
28 |
29 | class AdaptiveIoU(TrainMetric):
30 | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
31 | ignore_label=-1, from_logits=True,
32 | pred_output='instances', gt_output='instances'):
33 | super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
34 | self._ignore_label = ignore_label
35 | self._from_logits = from_logits
36 | self._iou_thresh = init_thresh
37 | self._thresh_step = thresh_step
38 | self._thresh_beta = thresh_beta
39 | self._iou_beta = iou_beta
40 | self._ema_iou = 0.0
41 | self._epoch_iou_sum = 0.0
42 | self._epoch_batch_count = 0
43 |
44 | def update(self, pred, gt):
45 | gt_mask = gt > 0.5
46 | if self._from_logits:
47 | pred = torch.sigmoid(pred)
48 |
49 | gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
50 | if np.all(gt_mask_area == 0):
51 | return
52 |
53 | ignore_mask = gt == self._ignore_label
54 | max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
55 | best_thresh = self._iou_thresh
56 | for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
57 | temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
58 | if temp_iou > max_iou:
59 | max_iou = temp_iou
60 | best_thresh = t
61 |
62 | self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
63 | self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
64 | self._epoch_iou_sum += max_iou
65 | self._epoch_batch_count += 1
66 |
67 | def get_epoch_value(self):
68 | if self._epoch_batch_count > 0:
69 | return self._epoch_iou_sum / self._epoch_batch_count
70 | else:
71 | return 0.0
72 |
73 | def reset_epoch_stats(self):
74 | self._epoch_iou_sum = 0.0
75 | self._epoch_batch_count = 0
76 |
77 | def log_states(self, sw, tag_prefix, global_step):
78 | sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
79 | sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
80 |
81 | @property
82 | def iou_thresh(self):
83 | return self._iou_thresh
84 |
85 |
86 | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
87 | if ignore_mask is not None:
88 | pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
89 |
90 | reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
91 | union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
92 | intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
93 | nonzero = union > 0
94 |
95 | iou = intersection[nonzero] / union[nonzero]
96 | if not keep_ignore:
97 | return iou
98 | else:
99 | result = np.full_like(intersection, -1)
100 | result[nonzero] = iou
101 | return result
102 |
--------------------------------------------------------------------------------
/isegm/model/modeling/basic_blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from isegm.model import ops
4 |
5 |
6 | class ConvHead(nn.Module):
7 | def __init__(self, out_channels, in_channels=32, num_layers=1,
8 | kernel_size=3, padding=1,
9 | norm_layer=nn.BatchNorm2d):
10 | super(ConvHead, self).__init__()
11 | convhead = []
12 |
13 | for i in range(num_layers):
14 | convhead.extend([
15 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
16 | nn.ReLU(),
17 | norm_layer(in_channels) if norm_layer is not None else nn.Identity()
18 | ])
19 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
20 |
21 | self.convhead = nn.Sequential(*convhead)
22 |
23 | def forward(self, *inputs):
24 | return self.convhead(inputs[0])
25 |
26 |
27 | class SepConvHead(nn.Module):
28 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
29 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
30 | norm_layer=nn.BatchNorm2d):
31 | super(SepConvHead, self).__init__()
32 |
33 | sepconvhead = []
34 |
35 | for i in range(num_layers):
36 | sepconvhead.append(
37 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
38 | out_channels=mid_channels,
39 | dw_kernel=kernel_size, dw_padding=padding,
40 | norm_layer=norm_layer, activation='relu')
41 | )
42 | if dropout_ratio > 0 and dropout_indx == i:
43 | sepconvhead.append(nn.Dropout(dropout_ratio))
44 |
45 | sepconvhead.append(
46 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
47 | )
48 |
49 | self.layers = nn.Sequential(*sepconvhead)
50 |
51 | def forward(self, *inputs):
52 | x = inputs[0]
53 |
54 | return self.layers(x)
55 |
56 |
57 | class SeparableConv2d(nn.Module):
58 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
59 | activation=None, use_bias=False, norm_layer=None):
60 | super(SeparableConv2d, self).__init__()
61 | _activation = ops.select_activation_function(activation)
62 | self.body = nn.Sequential(
63 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
64 | padding=dw_padding, bias=use_bias, groups=in_channels),
65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
66 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
67 | _activation()
68 | )
69 |
70 | def forward(self, x):
71 | return self.body(x)
72 |
--------------------------------------------------------------------------------
/isegm/model/modeling/deeplab_v3.py:
--------------------------------------------------------------------------------
1 | from contextlib import ExitStack
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | from .basic_blocks import SeparableConv2d
8 | from .resnet import ResNetBackbone
9 | from isegm.model import ops
10 | from isegm.model.modeling.cdnet.FDM import FDM, FDM_v2, FDM_v3
11 |
12 | class DeepLabV3Plus(nn.Module):
13 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
14 | backbone_norm_layer=None,
15 | ch=256,
16 | project_dropout=0.5,
17 | inference_mode=False,
18 | **kwargs):
19 | super(DeepLabV3Plus, self).__init__()
20 | if backbone_norm_layer is None:
21 | backbone_norm_layer = norm_layer
22 |
23 | self.backbone_name = backbone
24 | self.norm_layer = norm_layer
25 | self.backbone_norm_layer = backbone_norm_layer
26 | self.inference_mode = False
27 | self.ch = ch
28 | self.aspp_in_channels = 2048
29 | self.skip_project_in_channels = 256 # layer 1 out_channels
30 |
31 | self._kwargs = kwargs
32 |
33 | if backbone == 'resnet34' or 'resnet18':
34 | self.aspp_in_channels = 512
35 | self.skip_project_in_channels = 64
36 | else:
37 | self.aspp_in_channels = 512 * 4
38 | self.skip_project_in_channels = 64 * 4
39 |
40 |
41 |
42 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
43 | norm_layer=self.backbone_norm_layer, **kwargs)
44 |
45 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
46 | norm_layer=self.norm_layer)
47 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
48 | self.aspp = _ASPP(in_channels=self.aspp_in_channels,
49 | atrous_rates=[12, 24, 36],
50 | out_channels=ch,
51 | project_dropout=project_dropout,
52 | norm_layer=self.norm_layer)
53 | self.FDM = FDM_v3(self.ch,self.ch)
54 |
55 | if inference_mode:
56 | self.set_prediction_mode()
57 |
58 | def load_pretrained_weights(self):
59 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
60 | norm_layer=self.backbone_norm_layer, **self._kwargs)
61 | backbone_state_dict = self.backbone.state_dict()
62 | pretrained_state_dict = pretrained.state_dict()
63 |
64 | backbone_state_dict.update(pretrained_state_dict)
65 | self.backbone.load_state_dict(backbone_state_dict)
66 |
67 | if self.inference_mode:
68 | for param in self.backbone.parameters():
69 | param.requires_grad = False
70 |
71 | def set_prediction_mode(self):
72 | self.inference_mode = True
73 | self.eval()
74 |
75 | def forward(self, x, additional_features, small_clicks):
76 |
77 | with ExitStack() as stack:
78 | if self.inference_mode:
79 | stack.enter_context(torch.no_grad())
80 |
81 | c1, _, c3, c4 = self.backbone(x, additional_features)
82 | c1 = self.skip_project(c1)
83 |
84 | x = self.aspp(c4)
85 | x, pos_map, neg_map = self.FDM(x, small_clicks)
86 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
87 | x = torch.cat((x, c1), dim=1)
88 | x = self.head(x)
89 | return x,pos_map
90 |
91 |
92 | class _SkipProject(nn.Module):
93 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
94 | super(_SkipProject, self).__init__()
95 | _activation = ops.select_activation_function("relu")
96 |
97 | self.skip_project = nn.Sequential(
98 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
99 | norm_layer(out_channels),
100 | _activation()
101 | )
102 |
103 | def forward(self, x):
104 | return self.skip_project(x)
105 |
106 |
107 | class _DeepLabHead(nn.Module):
108 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
109 | super(_DeepLabHead, self).__init__()
110 |
111 | self.block = nn.Sequential(
112 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
113 | dw_padding=1, activation='relu', norm_layer=norm_layer),
114 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
115 | dw_padding=1, activation='relu', norm_layer=norm_layer),
116 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
117 | )
118 |
119 | def forward(self, x):
120 | return self.block(x)
121 |
122 |
123 | class _ASPP(nn.Module):
124 | def __init__(self, in_channels, atrous_rates, out_channels=256,
125 | project_dropout=0.5, norm_layer=nn.BatchNorm2d):
126 | super(_ASPP, self).__init__()
127 |
128 | b0 = nn.Sequential(
129 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
130 | norm_layer(out_channels),
131 | nn.ReLU()
132 | )
133 |
134 | rate1, rate2, rate3 = tuple(atrous_rates)
135 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
136 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
137 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
138 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
139 |
140 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
141 |
142 | project = [
143 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
144 | kernel_size=1, bias=False),
145 | norm_layer(out_channels),
146 | nn.ReLU()
147 | ]
148 | if project_dropout > 0:
149 | project.append(nn.Dropout(project_dropout))
150 | self.project = nn.Sequential(*project)
151 |
152 | def forward(self, x):
153 | x = torch.cat([block(x) for block in self.concurent], dim=1)
154 |
155 | return self.project(x)
156 |
157 |
158 | class _AsppPooling(nn.Module):
159 | def __init__(self, in_channels, out_channels, norm_layer):
160 | super(_AsppPooling, self).__init__()
161 |
162 | self.gap = nn.Sequential(
163 | nn.AdaptiveAvgPool2d((1, 1)),
164 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
165 | kernel_size=1, bias=False),
166 | norm_layer(out_channels),
167 | nn.ReLU()
168 | )
169 |
170 | def forward(self, x):
171 | pool = self.gap(x)
172 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
173 |
174 |
175 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
176 | block = nn.Sequential(
177 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
178 | kernel_size=3, padding=atrous_rate,
179 | dilation=atrous_rate, bias=False),
180 | norm_layer(out_channels),
181 | nn.ReLU()
182 | )
183 |
184 | return block
185 |
--------------------------------------------------------------------------------
/isegm/model/modeling/deeplab_v3_gp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from .basic_blocks import SeparableConv2d
6 | from .resnet import ResNetBackbone
7 | from isegm.model import ops
8 |
9 | class DeepLabV3Plus(nn.Module):
10 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
11 | backbone_norm_layer=None,
12 | ch=256,
13 | project_dropout=0.5,
14 | inference_mode=False,
15 | weight_dir=None,
16 | **kwargs):
17 | super(DeepLabV3Plus, self).__init__()
18 | if backbone_norm_layer is None:
19 | backbone_norm_layer = norm_layer
20 |
21 | self.backbone_name = backbone
22 | self.norm_layer = norm_layer
23 | self.backbone_norm_layer = backbone_norm_layer
24 | self.inference_mode = False
25 | self.ch = ch
26 | self.aspp_in_channels = 2048
27 | self.skip_project_in_channels = 256 # layer 1 out_channels
28 | self.weight_dir=weight_dir
29 |
30 | self._kwargs = kwargs
31 |
32 | self.aspp_in_channels = 512 * 4
33 | self.skip_project_in_channels = 64 * 4
34 |
35 |
36 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
37 | norm_layer=self.backbone_norm_layer,weight_dir=weight_dir, **kwargs)
38 |
39 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=48,
40 | norm_layer=self.norm_layer)
41 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
42 | self.aspp = _ASPP(in_channels=self.aspp_in_channels,
43 | atrous_rates=[12, 24, 36],
44 | out_channels=ch,
45 | project_dropout=project_dropout,
46 | norm_layer=self.norm_layer)
47 | self.feature_dim = 256
48 |
49 | if inference_mode:
50 | self.set_prediction_mode()
51 |
52 | def load_pretrained_weights(self):
53 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
54 | norm_layer=self.backbone_norm_layer,weight_dir=self.weight_dir, **self._kwargs)
55 | backbone_state_dict = self.backbone.state_dict()
56 | pretrained_state_dict = pretrained.state_dict()
57 |
58 | backbone_state_dict.update(pretrained_state_dict)
59 | self.backbone.load_state_dict(backbone_state_dict)
60 |
61 | if self.inference_mode:
62 | for param in self.backbone.parameters():
63 | param.requires_grad = False
64 |
65 | def set_prediction_mode(self):
66 | self.inference_mode = True
67 | self.eval()
68 |
69 | def forward(self, x, additional_features):
70 |
71 | c1, _, _, c4 = self.backbone(x, additional_features)
72 |
73 | c1 = self.skip_project(c1)
74 |
75 | x = self.aspp(c4)
76 |
77 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
78 | x = torch.cat((x, c1), dim=1)
79 | return self.head(x)
80 |
81 | class _SkipProject(nn.Module):
82 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
83 | super(_SkipProject, self).__init__()
84 | _activation = ops.select_activation_function("relu")
85 |
86 | self.skip_project = nn.Sequential(
87 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
88 | norm_layer(out_channels),
89 | _activation()
90 | )
91 |
92 | def forward(self, x):
93 | return self.skip_project(x)
94 |
95 |
96 | class _DeepLabHead(nn.Module):
97 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
98 | super(_DeepLabHead, self).__init__()
99 |
100 | self.block = nn.Sequential(
101 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
102 | dw_padding=1, activation='relu', norm_layer=norm_layer),
103 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
104 | dw_padding=1, activation='relu', norm_layer=norm_layer),
105 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
106 | )
107 |
108 | def forward(self, x):
109 | return self.block(x)
110 |
111 |
112 | class _ASPP(nn.Module):
113 | def __init__(self, in_channels, atrous_rates, out_channels=256,
114 | project_dropout=0.5, norm_layer=nn.BatchNorm2d):
115 | super(_ASPP, self).__init__()
116 |
117 | b0 = nn.Sequential(
118 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
119 | norm_layer(out_channels),
120 | nn.ReLU()
121 | )
122 |
123 | rate1, rate2, rate3 = tuple(atrous_rates)
124 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
125 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
126 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
127 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
128 |
129 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
130 |
131 | project = [
132 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
133 | kernel_size=1, bias=False),
134 | norm_layer(out_channels),
135 | nn.ReLU()
136 | ]
137 | if project_dropout > 0:
138 | project.append(nn.Dropout(project_dropout))
139 | self.project = nn.Sequential(*project)
140 |
141 | def forward(self, x):
142 | x = torch.cat([block(x) for block in self.concurent], dim=1)
143 |
144 | return self.project(x)
145 |
146 |
147 | class _AsppPooling(nn.Module):
148 | def __init__(self, in_channels, out_channels, norm_layer):
149 | super(_AsppPooling, self).__init__()
150 |
151 | self.gap = nn.Sequential(
152 | nn.AdaptiveAvgPool2d((1, 1)),
153 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
154 | kernel_size=1, bias=False),
155 | norm_layer(out_channels),
156 | nn.ReLU()
157 | )
158 |
159 | def forward(self, x):
160 | pool = self.gap(x)
161 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
162 |
163 |
164 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
165 | block = nn.Sequential(
166 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
167 | kernel_size=3, padding=atrous_rate,
168 | dilation=atrous_rate, bias=False),
169 | norm_layer(out_channels),
170 | nn.ReLU()
171 | )
172 |
173 | return block
174 |
--------------------------------------------------------------------------------
/isegm/model/modeling/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .resnetv1b import resnet18_v1b, resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
3 |
4 |
5 | class ResNetBackbone(torch.nn.Module):
6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True,weight_dir=None, **kwargs):
7 | super(ResNetBackbone, self).__init__()
8 |
9 | if backbone == 'resnet18':
10 | pretrained = resnet18_v1b(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs)
11 | elif backbone == 'resnet34':
12 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs)
13 | elif backbone == 'resnet50':
14 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs)
15 | elif backbone == 'resnet101':
16 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs)
17 | elif backbone == 'resnet152':
18 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated,weight_dir=weight_dir, **kwargs)
19 | else:
20 | raise RuntimeError(f'unknown backbone: {backbone}')
21 |
22 | self.conv1 = pretrained.conv1
23 | self.bn1 = pretrained.bn1
24 | self.relu = pretrained.relu
25 | self.maxpool = pretrained.maxpool
26 | self.layer1 = pretrained.layer1
27 | self.layer2 = pretrained.layer2
28 | self.layer3 = pretrained.layer3
29 | self.layer4 = pretrained.layer4
30 |
31 | def forward(self, x, additional_features=None):
32 | x = self.conv1(x)
33 | x = self.bn1(x)
34 | x = self.relu(x)
35 | if additional_features is not None:
36 | x = x + torch.nn.functional.pad(additional_features,
37 | [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
38 | mode='constant', value=0)
39 | x = self.maxpool(x)
40 | c1 = self.layer1(x)
41 | c2 = self.layer2(c1)
42 | c3 = self.layer3(c2)
43 | c4 = self.layer4(c3)
44 |
45 | return c1, c2, c3, c4
46 |
--------------------------------------------------------------------------------
/isegm/model/modeling/resnetv1b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' # This is open source, not me
4 |
5 |
6 | class BasicBlockV1b(nn.Module):
7 | expansion = 1
8 |
9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
10 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
11 | super(BasicBlockV1b, self).__init__()
12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
13 | padding=dilation, dilation=dilation, bias=False)
14 | self.bn1 = norm_layer(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
16 | padding=previous_dilation, dilation=previous_dilation, bias=False)
17 | self.bn2 = norm_layer(planes)
18 |
19 | self.relu = nn.ReLU(inplace=True)
20 | self.downsample = downsample
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 |
33 | if self.downsample is not None:
34 | residual = self.downsample(x)
35 |
36 | out = out + residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 |
42 | class BottleneckV1b(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
46 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
47 | super(BottleneckV1b, self).__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49 | self.bn1 = norm_layer(planes)
50 |
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52 | padding=dilation, dilation=dilation, bias=False)
53 | self.bn2 = norm_layer(planes)
54 |
55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
56 | self.bn3 = norm_layer(planes * self.expansion)
57 |
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out = out + residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNetV1b(nn.Module):
86 | """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
87 |
88 | Parameters
89 | ----------
90 | block : Block
91 | Class for the residual block. Options are BasicBlockV1, BottleneckV1.
92 | layers : list of int
93 | Numbers of layers in each block
94 | classes : int, default 1000
95 | Number of classification classes.
96 | dilated : bool, default False
97 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
98 | typically used in Semantic Segmentation.
99 | norm_layer : object
100 | Normalization layer used (default: :class:`nn.BatchNorm2d`)
101 | deep_stem : bool, default False
102 | Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
103 | avg_down : bool, default False
104 | Whether to use average pooling for projection skip connection between stages/downsample.
105 | final_drop : float, default 0.0
106 | Dropout ratio before the final classification layer.
107 |
108 | Reference:
109 | - He, Kaiming, et al. "Deep residual learning for image recognition."
110 | Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
111 |
112 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
113 | """
114 | def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
115 | avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
116 | self.inplanes = stem_width*2 if deep_stem else 64
117 | super(ResNetV1b, self).__init__()
118 | if not deep_stem:
119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
120 | else:
121 | self.conv1 = nn.Sequential(
122 | nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
123 | norm_layer(stem_width),
124 | nn.ReLU(True),
125 | nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
126 | norm_layer(stem_width),
127 | nn.ReLU(True),
128 | nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
129 | )
130 | self.bn1 = norm_layer(self.inplanes)
131 | self.relu = nn.ReLU(True)
132 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
133 | self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
134 | norm_layer=norm_layer)
135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
136 | norm_layer=norm_layer)
137 | if dilated:
138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
139 | avg_down=avg_down, norm_layer=norm_layer)
140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
141 | avg_down=avg_down, norm_layer=norm_layer)
142 | else:
143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144 | avg_down=avg_down, norm_layer=norm_layer)
145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146 | avg_down=avg_down, norm_layer=norm_layer)
147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148 | self.drop = None
149 | if final_drop > 0.0:
150 | self.drop = nn.Dropout(final_drop)
151 | self.fc = nn.Linear(512 * block.expansion, classes)
152 |
153 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
154 | avg_down=False, norm_layer=nn.BatchNorm2d):
155 | downsample = None
156 | if stride != 1 or self.inplanes != planes * block.expansion:
157 | downsample = []
158 | if avg_down:
159 | if dilation == 1:
160 | downsample.append(
161 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
162 | )
163 | else:
164 | downsample.append(
165 | nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
166 | )
167 | downsample.extend([
168 | nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
169 | kernel_size=1, stride=1, bias=False),
170 | norm_layer(planes * block.expansion)
171 | ])
172 | downsample = nn.Sequential(*downsample)
173 | else:
174 | downsample = nn.Sequential(
175 | nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
176 | kernel_size=1, stride=stride, bias=False),
177 | norm_layer(planes * block.expansion)
178 | )
179 |
180 | layers = []
181 | if dilation in (1, 2):
182 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
183 | previous_dilation=dilation, norm_layer=norm_layer))
184 | elif dilation == 4:
185 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
186 | previous_dilation=dilation, norm_layer=norm_layer))
187 | else:
188 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
189 |
190 | self.inplanes = planes * block.expansion
191 | for _ in range(1, blocks):
192 | layers.append(block(self.inplanes, planes, dilation=dilation,
193 | previous_dilation=dilation, norm_layer=norm_layer))
194 |
195 | return nn.Sequential(*layers)
196 |
197 | def forward(self, x):
198 | x = self.conv1(x)
199 | x = self.bn1(x)
200 | x = self.relu(x)
201 | x = self.maxpool(x)
202 |
203 | x = self.layer1(x)
204 | x = self.layer2(x)
205 | x = self.layer3(x)
206 | x = self.layer4(x)
207 |
208 | x = self.avgpool(x)
209 | x = x.view(x.size(0), -1)
210 | if self.drop is not None:
211 | x = self.drop(x)
212 | x = self.fc(x)
213 |
214 | return x
215 |
216 |
217 | def _safe_state_dict_filtering(orig_dict, model_dict_keys):
218 | filtered_orig_dict = {}
219 | for k, v in orig_dict.items():
220 | if k in model_dict_keys:
221 | filtered_orig_dict[k] = v
222 | else:
223 | print(f"[ERROR] Failed to load <{k}> in backbone")
224 | return filtered_orig_dict
225 |
226 |
227 |
228 | def resnet18_v1b(pretrained=False, **kwargs):
229 | model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], **kwargs)
230 | if pretrained:
231 | pass
232 | return model
233 |
234 |
235 | def resnet34_v1b(pretrained=False, weight_dir=None, **kwargs):
236 | model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
237 | if pretrained:
238 | model_dict = model.state_dict()
239 | if weight_dir is None:
240 | filtered_orig_dict = _safe_state_dict_filtering(
241 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
242 | model_dict.keys()
243 | )
244 | else:
245 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu')#['state_dict']
246 | model_dict.update(filtered_orig_dict)
247 | model.load_state_dict(model_dict)
248 | return model
249 |
250 |
251 | def resnet50_v1s(pretrained=False, weight_dir=None, **kwargs):
252 | model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
253 | if pretrained:
254 | model_dict = model.state_dict()
255 | if weight_dir is None:
256 | filtered_orig_dict = _safe_state_dict_filtering(
257 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
258 | model_dict.keys()
259 | )
260 | else:
261 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu')
262 | model_dict.update(filtered_orig_dict)
263 | model.load_state_dict(model_dict)
264 | return model
265 |
266 |
267 | def resnet101_v1s(pretrained=False, weight_dir=None, **kwargs):
268 | model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
269 | if pretrained:
270 | model_dict = model.state_dict()
271 | if weight_dir is None:
272 | filtered_orig_dict = _safe_state_dict_filtering(
273 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
274 | model_dict.keys()
275 | )
276 | else:
277 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu')
278 | model_dict.update(filtered_orig_dict)
279 | model.load_state_dict(model_dict)
280 | return model
281 |
282 |
283 | def resnet152_v1s(pretrained=False, weight_dir=None, **kwargs):
284 | model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
285 | if pretrained:
286 | model_dict = model.state_dict()
287 | if weight_dir is None:
288 | filtered_orig_dict = _safe_state_dict_filtering(
289 | torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
290 | model_dict.keys()
291 | )
292 | else:
293 | filtered_orig_dict = torch.load(weight_dir,map_location='cpu')
294 | model_dict.update(filtered_orig_dict)
295 | model.load_state_dict(model_dict)
296 | return model
297 |
--------------------------------------------------------------------------------
/isegm/model/modifiers.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class LRMult(object):
4 | def __init__(self, lr_mult=1.):
5 | self.lr_mult = lr_mult
6 |
7 | def __call__(self, m):
8 | if getattr(m, 'weight', None) is not None:
9 | m.weight.lr_mult = self.lr_mult
10 | if getattr(m, 'bias', None) is not None:
11 | m.bias.lr_mult = self.lr_mult
12 |
--------------------------------------------------------------------------------
/isegm/model/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | import numpy as np
4 | import isegm.model.initializer as initializer
5 |
6 |
7 | def select_activation_function(activation):
8 | if isinstance(activation, str):
9 | if activation.lower() == 'relu':
10 | return nn.ReLU
11 | elif activation.lower() == 'softplus':
12 | return nn.Softplus
13 | else:
14 | raise ValueError(f"Unknown activation type {activation}")
15 | elif isinstance(activation, nn.Module):
16 | return activation
17 | else:
18 | raise ValueError(f"Unknown activation type {activation}")
19 |
20 |
21 | class BilinearConvTranspose2d(nn.ConvTranspose2d):
22 | def __init__(self, in_channels, out_channels, scale, groups=1):
23 | kernel_size = 2 * scale - scale % 2
24 | self.scale = scale
25 |
26 | super().__init__(
27 | in_channels, out_channels,
28 | kernel_size=kernel_size,
29 | stride=scale,
30 | padding=1,
31 | groups=groups,
32 | bias=False)
33 |
34 | self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
35 |
36 |
37 | class DistMaps(nn.Module):
38 | def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False):
39 | super(DistMaps, self).__init__()
40 | self.spatial_scale = spatial_scale
41 | self.norm_radius = norm_radius
42 | self.cpu_mode = cpu_mode
43 | self.use_disks = use_disks
44 | if self.cpu_mode:
45 | from isegm.utils.cython import get_dist_maps
46 | self._get_dist_maps = get_dist_maps
47 |
48 | def get_coord_features(self, points, batchsize, rows, cols):
49 | if self.cpu_mode:
50 | coords = []
51 | for i in range(batchsize):
52 | norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius
53 | coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
54 | norm_delimeter))
55 | coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
56 | else:
57 | num_points = points.shape[1] // 2
58 | points = points.view(-1, points.size(2))
59 | points, points_order = torch.split(points, [2, 1], dim=1)
60 |
61 | invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
62 | row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
63 | col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
64 |
65 | coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
66 | coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
67 |
68 | add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
69 | coords.add_(-add_xy)
70 | if not self.use_disks:
71 | coords.div_(self.norm_radius * self.spatial_scale)
72 | coords.mul_(coords)
73 |
74 | coords[:, 0] += coords[:, 1]
75 | coords = coords[:, :1]
76 |
77 | coords[invalid_points, :, :, :] = 1e6
78 |
79 | coords = coords.view(-1, num_points, 1, rows, cols)
80 | coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
81 | coords = coords.view(-1, 2, rows, cols)
82 |
83 | if self.use_disks:
84 | coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float()
85 | else:
86 | coords.sqrt_().mul_(2).tanh_()
87 |
88 | return coords
89 |
90 | def forward(self, x, coords):
91 | return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])
92 |
93 |
94 | class ScaleLayer(nn.Module):
95 | def __init__(self, init_value=1.0, lr_mult=1):
96 | super().__init__()
97 | self.lr_mult = lr_mult
98 | self.scale = nn.Parameter(
99 | torch.full((1,), init_value / lr_mult, dtype=torch.float32)
100 | )
101 |
102 | def forward(self, x):
103 | scale = torch.abs(self.scale * self.lr_mult)
104 | return x * scale
105 |
106 |
107 | class BatchImageNormalize:
108 | def __init__(self, mean, std, dtype=torch.float):
109 | self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None]
110 | self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None]
111 |
112 | def __call__(self, tensor):
113 | tensor = tensor.clone()
114 |
115 | tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device))
116 | return tensor
117 |
--------------------------------------------------------------------------------
/isegm/utils/crop_local.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from skimage.measure import label
4 |
5 | def map_point_in_bbox(y,x,y1,y2,x1,x2,crop_l):
6 | h,w = y2-y1, x2-x1
7 | ry,rx = crop_l/h, crop_l/w
8 | y = (y - y1) * ry
9 | x = (x - x1) * rx
10 | return y,x
11 |
12 |
13 |
14 | def get_focus_cropv1(pred_mask, previous_mask, global_roi, y,x, ratio):
15 | pred_mask = pred_mask > 0.49
16 | previous_mask = previous_mask > 0.49
17 | ymin,ymax,xmin,xmax = global_roi
18 | diff_regions = np.logical_xor(previous_mask, pred_mask)
19 | if previous_mask.sum() == 0:
20 | y1,y2,x1,x2 = get_bbox_from_mask(pred_mask)
21 | else:
22 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8))
23 | label = labels[y,x]
24 | diff_conn_mask = labels == label
25 | y1d,y2d,x1d,x2d = get_bbox_from_mask(diff_conn_mask)
26 | hd,wd = y2d - y1d, x2d - x1d
27 |
28 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask)
29 | hp,wp = y2p - y1p, x2p - x1p
30 |
31 | if hd < hp/3 or wd < wp/3:
32 | r = 0.2
33 | l = max(hp,wp)
34 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l
35 | else:
36 | y1,y2,x1,x2 = y1d,y2d,x1d,x2d
37 |
38 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio )
39 | y1 = max(y1,ymin)
40 | y2 = min(y2,ymax)
41 | x1 = max(x1,xmin)
42 | x2 = min(x2,xmax)
43 | return y1,y2,x1,x2
44 |
45 |
46 | def get_focus_cropv2(pred_mask, previous_mask, global_roi, y,x, ratio):
47 | pred_mask = pred_mask > 0.5
48 | previous_mask = previous_mask > 0.5
49 | ymin,ymax,xmin,xmax = global_roi
50 | diff_regions = np.logical_xor(previous_mask, pred_mask)
51 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8))
52 | label = labels[y,x]
53 | diff_conn_mask = labels == label
54 |
55 | y1d,y2d,x1d,x2d = get_bbox_from_mask(diff_conn_mask)
56 | hd,wd = y2d - y1d, x2d - x1d
57 |
58 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask)
59 | hp,wp = y2p - y1p, x2p - x1p
60 |
61 | if previous_mask.sum() == 0:
62 | y1,y2,x1,x2 = y1p,y2p,x1p,x2p
63 | else:
64 | if hd < hp/3 or wd < wp/3:
65 | r = 0.16
66 | l = max(hp,wp)
67 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l
68 | else:
69 | if diff_conn_mask.sum() > diff_regions.sum() * 0.5:
70 | y1,y2,x1,x2 = y1d,y2d,x1d,x2d
71 | else:
72 | y1,y2,x1,x2 = y1p,y2p,x1p,x2p
73 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio )
74 | y1 = max(y1,ymin)
75 | y2 = min(y2,ymax)
76 | x1 = max(x1,xmin)
77 | x2 = min(x2,xmax)
78 | return y1,y2,x1,x2
79 |
80 |
81 | def get_object_crop(pred_mask, previous_mask, global_roi, y,x, ratio):
82 | pred_mask = pred_mask > 0.49
83 | y1,y2,x1,x2 = get_bbox_from_mask(pred_mask)
84 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio )
85 | ymin,ymax,xmin,xmax = global_roi
86 | y1 = max(y1,ymin)
87 | y2 = min(y2,ymax)
88 | x1 = max(x1,xmin)
89 | x2 = min(x2,xmax)
90 | return y1,y2,x1,x2
91 |
92 |
93 |
94 | def get_click_crop(pred_mask, previous_mask, global_roi, y,x, ratio):
95 | pred_mask = pred_mask > 0.49
96 | y1p,y2p,x1p,x2p= get_bbox_from_mask(pred_mask)
97 | hp,wp = y2p - y1p, x2p - x1p
98 | r = 0.2
99 | l = max(hp,wp)
100 | y1,y2,x1,x2 = y - r *l, y + r * l, x - r * l, x + r * l
101 | y1,y2,x1,x2 = expand_bbox(pred_mask,y1,y2,x1,x2,ratio )
102 | ymin,ymax,xmin,xmax = global_roi
103 | y1 = max(y1,ymin)
104 | y2 = min(y2,ymax)
105 | x1 = max(x1,xmin)
106 | x2 = min(x2,xmax)
107 | return y1,y2,x1,x2
108 |
109 |
110 |
111 |
112 | def getLargestCC(segmentation):
113 | if segmentation.sum()<10:
114 | return segmentation
115 | labels = label(segmentation)
116 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
117 | return largestCC
118 |
119 |
120 |
121 | def get_diff_region(pred_mask, previous_mask, y, x):
122 | y,x = int(y), int(x)
123 | diff_regions = np.logical_xor(previous_mask, pred_mask)
124 | if diff_regions.sum() > 1000:
125 | num, labels = cv2.connectedComponents( diff_regions.astype(np.uint8))
126 | label = labels[y,x]
127 | corr_mask = labels == label
128 | else:
129 | corr_mask = pred_mask
130 | return corr_mask
131 |
132 |
133 |
134 |
135 | def get_bbox_from_mask(mask):
136 | h,w = mask.shape[0],mask.shape[1]
137 |
138 | if mask.sum() < 10:
139 | return 0,h,0,w
140 | rows = np.any(mask,axis=1)
141 | cols = np.any(mask,axis=0)
142 | y1,y2 = np.where(rows)[0][[0,-1]]
143 | x1,x2 = np.where(cols)[0][[0,-1]]
144 | return y1,y2,x1,x2
145 |
146 | def expand_bbox(mask,y1,y2,x1,x2,ratio, min_crop=0):
147 | H,W = mask.shape[0], mask.shape[1]
148 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
149 | h = ratio * (y2-y1+1)
150 | w = ratio * (x2-x1+1)
151 | h = max(h,min_crop)
152 | w = max(w,min_crop)
153 |
154 | x1 = int(xc - w * 0.5)
155 | x2 = int(xc + w * 0.5)
156 | y1 = int(yc - h * 0.5)
157 | y2 = int(yc + h * 0.5)
158 |
159 | x1 = max(0,x1)
160 | x2 = min(W,x2)
161 | y1 = max(0,y1)
162 | y2 = min(H,y2)
163 | return y1,y2,x1,x2
164 |
165 |
166 | def expand_bbox_with_bias(mask,y1,y2,x1,x2,ratio, min_crop=0, bias = 0.3):
167 | H,W = mask.shape[0], mask.shape[1]
168 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
169 | h = ratio * (y2-y1+1)
170 | w = ratio * (x2-x1+1)
171 | h = max(h,min_crop)
172 | w = max(w,min_crop)
173 | hmax, wmax = int(h * bias), int(w * bias)
174 | h_bias = np.random.randint(-hmax,hmax+1)
175 | w_bias = np.random.randint(-wmax,wmax+1)
176 |
177 | x1 = int(xc - w * 0.5) + w_bias
178 | x2 = int(xc + w * 0.5) + w_bias
179 | y1 = int(yc - h * 0.5) + h_bias
180 | y2 = int(yc + h * 0.5) + h_bias
181 |
182 | x1 = max(0,x1)
183 | x2 = min(W,x2)
184 | y1 = max(0,y1)
185 | y2 = min(H,y2)
186 | return y1,y2,x1,x2
187 |
188 |
189 |
190 | def CalBox(mask,last_y = None, last_x = None, expand = 1.5):
191 | y1,y2,x1,x2 = get_bbox_from_mask(mask)
192 | H,W = mask.shape[0], mask.shape[1]
193 | if last_y is not None:
194 | y1 = min(y1,last_y)
195 | y2 = max(y2,last_y)
196 | x1 = min(x1, last_x)
197 | x2 = max(x2,last_x)
198 |
199 | xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
200 | h = expand * (y2-y1+1)
201 | w = expand * (x2-x1+1)
202 | x1 = int(xc - w * 0.5)
203 | x2 = int(xc + w * 0.5)
204 | y1 = int(yc - h * 0.5)
205 | y2 = int(yc + h * 0.5)
206 |
207 | x1 = max(0,x1)
208 | x2 = min(W,x2)
209 | y1 = max(0,y1)
210 | y2 = min(H,y2)
211 | return y1,y2,x1,x2
212 |
213 | def points_back(p_np, y1, x1):
214 | if p_np is None:
215 | return None
216 | bias = np.array( [[y1,x1]]).reshape((1,2))
217 | return p_np + bias
218 |
219 |
220 |
221 |
222 | def PointsInBox(points,y1,y2,x1,x2, H, W ):
223 | if points is None:
224 | return None
225 |
226 | y_ratio = H/(y2-y1)
227 | x_ratio = W/(x2-x1)
228 | num_pos = points.shape[0] // 2
229 | new_points = np.full_like(points,-1)
230 |
231 | valid_pos = 0
232 | for i in range(num_pos):
233 | y,x,index = points[i,0], points[i,1],points[i,2]
234 | if y>y1 and y< y2 and x>x1 and xy1 and y< y2 and x>x1 and x h or crop_w > w:
267 | return 0,0,h,w
268 |
269 |
270 | delta_h = center_h = crop_h // 2
271 | delta_w = center_w = crop_w // 2
272 |
273 | # mask out the validate area for selecting the cropping center
274 | mask = np.zeros_like(unknown)
275 | mask[delta_h:h - delta_h, delta_w:w - delta_w] = 1
276 | if np.any(unknown & mask):
277 | center_h_list, center_w_list = np.where(unknown & mask)
278 | elif np.any(unknown):
279 | center_h_list, center_w_list = np.where(unknown)
280 | else:
281 | #print_log('No unknown pixels found!', level=logging.WARNING)
282 | center_h_list = [center_h]
283 | center_w_list = [center_w]
284 | num_unknowns = len(center_h_list)
285 | rand_ind = np.random.randint(num_unknowns)
286 | center_h = center_h_list[rand_ind]
287 | center_w = center_w_list[rand_ind]
288 |
289 | # make sure the top-left point is valid
290 | top = np.clip(center_h - delta_h, 0, h - crop_h)
291 | left = np.clip(center_w - delta_w, 0, w - crop_w)
292 | y1,x1,y2,x2 = top, left, top + crop_h, left + crop_w
293 |
294 | return y1,x1,y2,x2
--------------------------------------------------------------------------------
/isegm/utils/cython/__init__.py:
--------------------------------------------------------------------------------
1 | # noinspection PyUnresolvedReferences
2 | from .dist_maps import get_dist_maps
--------------------------------------------------------------------------------
/isegm/utils/cython/_get_dist_maps.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport cython
3 | cimport numpy as np
4 | from libc.stdlib cimport malloc, free
5 |
6 | ctypedef struct qnode:
7 | int row
8 | int col
9 | int layer
10 | int orig_row
11 | int orig_col
12 |
13 | @cython.infer_types(True)
14 | @cython.boundscheck(False)
15 | @cython.wraparound(False)
16 | @cython.nonecheck(False)
17 | def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points,
18 | int height, int width, float norm_delimeter):
19 | cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \
20 | np.full((2, height, width), 1e6, dtype=np.float32, order="C")
21 |
22 | cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0]
23 | cdef int i, j, x, y, dx, dy
24 | cdef qnode v
25 | cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode))
26 | cdef int qhead = 0, qtail = -1
27 | cdef float ndist
28 |
29 | for i in range(points.shape[0]):
30 | x, y = round(points[i, 0]), round(points[i, 1])
31 | if x >= 0:
32 | qtail += 1
33 | q[qtail].row = x
34 | q[qtail].col = y
35 | q[qtail].orig_row = x
36 | q[qtail].orig_col = y
37 | if i >= points.shape[0] / 2:
38 | q[qtail].layer = 1
39 | else:
40 | q[qtail].layer = 0
41 | dist_maps[q[qtail].layer, x, y] = 0
42 |
43 | while qtail - qhead + 1 > 0:
44 | v = q[qhead]
45 | qhead += 1
46 |
47 | for k in range(4):
48 | x = v.row + dxy[2 * k]
49 | y = v.col + dxy[2 * k + 1]
50 |
51 | ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2
52 | if (x >= 0 and y >= 0 and x < height and y < width and
53 | dist_maps[v.layer, x, y] > ndist):
54 | qtail += 1
55 | q[qtail].orig_col = v.orig_col
56 | q[qtail].orig_row = v.orig_row
57 | q[qtail].layer = v.layer
58 | q[qtail].row = x
59 | q[qtail].col = y
60 | dist_maps[v.layer, x, y] = ndist
61 |
62 | free(q)
63 | return dist_maps
64 |
--------------------------------------------------------------------------------
/isegm/utils/cython/_get_dist_maps.pyxbld:
--------------------------------------------------------------------------------
1 | import numpy
2 |
3 | def make_ext(modname, pyxfilename):
4 | from distutils.extension import Extension
5 | return Extension(modname, [pyxfilename],
6 | include_dirs=[numpy.get_include()],
7 | extra_compile_args=['-O3'], language='c++')
8 |
--------------------------------------------------------------------------------
/isegm/utils/cython/dist_maps.py:
--------------------------------------------------------------------------------
1 | import pyximport; pyximport.install(pyximport=True, language_level=3)
2 | # noinspection PyUnresolvedReferences
3 | from ._get_dist_maps import get_dist_maps
--------------------------------------------------------------------------------
/isegm/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributed as dist
3 | from torch.utils import data
4 |
5 |
6 | def get_rank():
7 | if not dist.is_available() or not dist.is_initialized():
8 | return 0
9 | return dist.get_rank()
10 |
11 |
12 | def synchronize():
13 | if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1:
14 | return
15 | dist.barrier()
16 |
17 |
18 | def get_world_size():
19 | if not dist.is_available() or not dist.is_initialized():
20 | return 1
21 |
22 | return dist.get_world_size()
23 |
24 |
25 | def reduce_loss_dict(loss_dict):
26 | world_size = get_world_size()
27 |
28 | if world_size < 2:
29 | return loss_dict
30 |
31 | with torch.no_grad():
32 | keys = []
33 | losses = []
34 |
35 | for k in loss_dict.keys():
36 | keys.append(k)
37 | losses.append(loss_dict[k])
38 |
39 | losses = torch.stack(losses, 0)
40 | dist.reduce(losses, dst=0)
41 |
42 | if dist.get_rank() == 0:
43 | losses /= world_size
44 |
45 | reduced_losses = {k: v for k, v in zip(keys, losses)}
46 |
47 | return reduced_losses
48 |
49 |
50 | def get_sampler(dataset, shuffle, distributed):
51 | if distributed:
52 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
53 |
54 | if shuffle:
55 | return data.RandomSampler(dataset)
56 | else:
57 | return data.SequentialSampler(dataset)
58 |
59 |
60 | def get_dp_wrapper(distributed):
61 | class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel):
62 | def __getattr__(self, name):
63 | try:
64 | return super().__getattr__(name)
65 | except AttributeError:
66 | return getattr(self.module, name)
67 | return DPWrapper
68 |
--------------------------------------------------------------------------------
/isegm/utils/exp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import pprint
5 | from pathlib import Path
6 | from datetime import datetime
7 |
8 | import yaml
9 | import torch
10 | from easydict import EasyDict as edict
11 |
12 | from .log import logger, add_logging
13 | from .distributed import synchronize, get_world_size
14 |
15 |
16 | def init_experiment(args, model_name):
17 | model_path = Path(args.model_path)
18 | ftree = get_model_family_tree(model_path, model_name=model_name)
19 |
20 | if ftree is None:
21 | print('Models can only be located in the "models" directory in the root of the repository')
22 | sys.exit(1)
23 |
24 | cfg = load_config(model_path)
25 | update_config(cfg, args)
26 |
27 | cfg.distributed = args.distributed
28 | cfg.local_rank = args.local_rank
29 | if cfg.distributed:
30 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
31 | if args.workers > 0:
32 | torch.multiprocessing.set_start_method('forkserver', force=True)
33 |
34 | experiments_path = Path(cfg.EXPS_PATH)
35 | exp_parent_path = experiments_path / '/'.join(ftree)
36 | exp_parent_path.mkdir(parents=True, exist_ok=True)
37 |
38 | if cfg.resume_exp:
39 | exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp)
40 | else:
41 | last_exp_indx = find_last_exp_indx(exp_parent_path)
42 | exp_name = f'{last_exp_indx:03d}'
43 | if cfg.exp_name:
44 | exp_name += '_' + cfg.exp_name
45 | exp_path = exp_parent_path / exp_name
46 | synchronize()
47 | if cfg.local_rank == 0:
48 | exp_path.mkdir(parents=True)
49 |
50 | cfg.EXP_PATH = exp_path
51 | cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints'
52 | cfg.VIS_PATH = exp_path / 'vis'
53 | cfg.LOGS_PATH = exp_path / 'logs'
54 |
55 | if cfg.local_rank == 0:
56 | cfg.LOGS_PATH.mkdir(exist_ok=True)
57 | cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True)
58 | cfg.VIS_PATH.mkdir(exist_ok=True)
59 |
60 | dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py'))
61 | if args.temp_model_path:
62 | shutil.copy(args.temp_model_path, dst_script_path)
63 | os.remove(args.temp_model_path)
64 | else:
65 | shutil.copy(model_path, dst_script_path)
66 |
67 | synchronize()
68 |
69 | if cfg.gpus != '':
70 | gpu_ids = [int(id) for id in cfg.gpus.split(',')]
71 | else:
72 | gpu_ids = list(range(max(cfg.ngpus, get_world_size())))
73 | cfg.gpus = ','.join([str(id) for id in gpu_ids])
74 |
75 | cfg.gpu_ids = gpu_ids
76 | cfg.ngpus = len(gpu_ids)
77 | cfg.multi_gpu = cfg.ngpus > 1
78 |
79 | if cfg.distributed:
80 | cfg.device = torch.device('cuda')
81 | cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]]
82 | torch.cuda.set_device(cfg.gpu_ids[0])
83 | else:
84 | if cfg.multi_gpu:
85 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus
86 | ngpus = torch.cuda.device_count()
87 | # Added by Xavier
88 | # cfg.gpu_ids = [i for i in range(ngpus)]
89 | # assert ngpus == cfg.ngpus
90 | cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}')
91 |
92 | if cfg.local_rank == 0:
93 | add_logging(cfg.LOGS_PATH, prefix='train_')
94 | logger.info(f'Number of GPUs: {cfg.ngpus}')
95 | if cfg.distributed:
96 | logger.info(f'Multi-Process Multi-GPU Distributed Training')
97 |
98 | logger.info('Run experiment with config:')
99 | logger.info(pprint.pformat(cfg, indent=4))
100 |
101 | return cfg
102 |
103 |
104 | def get_model_family_tree(model_path, terminate_name='models', model_name=None):
105 | if model_name is None:
106 | model_name = model_path.stem
107 | family_tree = [model_name]
108 | for x in model_path.parents:
109 | if x.stem == terminate_name:
110 | break
111 | family_tree.append(x.stem)
112 | else:
113 | return None
114 |
115 | return family_tree[::-1]
116 |
117 |
118 | def find_last_exp_indx(exp_parent_path):
119 | indx = 0
120 | for x in exp_parent_path.iterdir():
121 | if not x.is_dir():
122 | continue
123 |
124 | exp_name = x.stem
125 | if exp_name[:3].isnumeric():
126 | indx = max(indx, int(exp_name[:3]) + 1)
127 |
128 | return indx
129 |
130 |
131 | def find_resume_exp(exp_parent_path, exp_pattern):
132 | candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*'))
133 | if len(candidates) == 0:
134 | print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"')
135 | sys.exit(1)
136 | elif len(candidates) > 1:
137 | print('More than one experiment found:')
138 | for x in candidates:
139 | print(x)
140 | sys.exit(1)
141 | else:
142 | exp_path = candidates[0]
143 | print(f'Continue with experiment "{exp_path}"')
144 |
145 | return exp_path
146 |
147 |
148 | def update_config(cfg, args):
149 | for param_name, value in vars(args).items():
150 | if param_name.lower() in cfg or param_name.upper() in cfg:
151 | continue
152 | cfg[param_name] = value
153 |
154 |
155 | def load_config(model_path):
156 | model_name = model_path.stem
157 | config_path = model_path.parent / (model_name + '.yml')
158 |
159 | if config_path.exists():
160 | cfg = load_config_file(config_path)
161 | else:
162 | cfg = dict()
163 |
164 | cwd = Path.cwd()
165 | config_parent = config_path.parent.absolute()
166 | while len(config_parent.parents) > 0:
167 | config_path = config_parent / 'config.yml'
168 |
169 | if config_path.exists():
170 | local_config = load_config_file(config_path, model_name=model_name)
171 | cfg.update({k: v for k, v in local_config.items() if k not in cfg})
172 |
173 | if config_parent.absolute() == cwd:
174 | break
175 | config_parent = config_parent.parent
176 |
177 | return edict(cfg)
178 |
179 |
180 | def load_config_file(config_path, model_name=None, return_edict=False):
181 | with open(config_path, 'r') as f:
182 | cfg = yaml.safe_load(f)
183 |
184 | if 'SUBCONFIGS' in cfg:
185 | if model_name is not None and model_name in cfg['SUBCONFIGS']:
186 | cfg.update(cfg['SUBCONFIGS'][model_name])
187 | del cfg['SUBCONFIGS']
188 |
189 | return edict(cfg) if return_edict else cfg
190 |
--------------------------------------------------------------------------------
/isegm/utils/exp_imports/default.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 | from easydict import EasyDict as edict
4 | from albumentations import *
5 |
6 | from isegm.data.datasets import *
7 | from isegm.model.losses import *
8 | from isegm.data.transforms import *
9 | from isegm.model.metrics import AdaptiveIoU
10 | from isegm.data.points_sampler import MultiPointSampler
11 | from isegm.utils.log import logger
12 | from isegm.model import initializer
13 |
14 |
--------------------------------------------------------------------------------
/isegm/utils/log.py:
--------------------------------------------------------------------------------
1 | import io
2 | import time
3 | import logging
4 | from datetime import datetime
5 |
6 | import numpy as np
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | LOGGER_NAME = 'root'
10 | LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S'
11 |
12 | handler = logging.StreamHandler()
13 |
14 | logger = logging.getLogger(LOGGER_NAME)
15 | logger.setLevel(logging.INFO)
16 | logger.addHandler(handler)
17 |
18 |
19 | def add_logging(logs_path, prefix):
20 | log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log'
21 | stdout_log_path = logs_path / log_name
22 |
23 | fh = logging.FileHandler(str(stdout_log_path))
24 | formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s',
25 | datefmt=LOGGER_DATEFMT)
26 | fh.setFormatter(formatter)
27 | logger.addHandler(fh)
28 |
29 |
30 | class TqdmToLogger(io.StringIO):
31 | logger = None
32 | level = None
33 | buf = ''
34 |
35 | def __init__(self, logger, level=None, mininterval=5):
36 | super(TqdmToLogger, self).__init__()
37 | self.logger = logger
38 | self.level = level or logging.INFO
39 | self.mininterval = mininterval
40 | self.last_time = 0
41 |
42 | def write(self, buf):
43 | self.buf = buf.strip('\r\n\t ')
44 |
45 | def flush(self):
46 | if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval:
47 | self.logger.log(self.level, self.buf)
48 | self.last_time = time.time()
49 |
50 |
51 | class SummaryWriterAvg(SummaryWriter):
52 | def __init__(self, *args, dump_period=20, **kwargs):
53 | super().__init__(*args, **kwargs)
54 | self._dump_period = dump_period
55 | self._avg_scalars = dict()
56 |
57 | def add_scalar(self, tag, value, global_step=None, disable_avg=False):
58 | if disable_avg or isinstance(value, (tuple, list, dict)):
59 | super().add_scalar(tag, np.array(value), global_step=global_step)
60 | else:
61 | if tag not in self._avg_scalars:
62 | self._avg_scalars[tag] = ScalarAccumulator(self._dump_period)
63 | avg_scalar = self._avg_scalars[tag]
64 | avg_scalar.add(value)
65 |
66 | if avg_scalar.is_full():
67 | super().add_scalar(tag, avg_scalar.value,
68 | global_step=global_step)
69 | avg_scalar.reset()
70 |
71 |
72 | class ScalarAccumulator(object):
73 | def __init__(self, period):
74 | self.sum = 0
75 | self.cnt = 0
76 | self.period = period
77 |
78 | def add(self, value):
79 | self.sum += value
80 | self.cnt += 1
81 |
82 | @property
83 | def value(self):
84 | if self.cnt > 0:
85 | return self.sum / self.cnt
86 | else:
87 | return 0
88 |
89 | def reset(self):
90 | self.cnt = 0
91 | self.sum = 0
92 |
93 | def is_full(self):
94 | return self.cnt >= self.period
95 |
96 | def __len__(self):
97 | return self.cnt
98 |
--------------------------------------------------------------------------------
/isegm/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .log import logger
5 |
6 |
7 | def get_dims_with_exclusion(dim, exclude=None):
8 | dims = list(range(dim))
9 | if exclude is not None:
10 | dims.remove(exclude)
11 |
12 | return dims
13 |
14 |
15 | def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False):
16 | if epoch is None:
17 | checkpoint_name = 'last_checkpoint.pth'
18 | else:
19 | checkpoint_name = f'{epoch:03d}.pth'
20 |
21 | if prefix:
22 | checkpoint_name = f'{prefix}_{checkpoint_name}'
23 |
24 | if not checkpoints_path.exists():
25 | checkpoints_path.mkdir(parents=True)
26 |
27 | checkpoint_path = checkpoints_path / checkpoint_name
28 | if verbose:
29 | logger.info(f'Save checkpoint to {str(checkpoint_path)}')
30 |
31 | net = net.module if multi_gpu else net
32 | torch.save({'state_dict': net.state_dict(),
33 | 'config': net._config}, str(checkpoint_path))
34 |
35 |
36 | def get_bbox_from_mask(mask):
37 | rows = np.any(mask, axis=1)
38 | cols = np.any(mask, axis=0)
39 | rmin, rmax = np.where(rows)[0][[0, -1]]
40 | cmin, cmax = np.where(cols)[0][[0, -1]]
41 |
42 | return rmin, rmax, cmin, cmax
43 |
44 |
45 | def expand_bbox(bbox, expand_ratio, min_crop_size=None):
46 | rmin, rmax, cmin, cmax = bbox
47 | rcenter = 0.5 * (rmin + rmax)
48 | ccenter = 0.5 * (cmin + cmax)
49 | height = expand_ratio * (rmax - rmin + 1)
50 | width = expand_ratio * (cmax - cmin + 1)
51 | if min_crop_size is not None:
52 | height = max(height, min_crop_size)
53 | width = max(width, min_crop_size)
54 |
55 | rmin = int(round(rcenter - 0.5 * height))
56 | rmax = int(round(rcenter + 0.5 * height))
57 | cmin = int(round(ccenter - 0.5 * width))
58 | cmax = int(round(ccenter + 0.5 * width))
59 |
60 | return rmin, rmax, cmin, cmax
61 |
62 |
63 | def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
64 | return (max(rmin, bbox[0]), min(rmax, bbox[1]),
65 | max(cmin, bbox[2]), min(cmax, bbox[3]))
66 |
67 |
68 | def get_bbox_iou(b1, b2):
69 | h_iou = get_segments_iou(b1[:2], b2[:2])
70 | w_iou = get_segments_iou(b1[2:4], b2[2:4])
71 | return h_iou * w_iou
72 |
73 |
74 | def get_segments_iou(s1, s2):
75 | a, b = s1
76 | c, d = s2
77 | intersection = max(0, min(b, d) - max(a, c) + 1)
78 | union = max(1e-6, max(b, d) - min(a, c) + 1)
79 | return intersection / union
80 |
81 |
82 | def get_labels_with_sizes(x):
83 | obj_sizes = np.bincount(x.flatten())
84 | labels = np.nonzero(obj_sizes)[0].tolist()
85 | labels = [x for x in labels if x != 0]
86 | return labels, obj_sizes[labels].tolist()
87 |
--------------------------------------------------------------------------------
/isegm/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from copy import deepcopy
3 | import inspect
4 | import torch.nn as nn
5 |
6 |
7 | def serialize(init):
8 | parameters = list(inspect.signature(init).parameters)
9 |
10 | @wraps(init)
11 | def new_init(self, *args, **kwargs):
12 | params = deepcopy(kwargs)
13 | for pname, value in zip(parameters[1:], args):
14 | params[pname] = value
15 |
16 | config = {
17 | 'class': get_classname(self.__class__),
18 | 'params': dict()
19 | }
20 | specified_params = set(params.keys())
21 |
22 | for pname, param in get_default_params(self.__class__).items():
23 | if pname not in params:
24 | params[pname] = param.default
25 |
26 | for name, value in list(params.items()):
27 | param_type = 'builtin'
28 | if inspect.isclass(value):
29 | param_type = 'class'
30 | value = get_classname(value)
31 |
32 | config['params'][name] = {
33 | 'type': param_type,
34 | 'value': value,
35 | 'specified': name in specified_params
36 | }
37 |
38 | setattr(self, '_config', config)
39 | init(self, *args, **kwargs)
40 |
41 | return new_init
42 |
43 |
44 | def load_model(config, **kwargs):
45 | model_class = get_class_from_str(config['class'])
46 | model_default_params = get_default_params(model_class)
47 |
48 | model_args = dict()
49 | for pname, param in config['params'].items():
50 | value = param['value']
51 | if param['type'] == 'class':
52 | value = get_class_from_str(value)
53 |
54 | if pname not in model_default_params and not param['specified']:
55 | continue
56 |
57 | assert pname in model_default_params
58 | if not param['specified'] and model_default_params[pname].default == value:
59 | continue
60 | model_args[pname] = value
61 |
62 | model_args.update(kwargs)
63 |
64 | return model_class(**model_args)
65 |
66 |
67 | def get_config_repr(config):
68 | config_str = f'Model: {config["class"]}\n'
69 | for pname, param in config['params'].items():
70 | value = param["value"]
71 | if param['type'] == 'class':
72 | value = value.split('.')[-1]
73 | param_str = f'{pname:<22} = {str(value):<12}'
74 | if not param['specified']:
75 | param_str += ' (default)'
76 | config_str += param_str + '\n'
77 | return config_str
78 |
79 |
80 | def get_default_params(some_class):
81 | params = dict()
82 | for mclass in some_class.mro():
83 | if mclass is nn.Module or mclass is object:
84 | continue
85 |
86 | mclass_params = inspect.signature(mclass.__init__).parameters
87 | for pname, param in mclass_params.items():
88 | if param.default != param.empty and pname not in params:
89 | params[pname] = param
90 |
91 | return params
92 |
93 |
94 | def get_classname(cls):
95 | module = cls.__module__
96 | name = cls.__qualname__
97 | if module is not None and module != "__builtin__":
98 | name = module + "." + name
99 | return name
100 |
101 |
102 | def get_class_from_str(class_str):
103 | components = class_str.split('.')
104 | mod = __import__('.'.join(components[:-1]))
105 | for comp in components[1:]:
106 | mod = getattr(mod, comp)
107 | return mod
108 |
--------------------------------------------------------------------------------
/isegm/utils/vis.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | def visualize_instances(imask, bg_color=255,
7 | boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8):
8 | num_objects = imask.max() + 1
9 | palette = get_palette(num_objects)
10 | if bg_color is not None:
11 | palette[0] = bg_color
12 |
13 | result = palette[imask].astype(np.uint8)
14 | if boundaries_color is not None:
15 | boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width)
16 | tresult = result.astype(np.float32)
17 | tresult[boundaries_mask] = boundaries_color
18 | tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result
19 | result = tresult.astype(np.uint8)
20 |
21 | return result
22 |
23 |
24 | @lru_cache(maxsize=16)
25 | def get_palette(num_cls):
26 | palette = np.zeros(3 * num_cls, dtype=np.int32)
27 |
28 | for j in range(0, num_cls):
29 | lab = j
30 | i = 0
31 |
32 | while lab > 0:
33 | palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i))
34 | palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i))
35 | palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i))
36 | i = i + 1
37 | lab >>= 3
38 |
39 | return palette.reshape((-1, 3))
40 |
41 |
42 | def visualize_mask(mask, num_cls):
43 | palette = get_palette(num_cls)
44 | mask[mask == -1] = 0
45 |
46 | return palette[mask].astype(np.uint8)
47 |
48 |
49 | def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1):
50 | proposal_map, colors, candidates = proposals_info
51 |
52 | proposal_map = draw_probmap(proposal_map)
53 | for x, y in candidates:
54 | proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1)
55 |
56 | return proposal_map
57 |
58 |
59 | def draw_probmap(x):
60 | return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT)
61 |
62 |
63 | def draw_points(image, points, color, radius=3):
64 | image = image.copy()
65 | for p in points:
66 | if p[0] < 0:
67 | continue
68 | if len(p) == 3:
69 | pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2
70 | else:
71 | pradius = radius
72 | image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1)
73 |
74 | return image
75 |
76 |
77 | def draw_instance_map(x, palette=None):
78 | num_colors = x.max() + 1
79 | if palette is None:
80 | palette = get_palette(num_colors)
81 |
82 | return palette[x].astype(np.uint8)
83 |
84 |
85 | def blend_mask(image, mask, alpha=0.6):
86 | if mask.min() == -1:
87 | mask = mask.copy() + 1
88 |
89 | imap = draw_instance_map(mask)
90 | result = (image * (1 - alpha) + alpha * imap).astype(np.uint8)
91 | return result
92 |
93 |
94 | def get_boundaries(instances_masks, boundaries_width=1):
95 | boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool)
96 |
97 | for obj_id in np.unique(instances_masks.flatten()):
98 | if obj_id == 0:
99 | continue
100 |
101 | obj_mask = instances_masks == obj_id
102 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
103 | inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool)
104 |
105 | obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask))
106 | boundaries = np.logical_or(boundaries, obj_boundary)
107 | return boundaries
108 |
109 |
110 | def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0),
111 | neg_color=(255, 0, 0), radius=4):
112 | result = img.copy()
113 |
114 | if mask is not None:
115 | palette = get_palette(np.max(mask) + 1)
116 | rgb_mask = palette[mask.astype(np.uint8)]
117 |
118 | mask_region = (mask > 0).astype(np.uint8)
119 | result = result * (1 - mask_region[:, :, np.newaxis]) + \
120 | (1 - alpha) * mask_region[:, :, np.newaxis] * result + \
121 | alpha * rgb_mask
122 | result = result.astype(np.uint8)
123 |
124 | # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8)
125 |
126 | if clicks_list is not None and len(clicks_list) > 0:
127 | pos_points = [click.coords for click in clicks_list if click.is_positive]
128 | neg_points = [click.coords for click in clicks_list if not click.is_positive]
129 |
130 | result = draw_points(result, pos_points, pos_color, radius=radius)
131 | result = draw_points(result, neg_points, neg_color, radius=radius)
132 |
133 | return result
134 |
135 |
136 | def add_tag(image, tag = 'nodefined', tag_h = 40):
137 | image = image.astype(np.uint8)
138 | H,W = image.shape[0], image.shape[1]
139 | tag_blanc = np.ones((tag_h,W,3)).astype(np.uint8) * 255
140 | cv2.putText(tag_blanc,tag,(10,30),cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0 ), 1)
141 | image = cv2.vconcat([image,tag_blanc])
142 | return image
143 |
144 |
145 |
--------------------------------------------------------------------------------
/models/gp_sbd_resnet50.py:
--------------------------------------------------------------------------------
1 | from isegm.utils.exp_imports.default import *
2 | MODEL_NAME = 'resnet50'
3 | # from isegm.data.compose import ComposeDataset,ProportionalComposeDataset
4 | import torch.nn as nn
5 | from isegm.data.aligned_augmentation import AlignedAugmentator
6 | from isegm.engine.gp_trainer import ISTrainer
7 | import importlib
8 |
9 | def main(cfg):
10 | model, model_cfg = init_model(cfg)
11 | train(model, cfg, model_cfg)
12 |
13 |
14 | def init_model(cfg):
15 | model_cfg = edict()
16 | model_cfg.crop_size = (cfg.crop_size, cfg.crop_size)
17 | model_cfg.num_max_points = cfg.num_max_points
18 | GpModel = importlib.import_module('isegm.model.'+cfg.gp_model).GpModel
19 | model = GpModel(backbone = 'resnet50', use_leaky_relu=True, use_disks=(not cfg.nodisk), binary_prev_mask=False,
20 | with_prev_mask=(not cfg.noprev_mask), weight_dir=cfg.IMAGENET_PRETRAINED_MODELS.RESNET50_v1s)
21 | model.to(cfg.device)
22 | model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=2.0))
23 | model.model.load_pretrained_weights()
24 | return model, model_cfg
25 |
26 |
27 | def train(model, cfg, model_cfg):
28 | cfg.batch_size = 32 if cfg.batch_size < 1 else cfg.batch_size
29 | cfg.val_batch_size = cfg.batch_size
30 | crop_size = model_cfg.crop_size
31 |
32 | loss_cfg = edict()
33 | loss_cfg.instance_loss = NormalizedFocalLossSigmoid(alpha=0.5, gamma=2)
34 | loss_cfg.instance_loss_weight = 1.0
35 |
36 | train_augmentator = AlignedAugmentator(ratio=[0.3,1.3], target_size=crop_size,flip=True, distribution='Gaussian', gs_center=0.8)
37 |
38 | val_augmentator = Compose([
39 | UniformRandomResize(scale_range=(0.75, 1.25)),
40 | PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
41 | RandomCrop(*crop_size)
42 | ], p=1.0)
43 |
44 | points_sampler = MultiPointSampler(model_cfg.num_max_points, prob_gamma=0.70,
45 | merge_objects_prob=0.15,
46 | max_num_merged_objects=2,
47 | use_hierarchy=False,
48 | first_click_center=True)
49 |
50 | trainset = SBDDataset(
51 | cfg.SBD_PATH,
52 | split='train',
53 | augmentator=train_augmentator,
54 | min_object_area=80,
55 | keep_background_prob=0.01,
56 | points_sampler=points_sampler,
57 | samples_scores_gamma=1.25
58 | )
59 |
60 | valset = SBDDataset(
61 | cfg.SBD_PATH,
62 | split='val',
63 | augmentator=val_augmentator,
64 | min_object_area=80,
65 | points_sampler=points_sampler,
66 | epoch_len=500
67 | )
68 |
69 | optimizer_params = {
70 | 'lr': cfg.lr, 'betas': (0.9, 0.999), 'eps': 1e-8
71 | }
72 |
73 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
74 | milestones=cfg.milestones[:-1], gamma=0.1)
75 | trainer = ISTrainer(model, cfg, model_cfg, loss_cfg,
76 | trainset, valset,
77 | optimizer='adam',
78 | optimizer_params=optimizer_params,
79 | lr_scheduler=lr_scheduler,
80 | checkpoint_interval=[(0, 50), (200, 10)],
81 | image_dump_interval=cfg.image_dump_interval,
82 | metrics=[AdaptiveIoU()],
83 | max_interactive_points=model_cfg.num_max_points,
84 | max_num_next_clicks=cfg.max_num_next_clicks)
85 | trainer.run(num_epochs=cfg.milestones[-1])
86 |
--------------------------------------------------------------------------------
/net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmhhmz/GPCIS_CVPR2023/6460415a2e784f5623a0c859971f884a89eb0fd0/net.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.3.4
2 | easydict==1.9
3 | opencv_contrib_python==4.2.0.32
4 | torchvision==0.9.0a0+01dfa8e
5 | mmcv_full==1.2.7
6 | numpy==1.17.0
7 | torch==1.8.0
8 | albumentations==0.5.1
9 | termcolor==1.1.0
10 | attrs==21.2.0
11 | timm==0.3.2
12 | scikit_image==0.17.2
13 | scipy==1.5.4
14 | Cython==0.29.23
15 | tqdm==4.61.0
16 | attr==0.3.1
17 | ipython==7.30.1
18 | mmcv==1.4.1
19 | Pillow==8.4.0
20 | PyYAML==6.0
21 | skimage==0.0
22 | thop==0.0.31-2005241907
23 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # training
2 | python3 train.py models/gp_sbd_resnet50.py \
3 | --gpus=0,1 \
4 | --workers=12 \
5 | --batch-size=32 \
6 | --milestones 190 220 230 \
7 | --max_num_next_clicks=3 \
8 | --num_max_points=24 \
9 | --crop_size=256 \
10 | --gp_model=is_gp_resnet50 \
11 | --exp-name=GP_Resnet50_SBD_230epo
12 |
13 | # Evaluation
14 | # python3 scripts/evaluate_model.py Baseline \
15 | # --model_dir=checkpoints/ \
16 | # --checkpoint=GPCIS_Resnet50.pth \
17 | # --datasets=GrabCut,Berkeley,SBD,DAVIS \
18 | # --gpus=0 \
19 | # --n-clicks=20 \
20 | # --target-iou=0.90 \
21 | # --thresh=0.50
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import importlib.util
4 |
5 | import torch
6 | from isegm.utils.exp import init_experiment
7 |
8 |
9 | def main():
10 | args = parse_args()
11 | if args.temp_model_path:
12 | model_script = load_module(args.temp_model_path)
13 | else:
14 | model_script = load_module(args.model_path)
15 |
16 | model_base_name = getattr(model_script, 'MODEL_NAME', None)
17 |
18 | args.distributed = 'WORLD_SIZE' in os.environ
19 | cfg = init_experiment(args, model_base_name)
20 |
21 | torch.backends.cudnn.benchmark = True
22 | torch.multiprocessing.set_sharing_strategy('file_system')
23 |
24 | model_script.main(cfg)
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser()
29 |
30 | parser.add_argument('model_path', type=str,
31 | help='Path to the model script.')
32 |
33 | parser.add_argument('--exp-name', type=str, default='',
34 | help='Here you can specify the name of the experiment. '
35 | 'It will be added as a suffix to the experiment folder.')
36 |
37 | parser.add_argument('--workers', type=int, default=4,
38 | metavar='N', help='Dataloader threads.')
39 |
40 | parser.add_argument('--batch-size', type=int, default=-1,
41 | help='You can override model batch size by specify positive number.')
42 |
43 | parser.add_argument('--ngpus', type=int, default=1,
44 | help='Number of GPUs. '
45 | 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. '
46 | 'You should use either this argument or "--gpus".')
47 |
48 | parser.add_argument('--gpus', type=str, default='', required=False,
49 | help='Ids of used GPUs. You should use either this argument or "--ngpus".')
50 |
51 | parser.add_argument('--resume-exp', type=str, default=None,
52 | help='The prefix of the name of the experiment to be continued. '
53 | 'If you use this field, you must specify the "--resume-prefix" argument.')
54 |
55 | parser.add_argument('--resume-prefix', type=str, default='latest',
56 | help='The prefix of the name of the checkpoint to be loaded.')
57 |
58 | parser.add_argument('--start-epoch', type=int, default=0,
59 | help='The number of the starting epoch from which training will continue. '
60 | '(it is important for correct logging and learning rate)')
61 |
62 | parser.add_argument('--weights', type=str, default=None,
63 | help='Model weights will be loaded from the specified path if you use this argument.')
64 |
65 | parser.add_argument('--temp-model-path', type=str, default='',
66 | help='Do not use this argument (for internal purposes).')
67 |
68 | parser.add_argument("--local_rank", type=int, default=0)
69 |
70 | parser.add_argument("--lr", type=float, default=5e-3)
71 |
72 | parser.add_argument("--max_num_next_clicks", type=int, default=3)
73 |
74 | parser.add_argument("--num_max_points", type=int, default=24)
75 |
76 | parser.add_argument('--gp_model', type=str, default='')
77 |
78 | parser.add_argument('--noprev_mask', action='store_true')
79 |
80 | parser.add_argument('--nodisk', action='store_true')
81 |
82 | parser.add_argument('--binary_prev_mask', action='store_true')
83 |
84 | parser.add_argument("--image_dump_interval", type=int, default=500)
85 |
86 | parser.add_argument("--crop_size", type=int, default=256)
87 |
88 | parser.add_argument('--milestones', type=int, nargs='+', default=[190,210,230])
89 | return parser.parse_args()
90 |
91 |
92 | def load_module(script_path):
93 | spec = importlib.util.spec_from_file_location("model_script", script_path)
94 | model_script = importlib.util.module_from_spec(spec)
95 | spec.loader.exec_module(model_script)
96 |
97 | return model_script
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
102 |
--------------------------------------------------------------------------------