├── .gitignore ├── src ├── configs │ ├── CondInst_P0_1x.yaml │ ├── CondInst_fine_tuning_30k.yaml │ └── point_selection_from_CondInst.yaml ├── register_point_annotations.py ├── train_net_point.py └── condinst │ ├── dynamic_mask_head.py │ ├── standard │ ├── dynamic_mask_head.py │ └── condinst.py │ ├── Entropy │ ├── dynamic_mask_head.py │ └── condinst.py │ └── condinst.py ├── scripts ├── prepare.sh ├── random.py ├── initialization.py └── entropy.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /src/configs/CondInst_P0_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../MS_R_50_1x.yaml" 2 | 3 | MODEL: 4 | BOXINST: 5 | POINT_LOSS_WEIGHT: 0.1 6 | DATASETS: 7 | TRAIN: ("coco_2017_train_points_n1",) 8 | INPUT: 9 | POINT_SUP: True -------------------------------------------------------------------------------- /src/configs/CondInst_fine_tuning_30k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../MS_R_50_1x.yaml" 2 | 3 | MODEL: 4 | WEIGHTS: "" 5 | BOXINST: 6 | POINT_LOSS_WEIGHT: 0.1 7 | DATASETS: 8 | TRAIN: ("coco_2017_train_points_n1",) 9 | INPUT: 10 | POINT_SUP: True 11 | SOLVER: 12 | IMS_PER_BATCH: 16 13 | BASE_LR: 0.01 14 | STEPS: (10000, 20000) 15 | MAX_ITER: 30000 -------------------------------------------------------------------------------- /src/configs/point_selection_from_CondInst.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../MS_R_50_1x.yaml" 2 | 3 | MODEL: 4 | WEIGHTS: "" 5 | SOLVER: 6 | IMS_PER_BATCH: 8 7 | BASE_LR: 0.0 # stop updating the model 8 | MAX_ITER: 14800 # 14800*8 ≈ len(train2017) 9 | INPUT: 10 | MIN_SIZE_TRAIN: (800,) 11 | RANDOM_FLIP: none 12 | HFLIP_TRAIN: False 13 | 14 | # Currently, the point selection process is implemented by 15 | # forwarding the model separately after training, while this 16 | # process can be accelerated by selecting during the last epoch. -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export ROOT_PATH = your_root_path # modify 4 | cd $ROOT_PATH 5 | 6 | # copy source files 7 | cp APIS/src/train_net_point.py AdelaiDet/tools/ 8 | cp APIS/src/register_point_annotations.py detectron2/projects/PointSup/point_sup/ 9 | cp APIS/src/condinst/* AdelaiDet/adet/modeling/condinst/ 10 | mkdir -p AdelaiDet/configs/CondInst/APIS 11 | cp APIS/src/configs/* AdelaiDet/configs/CondInst/APIS/ 12 | 13 | cd detectron2 && export DETECTRON2_DATASETS=$ROOT_PATH/AdelaiDet/datasets && cd .. 14 | 15 | # generate random points (see PointSup for details) 16 | python detectron2/projects/PointSup/tools/prepare_coco_point_annotations_without_masks.py 10 17 | 18 | mkdir -p AdelaiDet/models -------------------------------------------------------------------------------- /scripts/random.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system('cd ./AdelaiDet') 4 | 5 | # step=2 is P1 in the paper 6 | for step in range(2, 11): 7 | pre_step = step - 1 8 | 9 | ######################################################################################################### 10 | # train with the selected points 11 | ######################################################################################################### 12 | 13 | strs = (f'OMP_NUM_THREADS=1 python tools/train_net_point.py \ 14 | --config-file configs/CondInst/APIS/CondInst_fine_tuning_30k.yaml \ 15 | --num-gpus 8 \ 16 | DATASETS.TRAIN "\'coco_2017_train_points_n{step}_random\'," \ 17 | MODEL.WEIGHTS models/model_{pre_step}.pth \ 18 | MODEL.BOXINST.POINT_LOSS_WEIGHT 1.0 \ 19 | OUTPUT_DIR training_dir/random_logs_n{step}') 20 | os.system(strs) 21 | os.system(f'cp training_dir/random_logs_n{step}/model_final.pth models/model_{step}.pth') -------------------------------------------------------------------------------- /scripts/initialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | 5 | os.system('cd ./AdelaiDet') 6 | 7 | # generate random points for p1~p9 8 | # p(i) should be the subset of p(j) if i0.5).clone().detach().long() 17 | labels = labels.clone().detach().long() 18 | intersection = (outputs & labels).float().sum((2, 3)) 19 | union = (outputs | labels).float().sum((2, 3)) 20 | iou = (intersection + SMOOTH) / (union + SMOOTH) 21 | return iou.mean() 22 | 23 | # from detectron2.projects.point_sup.point_utils import get_point_coords_from_point_annotation 24 | def get_point_coords_from_point_annotation(instances, gt_inds, image_wh): 25 | # re-designed for condinst, image_wh (width, height) should be the padded size 26 | gt_point_coords = [] 27 | gt_point_labels = [] 28 | 29 | for per_im in instances: 30 | gt_point_coords.append(per_im.gt_point_coords.to(torch.float32) / image_wh) 31 | gt_point_labels.append(per_im.gt_point_labels.to(torch.float32).clone()) 32 | 33 | gt_point_coords = torch.cat(gt_point_coords) 34 | gt_point_labels = torch.cat(gt_point_labels) 35 | 36 | gt_point_coords = gt_point_coords[gt_inds] 37 | gt_point_labels = gt_point_labels[gt_inds] 38 | 39 | return gt_point_coords, gt_point_labels 40 | 41 | 42 | def compute_project_term(mask_scores, gt_bitmasks): 43 | mask_losses_y = dice_coefficient( 44 | mask_scores.max(dim=2, keepdim=True)[0], 45 | gt_bitmasks.max(dim=2, keepdim=True)[0] 46 | ) 47 | mask_losses_x = dice_coefficient( 48 | mask_scores.max(dim=3, keepdim=True)[0], 49 | gt_bitmasks.max(dim=3, keepdim=True)[0] 50 | ) 51 | return (mask_losses_x + mask_losses_y).mean() 52 | 53 | 54 | def compute_pairwise_term(mask_logits, pairwise_size, pairwise_dilation): 55 | assert mask_logits.dim() == 4 56 | 57 | log_fg_prob = F.logsigmoid(mask_logits) 58 | log_bg_prob = F.logsigmoid(-mask_logits) 59 | 60 | from adet.modeling.condinst.condinst import unfold_wo_center 61 | log_fg_prob_unfold = unfold_wo_center( 62 | log_fg_prob, kernel_size=pairwise_size, 63 | dilation=pairwise_dilation 64 | ) 65 | log_bg_prob_unfold = unfold_wo_center( 66 | log_bg_prob, kernel_size=pairwise_size, 67 | dilation=pairwise_dilation 68 | ) 69 | 70 | # the probability of making the same prediction = p_i * p_j + (1 - p_i) * (1 - p_j) 71 | # we compute the the probability in log space to avoid numerical instability 72 | log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold 73 | log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold 74 | 75 | max_ = torch.max(log_same_fg_prob, log_same_bg_prob) 76 | log_same_prob = torch.log( 77 | torch.exp(log_same_fg_prob - max_) + 78 | torch.exp(log_same_bg_prob - max_) 79 | ) + max_ 80 | 81 | # loss = -log(prob) 82 | return -log_same_prob[:, 0] 83 | 84 | 85 | def dice_coefficient(x, target): 86 | eps = 1e-5 87 | n_inst = x.size(0) 88 | x = x.reshape(n_inst, -1) 89 | target = target.reshape(n_inst, -1) 90 | intersection = (x * target).sum(dim=1) 91 | union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps 92 | loss = 1. - (2 * intersection / union) 93 | return loss 94 | 95 | 96 | def parse_dynamic_params(params, channels, weight_nums, bias_nums): 97 | assert params.dim() == 2 98 | assert len(weight_nums) == len(bias_nums) 99 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 100 | 101 | num_insts = params.size(0) 102 | num_layers = len(weight_nums) 103 | 104 | params_splits = list(torch.split_with_sizes( 105 | params, weight_nums + bias_nums, dim=1 106 | )) 107 | 108 | weight_splits = params_splits[:num_layers] 109 | bias_splits = params_splits[num_layers:] 110 | 111 | for l in range(num_layers): 112 | if l < num_layers - 1: 113 | # out_channels x in_channels x 1 x 1 114 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) 115 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 116 | else: 117 | # out_channels x in_channels x 1 x 1 118 | weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) 119 | bias_splits[l] = bias_splits[l].reshape(num_insts) 120 | 121 | return weight_splits, bias_splits 122 | 123 | 124 | def build_dynamic_mask_head(cfg): 125 | return DynamicMaskHead(cfg) 126 | 127 | 128 | class DynamicMaskHead(nn.Module): 129 | def __init__(self, cfg): 130 | super(DynamicMaskHead, self).__init__() 131 | self.num_layers = cfg.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS 132 | self.channels = cfg.MODEL.CONDINST.MASK_HEAD.CHANNELS 133 | self.in_channels = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS 134 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 135 | self.disable_rel_coords = cfg.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS 136 | 137 | soi = cfg.MODEL.FCOS.SIZES_OF_INTEREST 138 | self.register_buffer("sizes_of_interest", torch.tensor(soi + [soi[-1] * 2])) 139 | 140 | # boxinst configs 141 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 142 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 143 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 144 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 145 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 146 | self._warmup_iters = cfg.MODEL.BOXINST.PAIRWISE.WARMUP_ITERS 147 | 148 | # pointsup configs 149 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 150 | self.point_loss_weight = cfg.MODEL.BOXINST.POINT_LOSS_WEIGHT 151 | 152 | weight_nums, bias_nums = [], [] 153 | for l in range(self.num_layers): 154 | if l == 0: 155 | if not self.disable_rel_coords: 156 | weight_nums.append((self.in_channels + 2) * self.channels) 157 | else: 158 | weight_nums.append(self.in_channels * self.channels) 159 | bias_nums.append(self.channels) 160 | elif l == self.num_layers - 1: 161 | weight_nums.append(self.channels * 1) 162 | bias_nums.append(1) 163 | else: 164 | weight_nums.append(self.channels * self.channels) 165 | bias_nums.append(self.channels) 166 | 167 | self.weight_nums = weight_nums 168 | self.bias_nums = bias_nums 169 | self.num_gen_params = sum(weight_nums) + sum(bias_nums) 170 | 171 | self.register_buffer("_iter", torch.zeros([1])) 172 | 173 | def mask_heads_forward(self, features, weights, biases, num_insts): 174 | ''' 175 | :param features 176 | :param weights: [w0, w1, ...] 177 | :param bias: [b0, b1, ...] 178 | :return: 179 | ''' 180 | assert features.dim() == 4 181 | n_layers = len(weights) 182 | x = features 183 | for i, (w, b) in enumerate(zip(weights, biases)): 184 | x = F.conv2d( 185 | x, w, bias=b, 186 | stride=1, padding=0, 187 | groups=num_insts 188 | ) 189 | if i < n_layers - 1: 190 | x = F.relu(x) 191 | return x 192 | 193 | def mask_heads_forward_with_coords( 194 | self, mask_feats, mask_feat_stride, instances 195 | ): 196 | locations = compute_locations( 197 | mask_feats.size(2), mask_feats.size(3), 198 | stride=mask_feat_stride, device=mask_feats.device 199 | ) 200 | n_inst = len(instances) 201 | 202 | im_inds = instances.im_inds 203 | mask_head_params = instances.mask_head_params 204 | 205 | N, _, H, W = mask_feats.size() 206 | 207 | if not self.disable_rel_coords: 208 | instance_locations = instances.locations 209 | relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) 210 | relative_coords = relative_coords.permute(0, 2, 1).float() 211 | soi = self.sizes_of_interest.float()[instances.fpn_levels] 212 | relative_coords = relative_coords / soi.reshape(-1, 1, 1) 213 | relative_coords = relative_coords.to(dtype=mask_feats.dtype) 214 | 215 | mask_head_inputs = torch.cat([ 216 | relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 217 | ], dim=1) 218 | else: 219 | mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 220 | 221 | mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) 222 | 223 | weights, biases = parse_dynamic_params( 224 | mask_head_params, self.channels, 225 | self.weight_nums, self.bias_nums 226 | ) 227 | 228 | mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) 229 | 230 | mask_logits = mask_logits.reshape(-1, 1, H, W) 231 | 232 | assert mask_feat_stride >= self.mask_out_stride 233 | assert mask_feat_stride % self.mask_out_stride == 0 234 | mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) 235 | 236 | return mask_logits 237 | 238 | def __call__(self, mask_feats, mask_feat_stride, pred_instances, gt_instances=None): 239 | if self.training: 240 | self._iter += 1 241 | 242 | gt_inds = pred_instances.gt_inds 243 | if self.boxinst_enabled or (not self.boxinst_enabled and not self.point_sup_enabled): 244 | gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) 245 | gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype) 246 | 247 | losses = {} 248 | 249 | if len(pred_instances) == 0: 250 | dummy_loss = mask_feats.sum() * 0 + pred_instances.mask_head_params.sum() * 0 251 | if not self.boxinst_enabled and not self.point_sup_enabled: 252 | # fully-supervised 253 | losses["loss_mask"] = dummy_loss 254 | else: 255 | # BoxInst and/or PointSup settings 256 | if self.boxinst_enabled: 257 | losses["loss_prj"] = dummy_loss 258 | losses["loss_pairwise"] = dummy_loss 259 | if self.point_sup_enabled: 260 | losses["loss_point"] = dummy_loss 261 | else: 262 | mask_logits = self.mask_heads_forward_with_coords( 263 | mask_feats, mask_feat_stride, pred_instances 264 | ) 265 | mask_scores = mask_logits.sigmoid() 266 | 267 | if not self.boxinst_enabled and not self.point_sup_enabled: 268 | # fully-supervised CondInst losses 269 | mask_losses = dice_coefficient(mask_scores, gt_bitmasks) 270 | loss_mask = mask_losses.mean() 271 | losses["loss_mask"] = loss_mask 272 | else: 273 | # BoxInst and/or PointSup losses 274 | if self.boxinst_enabled: 275 | # box-supervised BoxInst losses 276 | image_color_similarity = torch.cat([x.image_color_similarity for x in gt_instances]) 277 | image_color_similarity = image_color_similarity[gt_inds].to(dtype=mask_feats.dtype) 278 | 279 | loss_prj_term = compute_project_term(mask_scores, gt_bitmasks) 280 | 281 | pairwise_losses = compute_pairwise_term( 282 | mask_logits, self.pairwise_size, 283 | self.pairwise_dilation 284 | ) 285 | 286 | weights = (image_color_similarity >= self.pairwise_color_thresh).float() * gt_bitmasks.float() 287 | loss_pairwise = (pairwise_losses * weights).sum() / weights.sum().clamp(min=1.0) 288 | 289 | warmup_factor = min(self._iter.item() / float(self._warmup_iters), 1.0) 290 | loss_pairwise = loss_pairwise * warmup_factor 291 | 292 | losses.update({ 293 | "loss_prj": loss_prj_term, 294 | "loss_pairwise": loss_pairwise, 295 | }) 296 | if self.point_sup_enabled: 297 | # pointly-supervised CondInst losses 298 | image_wh = torch.Tensor([mask_logits.size(3) * self.mask_out_stride, mask_logits.size(2) * self.mask_out_stride]).to(mask_logits.device) 299 | point_coords, point_labels = get_point_coords_from_point_annotation(gt_instances, gt_inds, image_wh) 300 | 301 | point_logits = point_sample( 302 | mask_logits, 303 | point_coords, 304 | mode='bilinear', 305 | align_corners=False, 306 | ) 307 | 308 | loss_point = roi_mask_point_loss(point_logits, gt_instances, point_labels) 309 | 310 | losses.update({ 311 | "loss_point": loss_point * self.point_loss_weight, 312 | }) 313 | return losses 314 | else: 315 | if len(pred_instances) > 0: 316 | mask_logits = self.mask_heads_forward_with_coords( 317 | mask_feats, mask_feat_stride, pred_instances 318 | ) 319 | pred_instances.pred_global_masks = mask_logits.sigmoid() 320 | 321 | return pred_instances 322 | -------------------------------------------------------------------------------- /src/condinst/standard/dynamic_mask_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch import nn 4 | 5 | from detectron2.utils.events import get_event_storage 6 | 7 | from adet.utils.comm import compute_locations, aligned_bilinear 8 | 9 | from detectron2.projects.point_rend.point_features import point_sample 10 | from detectron2.projects.point_rend.point_head import roi_mask_point_loss 11 | 12 | from detectron2.layers import cat 13 | 14 | SMOOTH = 1e-6 15 | def get_mask_iou(outputs, labels): 16 | outputs = (outputs>0.5).clone().detach().long() 17 | labels = labels.clone().detach().long() 18 | intersection = (outputs & labels).float().sum((2, 3)) 19 | union = (outputs | labels).float().sum((2, 3)) 20 | iou = (intersection + SMOOTH) / (union + SMOOTH) 21 | return iou.mean() 22 | 23 | # from detectron2.projects.point_sup.point_utils import get_point_coords_from_point_annotation 24 | def get_point_coords_from_point_annotation(instances, gt_inds, image_wh): 25 | # re-designed for condinst, image_wh (width, height) should be the padded size 26 | gt_point_coords = [] 27 | gt_point_labels = [] 28 | 29 | for per_im in instances: 30 | gt_point_coords.append(per_im.gt_point_coords.to(torch.float32) / image_wh) 31 | gt_point_labels.append(per_im.gt_point_labels.to(torch.float32).clone()) 32 | 33 | gt_point_coords = torch.cat(gt_point_coords) 34 | gt_point_labels = torch.cat(gt_point_labels) 35 | 36 | gt_point_coords = gt_point_coords[gt_inds] 37 | gt_point_labels = gt_point_labels[gt_inds] 38 | 39 | return gt_point_coords, gt_point_labels 40 | 41 | 42 | def compute_project_term(mask_scores, gt_bitmasks): 43 | mask_losses_y = dice_coefficient( 44 | mask_scores.max(dim=2, keepdim=True)[0], 45 | gt_bitmasks.max(dim=2, keepdim=True)[0] 46 | ) 47 | mask_losses_x = dice_coefficient( 48 | mask_scores.max(dim=3, keepdim=True)[0], 49 | gt_bitmasks.max(dim=3, keepdim=True)[0] 50 | ) 51 | return (mask_losses_x + mask_losses_y).mean() 52 | 53 | 54 | def compute_pairwise_term(mask_logits, pairwise_size, pairwise_dilation): 55 | assert mask_logits.dim() == 4 56 | 57 | log_fg_prob = F.logsigmoid(mask_logits) 58 | log_bg_prob = F.logsigmoid(-mask_logits) 59 | 60 | from adet.modeling.condinst.condinst import unfold_wo_center 61 | log_fg_prob_unfold = unfold_wo_center( 62 | log_fg_prob, kernel_size=pairwise_size, 63 | dilation=pairwise_dilation 64 | ) 65 | log_bg_prob_unfold = unfold_wo_center( 66 | log_bg_prob, kernel_size=pairwise_size, 67 | dilation=pairwise_dilation 68 | ) 69 | 70 | # the probability of making the same prediction = p_i * p_j + (1 - p_i) * (1 - p_j) 71 | # we compute the the probability in log space to avoid numerical instability 72 | log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold 73 | log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold 74 | 75 | max_ = torch.max(log_same_fg_prob, log_same_bg_prob) 76 | log_same_prob = torch.log( 77 | torch.exp(log_same_fg_prob - max_) + 78 | torch.exp(log_same_bg_prob - max_) 79 | ) + max_ 80 | 81 | # loss = -log(prob) 82 | return -log_same_prob[:, 0] 83 | 84 | 85 | def dice_coefficient(x, target): 86 | eps = 1e-5 87 | n_inst = x.size(0) 88 | x = x.reshape(n_inst, -1) 89 | target = target.reshape(n_inst, -1) 90 | intersection = (x * target).sum(dim=1) 91 | union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps 92 | loss = 1. - (2 * intersection / union) 93 | return loss 94 | 95 | 96 | def parse_dynamic_params(params, channels, weight_nums, bias_nums): 97 | assert params.dim() == 2 98 | assert len(weight_nums) == len(bias_nums) 99 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 100 | 101 | num_insts = params.size(0) 102 | num_layers = len(weight_nums) 103 | 104 | params_splits = list(torch.split_with_sizes( 105 | params, weight_nums + bias_nums, dim=1 106 | )) 107 | 108 | weight_splits = params_splits[:num_layers] 109 | bias_splits = params_splits[num_layers:] 110 | 111 | for l in range(num_layers): 112 | if l < num_layers - 1: 113 | # out_channels x in_channels x 1 x 1 114 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) 115 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 116 | else: 117 | # out_channels x in_channels x 1 x 1 118 | weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) 119 | bias_splits[l] = bias_splits[l].reshape(num_insts) 120 | 121 | return weight_splits, bias_splits 122 | 123 | 124 | def build_dynamic_mask_head(cfg): 125 | return DynamicMaskHead(cfg) 126 | 127 | 128 | class DynamicMaskHead(nn.Module): 129 | def __init__(self, cfg): 130 | super(DynamicMaskHead, self).__init__() 131 | self.num_layers = cfg.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS 132 | self.channels = cfg.MODEL.CONDINST.MASK_HEAD.CHANNELS 133 | self.in_channels = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS 134 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 135 | self.disable_rel_coords = cfg.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS 136 | 137 | soi = cfg.MODEL.FCOS.SIZES_OF_INTEREST 138 | self.register_buffer("sizes_of_interest", torch.tensor(soi + [soi[-1] * 2])) 139 | 140 | # boxinst configs 141 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 142 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 143 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 144 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 145 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 146 | self._warmup_iters = cfg.MODEL.BOXINST.PAIRWISE.WARMUP_ITERS 147 | 148 | # pointsup configs 149 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 150 | self.point_loss_weight = cfg.MODEL.BOXINST.POINT_LOSS_WEIGHT 151 | 152 | weight_nums, bias_nums = [], [] 153 | for l in range(self.num_layers): 154 | if l == 0: 155 | if not self.disable_rel_coords: 156 | weight_nums.append((self.in_channels + 2) * self.channels) 157 | else: 158 | weight_nums.append(self.in_channels * self.channels) 159 | bias_nums.append(self.channels) 160 | elif l == self.num_layers - 1: 161 | weight_nums.append(self.channels * 1) 162 | bias_nums.append(1) 163 | else: 164 | weight_nums.append(self.channels * self.channels) 165 | bias_nums.append(self.channels) 166 | 167 | self.weight_nums = weight_nums 168 | self.bias_nums = bias_nums 169 | self.num_gen_params = sum(weight_nums) + sum(bias_nums) 170 | 171 | self.register_buffer("_iter", torch.zeros([1])) 172 | 173 | def mask_heads_forward(self, features, weights, biases, num_insts): 174 | ''' 175 | :param features 176 | :param weights: [w0, w1, ...] 177 | :param bias: [b0, b1, ...] 178 | :return: 179 | ''' 180 | assert features.dim() == 4 181 | n_layers = len(weights) 182 | x = features 183 | for i, (w, b) in enumerate(zip(weights, biases)): 184 | x = F.conv2d( 185 | x, w, bias=b, 186 | stride=1, padding=0, 187 | groups=num_insts 188 | ) 189 | if i < n_layers - 1: 190 | x = F.relu(x) 191 | return x 192 | 193 | def mask_heads_forward_with_coords( 194 | self, mask_feats, mask_feat_stride, instances 195 | ): 196 | locations = compute_locations( 197 | mask_feats.size(2), mask_feats.size(3), 198 | stride=mask_feat_stride, device=mask_feats.device 199 | ) 200 | n_inst = len(instances) 201 | 202 | im_inds = instances.im_inds 203 | mask_head_params = instances.mask_head_params 204 | 205 | N, _, H, W = mask_feats.size() 206 | 207 | if not self.disable_rel_coords: 208 | instance_locations = instances.locations 209 | relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) 210 | relative_coords = relative_coords.permute(0, 2, 1).float() 211 | soi = self.sizes_of_interest.float()[instances.fpn_levels] 212 | relative_coords = relative_coords / soi.reshape(-1, 1, 1) 213 | relative_coords = relative_coords.to(dtype=mask_feats.dtype) 214 | 215 | mask_head_inputs = torch.cat([ 216 | relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 217 | ], dim=1) 218 | else: 219 | mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 220 | 221 | mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) 222 | 223 | weights, biases = parse_dynamic_params( 224 | mask_head_params, self.channels, 225 | self.weight_nums, self.bias_nums 226 | ) 227 | 228 | mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) 229 | 230 | mask_logits = mask_logits.reshape(-1, 1, H, W) 231 | 232 | assert mask_feat_stride >= self.mask_out_stride 233 | assert mask_feat_stride % self.mask_out_stride == 0 234 | mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) 235 | 236 | return mask_logits 237 | 238 | def __call__(self, mask_feats, mask_feat_stride, pred_instances, gt_instances=None): 239 | if self.training: 240 | self._iter += 1 241 | 242 | gt_inds = pred_instances.gt_inds 243 | if self.boxinst_enabled or (not self.boxinst_enabled and not self.point_sup_enabled): 244 | gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) 245 | gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype) 246 | 247 | losses = {} 248 | 249 | if len(pred_instances) == 0: 250 | dummy_loss = mask_feats.sum() * 0 + pred_instances.mask_head_params.sum() * 0 251 | if not self.boxinst_enabled and not self.point_sup_enabled: 252 | # fully-supervised 253 | losses["loss_mask"] = dummy_loss 254 | else: 255 | # BoxInst and/or PointSup settings 256 | if self.boxinst_enabled: 257 | losses["loss_prj"] = dummy_loss 258 | losses["loss_pairwise"] = dummy_loss 259 | if self.point_sup_enabled: 260 | losses["loss_point"] = dummy_loss 261 | else: 262 | mask_logits = self.mask_heads_forward_with_coords( 263 | mask_feats, mask_feat_stride, pred_instances 264 | ) 265 | mask_scores = mask_logits.sigmoid() 266 | 267 | if not self.boxinst_enabled and not self.point_sup_enabled: 268 | # fully-supervised CondInst losses 269 | mask_losses = dice_coefficient(mask_scores, gt_bitmasks) 270 | loss_mask = mask_losses.mean() 271 | losses["loss_mask"] = loss_mask 272 | else: 273 | # BoxInst and/or PointSup losses 274 | if self.boxinst_enabled: 275 | # box-supervised BoxInst losses 276 | image_color_similarity = torch.cat([x.image_color_similarity for x in gt_instances]) 277 | image_color_similarity = image_color_similarity[gt_inds].to(dtype=mask_feats.dtype) 278 | 279 | loss_prj_term = compute_project_term(mask_scores, gt_bitmasks) 280 | 281 | pairwise_losses = compute_pairwise_term( 282 | mask_logits, self.pairwise_size, 283 | self.pairwise_dilation 284 | ) 285 | 286 | weights = (image_color_similarity >= self.pairwise_color_thresh).float() * gt_bitmasks.float() 287 | loss_pairwise = (pairwise_losses * weights).sum() / weights.sum().clamp(min=1.0) 288 | 289 | warmup_factor = min(self._iter.item() / float(self._warmup_iters), 1.0) 290 | loss_pairwise = loss_pairwise * warmup_factor 291 | 292 | losses.update({ 293 | "loss_prj": loss_prj_term, 294 | "loss_pairwise": loss_pairwise, 295 | }) 296 | if self.point_sup_enabled: 297 | # pointly-supervised CondInst losses 298 | image_wh = torch.Tensor([mask_logits.size(3) * self.mask_out_stride, mask_logits.size(2) * self.mask_out_stride]).to(mask_logits.device) 299 | point_coords, point_labels = get_point_coords_from_point_annotation(gt_instances, gt_inds, image_wh) 300 | 301 | point_logits = point_sample( 302 | mask_logits, 303 | point_coords, 304 | mode='bilinear', 305 | align_corners=False, 306 | ) 307 | 308 | loss_point = roi_mask_point_loss(point_logits, gt_instances, point_labels) 309 | 310 | losses.update({ 311 | "loss_point": loss_point * self.point_loss_weight, 312 | }) 313 | return losses 314 | else: 315 | if len(pred_instances) > 0: 316 | mask_logits = self.mask_heads_forward_with_coords( 317 | mask_feats, mask_feat_stride, pred_instances 318 | ) 319 | pred_instances.pred_global_masks = mask_logits.sigmoid() 320 | 321 | return pred_instances 322 | -------------------------------------------------------------------------------- /src/condinst/Entropy/dynamic_mask_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch import nn 4 | 5 | from detectron2.utils.events import get_event_storage 6 | 7 | from adet.utils.comm import compute_locations, aligned_bilinear 8 | 9 | from detectron2.projects.point_rend.point_features import point_sample 10 | from detectron2.projects.point_rend.point_head import roi_mask_point_loss 11 | 12 | from detectron2.layers import cat 13 | 14 | 15 | SMOOTH = 1e-6 16 | def get_mask_iou(outputs, labels): 17 | outputs = (outputs>0.5).clone().detach().long() 18 | labels = labels.clone().detach().long() 19 | intersection = (outputs & labels).float().sum((2, 3)) 20 | union = (outputs | labels).float().sum((2, 3)) 21 | iou = (intersection + SMOOTH) / (union + SMOOTH) 22 | return iou.mean() 23 | 24 | 25 | # from detectron2.projects.point_sup.point_utils import get_point_coords_from_point_annotation 26 | def get_point_coords_from_point_annotation(instances, gt_inds, image_wh): 27 | # re-designed for condinst, image_wh (width, height) should be the padded size 28 | gt_point_coords = [] 29 | gt_point_labels = [] 30 | 31 | for per_im in instances: 32 | gt_point_coords.append(per_im.gt_point_coords.to(torch.float32) / image_wh) 33 | gt_point_labels.append(per_im.gt_point_labels.to(torch.float32).clone()) 34 | 35 | gt_point_coords = torch.cat(gt_point_coords) 36 | gt_point_labels = torch.cat(gt_point_labels) 37 | 38 | gt_point_coords = gt_point_coords[gt_inds] 39 | gt_point_labels = gt_point_labels[gt_inds] 40 | 41 | return gt_point_coords, gt_point_labels 42 | 43 | 44 | def compute_project_term(mask_scores, gt_bitmasks): 45 | mask_losses_y = dice_coefficient( 46 | mask_scores.max(dim=2, keepdim=True)[0], 47 | gt_bitmasks.max(dim=2, keepdim=True)[0] 48 | ) 49 | mask_losses_x = dice_coefficient( 50 | mask_scores.max(dim=3, keepdim=True)[0], 51 | gt_bitmasks.max(dim=3, keepdim=True)[0] 52 | ) 53 | return (mask_losses_x + mask_losses_y).mean() 54 | 55 | 56 | def compute_pairwise_term(mask_logits, pairwise_size, pairwise_dilation): 57 | assert mask_logits.dim() == 4 58 | 59 | log_fg_prob = F.logsigmoid(mask_logits) 60 | log_bg_prob = F.logsigmoid(-mask_logits) 61 | 62 | from adet.modeling.condinst.condinst import unfold_wo_center 63 | log_fg_prob_unfold = unfold_wo_center( 64 | log_fg_prob, kernel_size=pairwise_size, 65 | dilation=pairwise_dilation 66 | ) 67 | log_bg_prob_unfold = unfold_wo_center( 68 | log_bg_prob, kernel_size=pairwise_size, 69 | dilation=pairwise_dilation 70 | ) 71 | 72 | # the probability of making the same prediction = p_i * p_j + (1 - p_i) * (1 - p_j) 73 | # we compute the the probability in log space to avoid numerical instability 74 | log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold 75 | log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold 76 | 77 | max_ = torch.max(log_same_fg_prob, log_same_bg_prob) 78 | log_same_prob = torch.log( 79 | torch.exp(log_same_fg_prob - max_) + 80 | torch.exp(log_same_bg_prob - max_) 81 | ) + max_ 82 | 83 | # loss = -log(prob) 84 | return -log_same_prob[:, 0] 85 | 86 | 87 | def dice_coefficient(x, target): 88 | eps = 1e-5 89 | n_inst = x.size(0) 90 | x = x.reshape(n_inst, -1) 91 | target = target.reshape(n_inst, -1) 92 | intersection = (x * target).sum(dim=1) 93 | union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps 94 | loss = 1. - (2 * intersection / union) 95 | return loss 96 | 97 | 98 | def parse_dynamic_params(params, channels, weight_nums, bias_nums): 99 | assert params.dim() == 2 100 | assert len(weight_nums) == len(bias_nums) 101 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 102 | 103 | num_insts = params.size(0) 104 | num_layers = len(weight_nums) 105 | 106 | params_splits = list(torch.split_with_sizes( 107 | params, weight_nums + bias_nums, dim=1 108 | )) 109 | 110 | weight_splits = params_splits[:num_layers] 111 | bias_splits = params_splits[num_layers:] 112 | 113 | for l in range(num_layers): 114 | if l < num_layers - 1: 115 | # out_channels x in_channels x 1 x 1 116 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) 117 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 118 | else: 119 | # out_channels x in_channels x 1 x 1 120 | weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) 121 | bias_splits[l] = bias_splits[l].reshape(num_insts) 122 | 123 | return weight_splits, bias_splits 124 | 125 | 126 | def build_dynamic_mask_head(cfg): 127 | return DynamicMaskHead(cfg) 128 | 129 | 130 | class DynamicMaskHead(nn.Module): 131 | def __init__(self, cfg): 132 | super(DynamicMaskHead, self).__init__() 133 | self.num_layers = cfg.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS 134 | self.channels = cfg.MODEL.CONDINST.MASK_HEAD.CHANNELS 135 | self.in_channels = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS 136 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 137 | self.disable_rel_coords = cfg.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS 138 | 139 | soi = cfg.MODEL.FCOS.SIZES_OF_INTEREST 140 | self.register_buffer("sizes_of_interest", torch.tensor(soi + [soi[-1] * 2])) 141 | 142 | # boxinst configs 143 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 144 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 145 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 146 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 147 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 148 | self._warmup_iters = cfg.MODEL.BOXINST.PAIRWISE.WARMUP_ITERS 149 | 150 | # pointsup configs 151 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 152 | self.point_loss_weight = cfg.MODEL.BOXINST.POINT_LOSS_WEIGHT 153 | 154 | weight_nums, bias_nums = [], [] 155 | for l in range(self.num_layers): 156 | if l == 0: 157 | if not self.disable_rel_coords: 158 | weight_nums.append((self.in_channels + 2) * self.channels) 159 | else: 160 | weight_nums.append(self.in_channels * self.channels) 161 | bias_nums.append(self.channels) 162 | elif l == self.num_layers - 1: 163 | weight_nums.append(self.channels * 1) 164 | bias_nums.append(1) 165 | else: 166 | weight_nums.append(self.channels * self.channels) 167 | bias_nums.append(self.channels) 168 | 169 | self.weight_nums = weight_nums 170 | self.bias_nums = bias_nums 171 | self.num_gen_params = sum(weight_nums) + sum(bias_nums) 172 | 173 | self.register_buffer("_iter", torch.zeros([1])) 174 | 175 | def mask_heads_forward(self, features, weights, biases, num_insts): 176 | ''' 177 | :param features 178 | :param weights: [w0, w1, ...] 179 | :param bias: [b0, b1, ...] 180 | :return: 181 | ''' 182 | assert features.dim() == 4 183 | n_layers = len(weights) 184 | x = features 185 | for i, (w, b) in enumerate(zip(weights, biases)): 186 | x = F.conv2d( 187 | x, w, bias=b, 188 | stride=1, padding=0, 189 | groups=num_insts 190 | ) 191 | if i < n_layers - 1: 192 | x = F.relu(x) 193 | return x 194 | 195 | def mask_heads_forward_with_coords( 196 | self, mask_feats, mask_feat_stride, instances 197 | ): 198 | locations = compute_locations( 199 | mask_feats.size(2), mask_feats.size(3), 200 | stride=mask_feat_stride, device=mask_feats.device 201 | ) 202 | n_inst = len(instances) 203 | 204 | im_inds = instances.im_inds 205 | mask_head_params = instances.mask_head_params 206 | 207 | N, _, H, W = mask_feats.size() 208 | 209 | if not self.disable_rel_coords: 210 | instance_locations = instances.locations 211 | relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) 212 | relative_coords = relative_coords.permute(0, 2, 1).float() 213 | soi = self.sizes_of_interest.float()[instances.fpn_levels] 214 | relative_coords = relative_coords / soi.reshape(-1, 1, 1) 215 | relative_coords = relative_coords.to(dtype=mask_feats.dtype) 216 | 217 | mask_head_inputs = torch.cat([ 218 | relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 219 | ], dim=1) 220 | else: 221 | mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) 222 | 223 | mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) 224 | 225 | weights, biases = parse_dynamic_params( 226 | mask_head_params, self.channels, 227 | self.weight_nums, self.bias_nums 228 | ) 229 | 230 | mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) 231 | 232 | mask_logits = mask_logits.reshape(-1, 1, H, W) 233 | 234 | assert mask_feat_stride >= self.mask_out_stride 235 | assert mask_feat_stride % self.mask_out_stride == 0 236 | mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) 237 | 238 | return mask_logits 239 | 240 | def __call__(self, mask_feats, mask_feat_stride, pred_instances, gt_instances=None): 241 | if self.training: 242 | self._iter += 1 243 | 244 | gt_inds = pred_instances.gt_inds 245 | if self.boxinst_enabled: 246 | gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) 247 | gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype) 248 | 249 | losses = {} 250 | 251 | dummy_loss = mask_feats.sum() * 0 + pred_instances.mask_head_params.sum() * 0 252 | if not self.boxinst_enabled and not self.point_sup_enabled: 253 | # fully-supervised 254 | losses["loss_mask"] = dummy_loss 255 | else: 256 | # BoxInst and/or PointSup settings 257 | if self.boxinst_enabled: 258 | losses["loss_prj"] = dummy_loss 259 | losses["loss_pairwise"] = dummy_loss 260 | if self.point_sup_enabled: 261 | losses["loss_point"] = dummy_loss 262 | 263 | preds = () 264 | if len(pred_instances) != 0: 265 | mask_logits = self.mask_heads_forward_with_coords( 266 | mask_feats, mask_feat_stride, pred_instances 267 | ) 268 | mask_scores = mask_logits.sigmoid() 269 | 270 | resized_im_h, resized_im_w = gt_instances[0].image_size 271 | pred_global_masks = aligned_bilinear( 272 | mask_scores, int(self.mask_out_stride) 273 | ) 274 | pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w] 275 | 276 | preds = (gt_inds, pred_global_masks) 277 | 278 | return losses, [preds] 279 | 280 | # if not self.boxinst_enabled and not self.point_sup_enabled: 281 | # # fully-supervised CondInst losses 282 | # mask_losses = dice_coefficient(mask_scores, gt_bitmasks) 283 | # loss_mask = mask_losses.mean() 284 | # losses["loss_mask"] = loss_mask 285 | # else: 286 | # # BoxInst and/or PointSup losses 287 | # if self.boxinst_enabled: 288 | # # box-supervised BoxInst losses 289 | # image_color_similarity = torch.cat([x.image_color_similarity for x in gt_instances]) 290 | # image_color_similarity = image_color_similarity[gt_inds].to(dtype=mask_feats.dtype) 291 | 292 | # loss_prj_term = compute_project_term(mask_scores, gt_bitmasks) 293 | 294 | # pairwise_losses = compute_pairwise_term( 295 | # mask_logits, self.pairwise_size, 296 | # self.pairwise_dilation 297 | # ) 298 | 299 | # weights = (image_color_similarity >= self.pairwise_color_thresh).float() * gt_bitmasks.float() 300 | # loss_pairwise = (pairwise_losses * weights).sum() / weights.sum().clamp(min=1.0) 301 | 302 | # warmup_factor = min(self._iter.item() / float(self._warmup_iters), 1.0) 303 | # loss_pairwise = loss_pairwise * warmup_factor 304 | 305 | # losses.update({ 306 | # "loss_prj": loss_prj_term, 307 | # "loss_pairwise": loss_pairwise, 308 | # }) 309 | # if self.point_sup_enabled: 310 | # # pointly-supervised CondInst losses 311 | # image_wh = torch.Tensor([mask_logits.size(3) * self.mask_out_stride, mask_logits.size(2) * self.mask_out_stride]).to(mask_logits.device) 312 | # point_coords, point_labels = get_point_coords_from_point_annotation(gt_instances, gt_inds, image_wh) 313 | 314 | # point_logits = point_sample( 315 | # mask_logits, 316 | # point_coords, 317 | # mode='bilinear', 318 | # align_corners=False, 319 | # ) 320 | 321 | # loss_point = roi_mask_point_loss(point_logits, gt_instances, point_labels) 322 | 323 | # losses.update({ 324 | # "loss_point": loss_point * self.point_loss_weight, 325 | # }) 326 | # return losses 327 | else: 328 | if len(pred_instances) > 0: 329 | mask_logits = self.mask_heads_forward_with_coords( 330 | mask_feats, mask_feat_stride, pred_instances 331 | ) 332 | pred_instances.pred_global_masks = mask_logits.sigmoid() 333 | 334 | return pred_instances 335 | -------------------------------------------------------------------------------- /src/condinst/condinst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from skimage import color 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from detectron2.structures import ImageList 10 | from detectron2.modeling.proposal_generator import build_proposal_generator 11 | from detectron2.modeling.backbone import build_backbone 12 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 13 | from detectron2.structures.instances import Instances 14 | from detectron2.structures.masks import PolygonMasks, polygons_to_bitmask 15 | 16 | from .dynamic_mask_head import build_dynamic_mask_head 17 | from .mask_branch import build_mask_branch 18 | 19 | from adet.utils.comm import aligned_bilinear 20 | 21 | __all__ = ["CondInst"] 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def unfold_wo_center(x, kernel_size, dilation): 28 | assert x.dim() == 4 29 | assert kernel_size % 2 == 1 30 | 31 | # using SAME padding 32 | padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 33 | unfolded_x = F.unfold( 34 | x, kernel_size=kernel_size, 35 | padding=padding, 36 | dilation=dilation 37 | ) 38 | 39 | unfolded_x = unfolded_x.reshape( 40 | x.size(0), x.size(1), -1, x.size(2), x.size(3) 41 | ) 42 | 43 | # remove the center pixels 44 | size = kernel_size ** 2 45 | unfolded_x = torch.cat(( 46 | unfolded_x[:, :, :size // 2], 47 | unfolded_x[:, :, size // 2 + 1:] 48 | ), dim=2) 49 | 50 | return unfolded_x 51 | 52 | 53 | def get_images_color_similarity(images, image_masks, kernel_size, dilation): 54 | assert images.dim() == 4 55 | assert images.size(0) == 1 56 | 57 | unfolded_images = unfold_wo_center( 58 | images, kernel_size=kernel_size, dilation=dilation 59 | ) 60 | 61 | diff = images[:, :, None] - unfolded_images 62 | similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5) 63 | 64 | unfolded_weights = unfold_wo_center( 65 | image_masks[None, None], kernel_size=kernel_size, 66 | dilation=dilation 67 | ) 68 | unfolded_weights = torch.max(unfolded_weights, dim=1)[0] 69 | 70 | return similarity * unfolded_weights 71 | 72 | 73 | @META_ARCH_REGISTRY.register() 74 | class CondInst(nn.Module): 75 | """ 76 | Main class for CondInst architectures (see https://arxiv.org/abs/2003.05664). 77 | """ 78 | 79 | def __init__(self, cfg): 80 | super().__init__() 81 | self.device = torch.device(cfg.MODEL.DEVICE) 82 | 83 | self.backbone = build_backbone(cfg) 84 | self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) 85 | self.mask_head = build_dynamic_mask_head(cfg) 86 | self.mask_branch = build_mask_branch(cfg, self.backbone.output_shape()) 87 | 88 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 89 | 90 | self.max_proposals = cfg.MODEL.CONDINST.MAX_PROPOSALS 91 | self.topk_proposals_per_im = cfg.MODEL.CONDINST.TOPK_PROPOSALS_PER_IM 92 | 93 | # boxinst configs 94 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 95 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 96 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 97 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 98 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 99 | 100 | # pointsup configs 101 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 102 | 103 | # build top module 104 | in_channels = self.proposal_generator.in_channels_to_top_module 105 | 106 | self.controller = nn.Conv2d( 107 | in_channels, self.mask_head.num_gen_params, 108 | kernel_size=3, stride=1, padding=1 109 | ) 110 | torch.nn.init.normal_(self.controller.weight, std=0.01) 111 | torch.nn.init.constant_(self.controller.bias, 0) 112 | 113 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 114 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 115 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 116 | self.to(self.device) 117 | 118 | def forward(self, batched_inputs): 119 | original_images = [x["image"].to(self.device) for x in batched_inputs] 120 | 121 | # normalize images 122 | images_norm = [self.normalizer(x) for x in original_images] 123 | images_norm = ImageList.from_tensors(images_norm, self.backbone.size_divisibility) 124 | 125 | features = self.backbone(images_norm.tensor) 126 | 127 | if "instances" in batched_inputs[0]: 128 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 129 | 130 | if self.boxinst_enabled: 131 | original_image_masks = [torch.ones_like(x[0], dtype=torch.float32) for x in original_images] 132 | 133 | # mask out the bottom area where the COCO dataset probably has wrong annotations 134 | for i in range(len(original_image_masks)): 135 | im_h = batched_inputs[i]["height"] 136 | pixels_removed = int( 137 | self.bottom_pixels_removed * 138 | float(original_images[i].size(1)) / float(im_h) 139 | ) 140 | if pixels_removed > 0: 141 | original_image_masks[i][-pixels_removed:, :] = 0 142 | 143 | original_images = ImageList.from_tensors(original_images, self.backbone.size_divisibility) 144 | original_image_masks = ImageList.from_tensors( 145 | original_image_masks, self.backbone.size_divisibility, pad_value=0.0 146 | ) 147 | self.add_bitmasks_from_boxes( 148 | gt_instances, original_images.tensor, original_image_masks.tensor, 149 | original_images.tensor.size(-2), original_images.tensor.size(-1) 150 | ) 151 | else: 152 | gt_instances = None 153 | 154 | mask_feats, sem_losses = self.mask_branch(features, gt_instances) 155 | 156 | proposals, proposal_losses = self.proposal_generator( 157 | images_norm, features, gt_instances, self.controller 158 | ) 159 | 160 | if self.training: 161 | mask_losses = self._forward_mask_heads_train(proposals, mask_feats, gt_instances) 162 | 163 | losses = {} 164 | losses.update(sem_losses) 165 | losses.update(proposal_losses) 166 | losses.update(mask_losses) 167 | return losses 168 | else: 169 | pred_instances_w_masks = self._forward_mask_heads_test(proposals, mask_feats) 170 | 171 | padded_im_h, padded_im_w = images_norm.tensor.size()[-2:] 172 | processed_results = [] 173 | for im_id, (input_per_image, image_size) in enumerate(zip(batched_inputs, images_norm.image_sizes)): 174 | height = input_per_image.get("height", image_size[0]) 175 | width = input_per_image.get("width", image_size[1]) 176 | 177 | instances_per_im = pred_instances_w_masks[pred_instances_w_masks.im_inds == im_id] 178 | instances_per_im = self.postprocess( 179 | instances_per_im, height, width, 180 | padded_im_h, padded_im_w 181 | ) 182 | 183 | processed_results.append({ 184 | "instances": instances_per_im 185 | }) 186 | 187 | return processed_results 188 | 189 | def _forward_mask_heads_train(self, proposals, mask_feats, gt_instances): 190 | # prepare the inputs for mask heads 191 | pred_instances = proposals["instances"] 192 | 193 | assert (self.max_proposals == -1) or (self.topk_proposals_per_im == -1), \ 194 | "MAX_PROPOSALS and TOPK_PROPOSALS_PER_IM cannot be used at the same time." 195 | if self.max_proposals != -1: 196 | if self.max_proposals < len(pred_instances): 197 | inds = torch.randperm(len(pred_instances), device=mask_feats.device).long() 198 | logger.info("clipping proposals from {} to {}".format( 199 | len(pred_instances), self.max_proposals 200 | )) 201 | pred_instances = pred_instances[inds[:self.max_proposals]] 202 | elif self.topk_proposals_per_im != -1: 203 | num_images = len(gt_instances) 204 | 205 | kept_instances = [] 206 | for im_id in range(num_images): 207 | instances_per_im = pred_instances[pred_instances.im_inds == im_id] 208 | if len(instances_per_im) == 0: 209 | kept_instances.append(instances_per_im) 210 | continue 211 | 212 | unique_gt_inds = instances_per_im.gt_inds.unique() 213 | num_instances_per_gt = max(int(self.topk_proposals_per_im / len(unique_gt_inds)), 1) 214 | 215 | for gt_ind in unique_gt_inds: 216 | instances_per_gt = instances_per_im[instances_per_im.gt_inds == gt_ind] 217 | 218 | if len(instances_per_gt) > num_instances_per_gt: 219 | scores = instances_per_gt.logits_pred.sigmoid().max(dim=1)[0] 220 | ctrness_pred = instances_per_gt.ctrness_pred.sigmoid() 221 | inds = (scores * ctrness_pred).topk(k=num_instances_per_gt, dim=0)[1] 222 | instances_per_gt = instances_per_gt[inds] 223 | 224 | kept_instances.append(instances_per_gt) 225 | 226 | pred_instances = Instances.cat(kept_instances) 227 | 228 | pred_instances.mask_head_params = pred_instances.top_feats 229 | 230 | loss_mask = self.mask_head( 231 | mask_feats, self.mask_branch.out_stride, 232 | pred_instances, gt_instances 233 | ) 234 | 235 | return loss_mask 236 | 237 | def _forward_mask_heads_test(self, proposals, mask_feats): 238 | # prepare the inputs for mask heads 239 | for im_id, per_im in enumerate(proposals): 240 | per_im.im_inds = per_im.locations.new_ones(len(per_im), dtype=torch.long) * im_id 241 | pred_instances = Instances.cat(proposals) 242 | pred_instances.mask_head_params = pred_instances.top_feat 243 | 244 | pred_instances_w_masks = self.mask_head( 245 | mask_feats, self.mask_branch.out_stride, pred_instances 246 | ) 247 | 248 | return pred_instances_w_masks 249 | 250 | def add_bitmasks(self, instances, im_h, im_w): 251 | for per_im_gt_inst in instances: 252 | if not per_im_gt_inst.has("gt_masks"): 253 | continue 254 | start = int(self.mask_out_stride // 2) 255 | if isinstance(per_im_gt_inst.get("gt_masks"), PolygonMasks): 256 | polygons = per_im_gt_inst.get("gt_masks").polygons 257 | per_im_bitmasks = [] 258 | per_im_bitmasks_full = [] 259 | for per_polygons in polygons: 260 | bitmask = polygons_to_bitmask(per_polygons, im_h, im_w) 261 | bitmask = torch.from_numpy(bitmask).to(self.device).float() 262 | start = int(self.mask_out_stride // 2) 263 | bitmask_full = bitmask.clone() 264 | bitmask = bitmask[start::self.mask_out_stride, start::self.mask_out_stride] 265 | 266 | assert bitmask.size(0) * self.mask_out_stride == im_h 267 | assert bitmask.size(1) * self.mask_out_stride == im_w 268 | 269 | per_im_bitmasks.append(bitmask) 270 | per_im_bitmasks_full.append(bitmask_full) 271 | 272 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 273 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 274 | else: # RLE format bitmask 275 | bitmasks = per_im_gt_inst.get("gt_masks").tensor 276 | h, w = bitmasks.size()[1:] 277 | # pad to new size 278 | bitmasks_full = F.pad(bitmasks, (0, im_w - w, 0, im_h - h), "constant", 0) 279 | bitmasks = bitmasks_full[:, start::self.mask_out_stride, start::self.mask_out_stride] 280 | per_im_gt_inst.gt_bitmasks = bitmasks 281 | per_im_gt_inst.gt_bitmasks_full = bitmasks_full 282 | 283 | def add_bitmasks_from_boxes(self, instances, images, image_masks, im_h, im_w): 284 | stride = self.mask_out_stride 285 | start = int(stride // 2) 286 | 287 | assert images.size(2) % stride == 0 288 | assert images.size(3) % stride == 0 289 | 290 | downsampled_images = F.avg_pool2d( 291 | images.float(), kernel_size=stride, 292 | stride=stride, padding=0 293 | )[:, [2, 1, 0]] 294 | image_masks = image_masks[:, start::stride, start::stride] 295 | 296 | for im_i, per_im_gt_inst in enumerate(instances): 297 | images_lab = color.rgb2lab(downsampled_images[im_i].byte().permute(1, 2, 0).cpu().numpy()) 298 | images_lab = torch.as_tensor(images_lab, device=downsampled_images.device, dtype=torch.float32) 299 | images_lab = images_lab.permute(2, 0, 1)[None] 300 | images_color_similarity = get_images_color_similarity( 301 | images_lab, image_masks[im_i], 302 | self.pairwise_size, self.pairwise_dilation 303 | ) 304 | 305 | per_im_boxes = per_im_gt_inst.gt_boxes.tensor 306 | per_im_bitmasks = [] 307 | per_im_bitmasks_full = [] 308 | for per_box in per_im_boxes: 309 | bitmask_full = torch.zeros((im_h, im_w)).to(self.device).float() 310 | bitmask_full[int(per_box[1]):int(per_box[3] + 1), int(per_box[0]):int(per_box[2] + 1)] = 1.0 311 | 312 | bitmask = bitmask_full[start::stride, start::stride] 313 | 314 | assert bitmask.size(0) * stride == im_h 315 | assert bitmask.size(1) * stride == im_w 316 | 317 | per_im_bitmasks.append(bitmask) 318 | per_im_bitmasks_full.append(bitmask_full) 319 | 320 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 321 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 322 | per_im_gt_inst.image_color_similarity = torch.cat([ 323 | images_color_similarity for _ in range(len(per_im_gt_inst)) 324 | ], dim=0) 325 | 326 | def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5): 327 | """ 328 | Resize the output instances. 329 | The input images are often resized when entering an object detector. 330 | As a result, we often need the outputs of the detector in a different 331 | resolution from its inputs. 332 | This function will resize the raw outputs of an R-CNN detector 333 | to produce outputs according to the desired output resolution. 334 | Args: 335 | results (Instances): the raw outputs from the detector. 336 | `results.image_size` contains the input image resolution the detector sees. 337 | This object might be modified in-place. 338 | output_height, output_width: the desired output resolution. 339 | Returns: 340 | Instances: the resized output from the model, based on the output resolution 341 | """ 342 | scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) 343 | resized_im_h, resized_im_w = results.image_size 344 | results = Instances((output_height, output_width), **results.get_fields()) 345 | 346 | if results.has("pred_boxes"): 347 | output_boxes = results.pred_boxes 348 | elif results.has("proposal_boxes"): 349 | output_boxes = results.proposal_boxes 350 | 351 | output_boxes.scale(scale_x, scale_y) 352 | output_boxes.clip(results.image_size) 353 | 354 | results = results[output_boxes.nonempty()] 355 | 356 | if results.has("pred_global_masks"): 357 | mask_h, mask_w = results.pred_global_masks.size()[-2:] 358 | factor_h = padded_im_h // mask_h 359 | factor_w = padded_im_w // mask_w 360 | assert factor_h == factor_w 361 | factor = factor_h 362 | pred_global_masks = aligned_bilinear( 363 | results.pred_global_masks, factor 364 | ) 365 | pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w] 366 | pred_global_masks = F.interpolate( 367 | pred_global_masks, 368 | size=(output_height, output_width), 369 | mode="bilinear", align_corners=False 370 | ) 371 | pred_global_masks = pred_global_masks[:, 0, :, :] 372 | 373 | if self.point_sup_enabled: 374 | # filter out any mask prediction outside of predicted boxes (see PointSup) 375 | pred_boxes = results.pred_boxes.tensor 376 | for i in range(pred_global_masks.size(0)): 377 | kept_mask = torch.zeros_like(pred_global_masks[0]).to(pred_boxes.device) 378 | x0,y0,x1,y1 = int(pred_boxes[i][0]),int(pred_boxes[i][1]),int(pred_boxes[i][2]),int(pred_boxes[i][3]) 379 | kept_mask[y0:y1, x0:x1] = 1 380 | pred_global_masks[i] *= kept_mask 381 | 382 | results.pred_masks = (pred_global_masks > mask_threshold).float() 383 | 384 | return results 385 | -------------------------------------------------------------------------------- /src/condinst/standard/condinst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from skimage import color 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from detectron2.structures import ImageList 10 | from detectron2.modeling.proposal_generator import build_proposal_generator 11 | from detectron2.modeling.backbone import build_backbone 12 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 13 | from detectron2.structures.instances import Instances 14 | from detectron2.structures.masks import PolygonMasks, polygons_to_bitmask 15 | 16 | from .dynamic_mask_head import build_dynamic_mask_head 17 | from .mask_branch import build_mask_branch 18 | 19 | from adet.utils.comm import aligned_bilinear 20 | 21 | __all__ = ["CondInst"] 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def unfold_wo_center(x, kernel_size, dilation): 28 | assert x.dim() == 4 29 | assert kernel_size % 2 == 1 30 | 31 | # using SAME padding 32 | padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 33 | unfolded_x = F.unfold( 34 | x, kernel_size=kernel_size, 35 | padding=padding, 36 | dilation=dilation 37 | ) 38 | 39 | unfolded_x = unfolded_x.reshape( 40 | x.size(0), x.size(1), -1, x.size(2), x.size(3) 41 | ) 42 | 43 | # remove the center pixels 44 | size = kernel_size ** 2 45 | unfolded_x = torch.cat(( 46 | unfolded_x[:, :, :size // 2], 47 | unfolded_x[:, :, size // 2 + 1:] 48 | ), dim=2) 49 | 50 | return unfolded_x 51 | 52 | 53 | def get_images_color_similarity(images, image_masks, kernel_size, dilation): 54 | assert images.dim() == 4 55 | assert images.size(0) == 1 56 | 57 | unfolded_images = unfold_wo_center( 58 | images, kernel_size=kernel_size, dilation=dilation 59 | ) 60 | 61 | diff = images[:, :, None] - unfolded_images 62 | similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5) 63 | 64 | unfolded_weights = unfold_wo_center( 65 | image_masks[None, None], kernel_size=kernel_size, 66 | dilation=dilation 67 | ) 68 | unfolded_weights = torch.max(unfolded_weights, dim=1)[0] 69 | 70 | return similarity * unfolded_weights 71 | 72 | 73 | @META_ARCH_REGISTRY.register() 74 | class CondInst(nn.Module): 75 | """ 76 | Main class for CondInst architectures (see https://arxiv.org/abs/2003.05664). 77 | """ 78 | 79 | def __init__(self, cfg): 80 | super().__init__() 81 | self.device = torch.device(cfg.MODEL.DEVICE) 82 | 83 | self.backbone = build_backbone(cfg) 84 | self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) 85 | self.mask_head = build_dynamic_mask_head(cfg) 86 | self.mask_branch = build_mask_branch(cfg, self.backbone.output_shape()) 87 | 88 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 89 | 90 | self.max_proposals = cfg.MODEL.CONDINST.MAX_PROPOSALS 91 | self.topk_proposals_per_im = cfg.MODEL.CONDINST.TOPK_PROPOSALS_PER_IM 92 | 93 | # boxinst configs 94 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 95 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 96 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 97 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 98 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 99 | 100 | # pointsup configs 101 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 102 | 103 | # build top module 104 | in_channels = self.proposal_generator.in_channels_to_top_module 105 | 106 | self.controller = nn.Conv2d( 107 | in_channels, self.mask_head.num_gen_params, 108 | kernel_size=3, stride=1, padding=1 109 | ) 110 | torch.nn.init.normal_(self.controller.weight, std=0.01) 111 | torch.nn.init.constant_(self.controller.bias, 0) 112 | 113 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 114 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 115 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 116 | self.to(self.device) 117 | 118 | def forward(self, batched_inputs): 119 | original_images = [x["image"].to(self.device) for x in batched_inputs] 120 | 121 | # normalize images 122 | images_norm = [self.normalizer(x) for x in original_images] 123 | images_norm = ImageList.from_tensors(images_norm, self.backbone.size_divisibility) 124 | 125 | features = self.backbone(images_norm.tensor) 126 | 127 | if "instances" in batched_inputs[0]: 128 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 129 | 130 | if self.boxinst_enabled: 131 | original_image_masks = [torch.ones_like(x[0], dtype=torch.float32) for x in original_images] 132 | 133 | # mask out the bottom area where the COCO dataset probably has wrong annotations 134 | for i in range(len(original_image_masks)): 135 | im_h = batched_inputs[i]["height"] 136 | pixels_removed = int( 137 | self.bottom_pixels_removed * 138 | float(original_images[i].size(1)) / float(im_h) 139 | ) 140 | if pixels_removed > 0: 141 | original_image_masks[i][-pixels_removed:, :] = 0 142 | 143 | original_images = ImageList.from_tensors(original_images, self.backbone.size_divisibility) 144 | original_image_masks = ImageList.from_tensors( 145 | original_image_masks, self.backbone.size_divisibility, pad_value=0.0 146 | ) 147 | self.add_bitmasks_from_boxes( 148 | gt_instances, original_images.tensor, original_image_masks.tensor, 149 | original_images.tensor.size(-2), original_images.tensor.size(-1) 150 | ) 151 | else: 152 | gt_instances = None 153 | 154 | mask_feats, sem_losses = self.mask_branch(features, gt_instances) 155 | 156 | proposals, proposal_losses = self.proposal_generator( 157 | images_norm, features, gt_instances, self.controller 158 | ) 159 | 160 | if self.training: 161 | mask_losses = self._forward_mask_heads_train(proposals, mask_feats, gt_instances) 162 | 163 | losses = {} 164 | losses.update(sem_losses) 165 | losses.update(proposal_losses) 166 | losses.update(mask_losses) 167 | return losses 168 | else: 169 | pred_instances_w_masks = self._forward_mask_heads_test(proposals, mask_feats) 170 | 171 | padded_im_h, padded_im_w = images_norm.tensor.size()[-2:] 172 | processed_results = [] 173 | for im_id, (input_per_image, image_size) in enumerate(zip(batched_inputs, images_norm.image_sizes)): 174 | height = input_per_image.get("height", image_size[0]) 175 | width = input_per_image.get("width", image_size[1]) 176 | 177 | instances_per_im = pred_instances_w_masks[pred_instances_w_masks.im_inds == im_id] 178 | instances_per_im = self.postprocess( 179 | instances_per_im, height, width, 180 | padded_im_h, padded_im_w 181 | ) 182 | 183 | processed_results.append({ 184 | "instances": instances_per_im 185 | }) 186 | 187 | return processed_results 188 | 189 | def _forward_mask_heads_train(self, proposals, mask_feats, gt_instances): 190 | # prepare the inputs for mask heads 191 | pred_instances = proposals["instances"] 192 | 193 | assert (self.max_proposals == -1) or (self.topk_proposals_per_im == -1), \ 194 | "MAX_PROPOSALS and TOPK_PROPOSALS_PER_IM cannot be used at the same time." 195 | if self.max_proposals != -1: 196 | if self.max_proposals < len(pred_instances): 197 | inds = torch.randperm(len(pred_instances), device=mask_feats.device).long() 198 | logger.info("clipping proposals from {} to {}".format( 199 | len(pred_instances), self.max_proposals 200 | )) 201 | pred_instances = pred_instances[inds[:self.max_proposals]] 202 | elif self.topk_proposals_per_im != -1: 203 | num_images = len(gt_instances) 204 | 205 | kept_instances = [] 206 | for im_id in range(num_images): 207 | instances_per_im = pred_instances[pred_instances.im_inds == im_id] 208 | if len(instances_per_im) == 0: 209 | kept_instances.append(instances_per_im) 210 | continue 211 | 212 | unique_gt_inds = instances_per_im.gt_inds.unique() 213 | num_instances_per_gt = max(int(self.topk_proposals_per_im / len(unique_gt_inds)), 1) 214 | 215 | for gt_ind in unique_gt_inds: 216 | instances_per_gt = instances_per_im[instances_per_im.gt_inds == gt_ind] 217 | 218 | if len(instances_per_gt) > num_instances_per_gt: 219 | scores = instances_per_gt.logits_pred.sigmoid().max(dim=1)[0] 220 | ctrness_pred = instances_per_gt.ctrness_pred.sigmoid() 221 | inds = (scores * ctrness_pred).topk(k=num_instances_per_gt, dim=0)[1] 222 | instances_per_gt = instances_per_gt[inds] 223 | 224 | kept_instances.append(instances_per_gt) 225 | 226 | pred_instances = Instances.cat(kept_instances) 227 | 228 | pred_instances.mask_head_params = pred_instances.top_feats 229 | 230 | loss_mask = self.mask_head( 231 | mask_feats, self.mask_branch.out_stride, 232 | pred_instances, gt_instances 233 | ) 234 | 235 | return loss_mask 236 | 237 | def _forward_mask_heads_test(self, proposals, mask_feats): 238 | # prepare the inputs for mask heads 239 | for im_id, per_im in enumerate(proposals): 240 | per_im.im_inds = per_im.locations.new_ones(len(per_im), dtype=torch.long) * im_id 241 | pred_instances = Instances.cat(proposals) 242 | pred_instances.mask_head_params = pred_instances.top_feat 243 | 244 | pred_instances_w_masks = self.mask_head( 245 | mask_feats, self.mask_branch.out_stride, pred_instances 246 | ) 247 | 248 | return pred_instances_w_masks 249 | 250 | def add_bitmasks(self, instances, im_h, im_w): 251 | for per_im_gt_inst in instances: 252 | if not per_im_gt_inst.has("gt_masks"): 253 | continue 254 | start = int(self.mask_out_stride // 2) 255 | if isinstance(per_im_gt_inst.get("gt_masks"), PolygonMasks): 256 | polygons = per_im_gt_inst.get("gt_masks").polygons 257 | per_im_bitmasks = [] 258 | per_im_bitmasks_full = [] 259 | for per_polygons in polygons: 260 | bitmask = polygons_to_bitmask(per_polygons, im_h, im_w) 261 | bitmask = torch.from_numpy(bitmask).to(self.device).float() 262 | start = int(self.mask_out_stride // 2) 263 | bitmask_full = bitmask.clone() 264 | bitmask = bitmask[start::self.mask_out_stride, start::self.mask_out_stride] 265 | 266 | assert bitmask.size(0) * self.mask_out_stride == im_h 267 | assert bitmask.size(1) * self.mask_out_stride == im_w 268 | 269 | per_im_bitmasks.append(bitmask) 270 | per_im_bitmasks_full.append(bitmask_full) 271 | 272 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 273 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 274 | else: # RLE format bitmask 275 | bitmasks = per_im_gt_inst.get("gt_masks").tensor 276 | h, w = bitmasks.size()[1:] 277 | # pad to new size 278 | bitmasks_full = F.pad(bitmasks, (0, im_w - w, 0, im_h - h), "constant", 0) 279 | bitmasks = bitmasks_full[:, start::self.mask_out_stride, start::self.mask_out_stride] 280 | per_im_gt_inst.gt_bitmasks = bitmasks 281 | per_im_gt_inst.gt_bitmasks_full = bitmasks_full 282 | 283 | def add_bitmasks_from_boxes(self, instances, images, image_masks, im_h, im_w): 284 | stride = self.mask_out_stride 285 | start = int(stride // 2) 286 | 287 | assert images.size(2) % stride == 0 288 | assert images.size(3) % stride == 0 289 | 290 | downsampled_images = F.avg_pool2d( 291 | images.float(), kernel_size=stride, 292 | stride=stride, padding=0 293 | )[:, [2, 1, 0]] 294 | image_masks = image_masks[:, start::stride, start::stride] 295 | 296 | for im_i, per_im_gt_inst in enumerate(instances): 297 | images_lab = color.rgb2lab(downsampled_images[im_i].byte().permute(1, 2, 0).cpu().numpy()) 298 | images_lab = torch.as_tensor(images_lab, device=downsampled_images.device, dtype=torch.float32) 299 | images_lab = images_lab.permute(2, 0, 1)[None] 300 | images_color_similarity = get_images_color_similarity( 301 | images_lab, image_masks[im_i], 302 | self.pairwise_size, self.pairwise_dilation 303 | ) 304 | 305 | per_im_boxes = per_im_gt_inst.gt_boxes.tensor 306 | per_im_bitmasks = [] 307 | per_im_bitmasks_full = [] 308 | for per_box in per_im_boxes: 309 | bitmask_full = torch.zeros((im_h, im_w)).to(self.device).float() 310 | bitmask_full[int(per_box[1]):int(per_box[3] + 1), int(per_box[0]):int(per_box[2] + 1)] = 1.0 311 | 312 | bitmask = bitmask_full[start::stride, start::stride] 313 | 314 | assert bitmask.size(0) * stride == im_h 315 | assert bitmask.size(1) * stride == im_w 316 | 317 | per_im_bitmasks.append(bitmask) 318 | per_im_bitmasks_full.append(bitmask_full) 319 | 320 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 321 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 322 | per_im_gt_inst.image_color_similarity = torch.cat([ 323 | images_color_similarity for _ in range(len(per_im_gt_inst)) 324 | ], dim=0) 325 | 326 | def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5): 327 | """ 328 | Resize the output instances. 329 | The input images are often resized when entering an object detector. 330 | As a result, we often need the outputs of the detector in a different 331 | resolution from its inputs. 332 | This function will resize the raw outputs of an R-CNN detector 333 | to produce outputs according to the desired output resolution. 334 | Args: 335 | results (Instances): the raw outputs from the detector. 336 | `results.image_size` contains the input image resolution the detector sees. 337 | This object might be modified in-place. 338 | output_height, output_width: the desired output resolution. 339 | Returns: 340 | Instances: the resized output from the model, based on the output resolution 341 | """ 342 | scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) 343 | resized_im_h, resized_im_w = results.image_size 344 | results = Instances((output_height, output_width), **results.get_fields()) 345 | 346 | if results.has("pred_boxes"): 347 | output_boxes = results.pred_boxes 348 | elif results.has("proposal_boxes"): 349 | output_boxes = results.proposal_boxes 350 | 351 | output_boxes.scale(scale_x, scale_y) 352 | output_boxes.clip(results.image_size) 353 | 354 | results = results[output_boxes.nonempty()] 355 | 356 | if results.has("pred_global_masks"): 357 | mask_h, mask_w = results.pred_global_masks.size()[-2:] 358 | factor_h = padded_im_h // mask_h 359 | factor_w = padded_im_w // mask_w 360 | assert factor_h == factor_w 361 | factor = factor_h 362 | pred_global_masks = aligned_bilinear( 363 | results.pred_global_masks, factor 364 | ) 365 | pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w] 366 | pred_global_masks = F.interpolate( 367 | pred_global_masks, 368 | size=(output_height, output_width), 369 | mode="bilinear", align_corners=False 370 | ) 371 | pred_global_masks = pred_global_masks[:, 0, :, :] 372 | 373 | if self.point_sup_enabled: 374 | # filter out any mask prediction outside of predicted boxes (see PointSup) 375 | pred_boxes = results.pred_boxes.tensor 376 | for i in range(pred_global_masks.size(0)): 377 | kept_mask = torch.zeros_like(pred_global_masks[0]).to(pred_boxes.device) 378 | x0,y0,x1,y1 = int(pred_boxes[i][0]),int(pred_boxes[i][1]),int(pred_boxes[i][2]),int(pred_boxes[i][3]) 379 | kept_mask[y0:y1, x0:x1] = 1 380 | pred_global_masks[i] *= kept_mask 381 | 382 | results.pred_masks = (pred_global_masks > mask_threshold).float() 383 | 384 | return results 385 | -------------------------------------------------------------------------------- /src/condinst/Entropy/condinst.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from skimage import color 4 | 5 | import os 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from detectron2.structures import ImageList 11 | from detectron2.modeling.proposal_generator import build_proposal_generator 12 | from detectron2.modeling.backbone import build_backbone 13 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 14 | from detectron2.structures.instances import Instances 15 | from detectron2.structures.masks import PolygonMasks, polygons_to_bitmask 16 | 17 | from .dynamic_mask_head import build_dynamic_mask_head 18 | from .mask_branch import build_mask_branch 19 | 20 | from adet.utils.comm import aligned_bilinear 21 | 22 | __all__ = ["CondInst"] 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | import random 28 | EPS = 1e-12 29 | def sampled_points_with_uncertainty(pred_score, ins_box): 30 | x = pred_score.mean(0) 31 | uncertainty = - x * torch.log(x + EPS) - (1 - x) * torch.log(1 - x + EPS) 32 | 33 | x0,y0,x1,y1 = ins_box.tolist() 34 | keep = torch.zeros_like(uncertainty).to(uncertainty.device) 35 | keep[y0:y1, x0:x1] = 1 36 | uncertainty *= keep 37 | 38 | points = (uncertainty==torch.max(uncertainty)).nonzero() 39 | random_idx = random.randint(0, len(points)-1) 40 | point = points[random_idx] 41 | return point 42 | 43 | def unfold_wo_center(x, kernel_size, dilation): 44 | assert x.dim() == 4 45 | assert kernel_size % 2 == 1 46 | 47 | # using SAME padding 48 | padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 49 | unfolded_x = F.unfold( 50 | x, kernel_size=kernel_size, 51 | padding=padding, 52 | dilation=dilation 53 | ) 54 | 55 | unfolded_x = unfolded_x.reshape( 56 | x.size(0), x.size(1), -1, x.size(2), x.size(3) 57 | ) 58 | 59 | # remove the center pixels 60 | size = kernel_size ** 2 61 | unfolded_x = torch.cat(( 62 | unfolded_x[:, :, :size // 2], 63 | unfolded_x[:, :, size // 2 + 1:] 64 | ), dim=2) 65 | 66 | return unfolded_x 67 | 68 | 69 | def get_images_color_similarity(images, image_masks, kernel_size, dilation): 70 | assert images.dim() == 4 71 | assert images.size(0) == 1 72 | 73 | unfolded_images = unfold_wo_center( 74 | images, kernel_size=kernel_size, dilation=dilation 75 | ) 76 | 77 | diff = images[:, :, None] - unfolded_images 78 | similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5) 79 | 80 | unfolded_weights = unfold_wo_center( 81 | image_masks[None, None], kernel_size=kernel_size, 82 | dilation=dilation 83 | ) 84 | unfolded_weights = torch.max(unfolded_weights, dim=1)[0] 85 | 86 | return similarity * unfolded_weights 87 | 88 | 89 | @META_ARCH_REGISTRY.register() 90 | class CondInst(nn.Module): 91 | """ 92 | Main class for CondInst architectures (see https://arxiv.org/abs/2003.05664). 93 | """ 94 | 95 | def __init__(self, cfg): 96 | super().__init__() 97 | self.device = torch.device(cfg.MODEL.DEVICE) 98 | 99 | self.backbone = build_backbone(cfg) 100 | self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) 101 | self.mask_head = build_dynamic_mask_head(cfg) 102 | self.mask_branch = build_mask_branch(cfg, self.backbone.output_shape()) 103 | 104 | self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE 105 | 106 | self.max_proposals = cfg.MODEL.CONDINST.MAX_PROPOSALS 107 | self.topk_proposals_per_im = cfg.MODEL.CONDINST.TOPK_PROPOSALS_PER_IM 108 | 109 | # boxinst configs 110 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED 111 | self.bottom_pixels_removed = cfg.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED 112 | self.pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE 113 | self.pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION 114 | self.pairwise_color_thresh = cfg.MODEL.BOXINST.PAIRWISE.COLOR_THRESH 115 | 116 | # pointsup configs 117 | self.point_sup_enabled = cfg.INPUT.POINT_SUP 118 | 119 | # build top module 120 | in_channels = self.proposal_generator.in_channels_to_top_module 121 | 122 | self.controller = nn.Conv2d( 123 | in_channels, self.mask_head.num_gen_params, 124 | kernel_size=3, stride=1, padding=1 125 | ) 126 | torch.nn.init.normal_(self.controller.weight, std=0.01) 127 | torch.nn.init.constant_(self.controller.bias, 0) 128 | 129 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 130 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 131 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 132 | self.to(self.device) 133 | 134 | def forward(self, batched_inputs): 135 | original_images = [x["image"].to(self.device) for x in batched_inputs] 136 | 137 | # normalize images 138 | images_norm = [self.normalizer(x) for x in original_images] 139 | images_norm = ImageList.from_tensors(images_norm, self.backbone.size_divisibility) 140 | 141 | features = self.backbone(images_norm.tensor) 142 | 143 | if "instances" in batched_inputs[0]: 144 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 145 | 146 | if self.boxinst_enabled: 147 | original_image_masks = [torch.ones_like(x[0], dtype=torch.float32) for x in original_images] 148 | 149 | # mask out the bottom area where the COCO dataset probably has wrong annotations 150 | for i in range(len(original_image_masks)): 151 | im_h = batched_inputs[i]["height"] 152 | pixels_removed = int( 153 | self.bottom_pixels_removed * 154 | float(original_images[i].size(1)) / float(im_h) 155 | ) 156 | if pixels_removed > 0: 157 | original_image_masks[i][-pixels_removed:, :] = 0 158 | 159 | original_images = ImageList.from_tensors(original_images, self.backbone.size_divisibility) 160 | original_image_masks = ImageList.from_tensors( 161 | original_image_masks, self.backbone.size_divisibility, pad_value=0.0 162 | ) 163 | self.add_bitmasks_from_boxes( 164 | gt_instances, original_images.tensor, original_image_masks.tensor, 165 | original_images.tensor.size(-2), original_images.tensor.size(-1) 166 | ) 167 | else: 168 | gt_instances = None 169 | 170 | mask_feats, sem_losses = self.mask_branch(features, gt_instances) 171 | 172 | proposals, proposal_losses = self.proposal_generator( 173 | images_norm, features, gt_instances, self.controller 174 | ) 175 | 176 | if self.training: 177 | mask_losses, preds = self._forward_mask_heads_train(proposals, mask_feats, gt_instances) 178 | 179 | for per_im, per_im_gt_instances, pred in zip(batched_inputs, gt_instances, preds): 180 | if len(per_im_gt_instances) == 0 or len(pred) == 0: 181 | continue 182 | 183 | img_id = per_im['file_name'].split('/')[-1].split('.')[0] 184 | img_h, img_w = per_im['height'], per_im['width'] 185 | resized_h, resized_w = per_im_gt_instances.image_size 186 | if min(resized_h, resized_w) == 800: 187 | factor = min(resized_h, resized_w) * 1.0 / min(img_h, img_w) 188 | else: 189 | factor = max(resized_h, resized_w) * 1.0 / max(img_h, img_w) 190 | 191 | gt_inds, pred_global_masks = pred 192 | pred_global_masks = F.interpolate( 193 | pred_global_masks, 194 | size=(img_h, img_w), 195 | mode="bilinear", align_corners=False 196 | ) 197 | pred_global_masks = pred_global_masks[:, 0, :, :] # original image size 198 | 199 | all_points = [] 200 | for ins_idx in gt_inds.unique(): 201 | pred_score = pred_global_masks[gt_inds==ins_idx] 202 | ins_box = torch.floor(per_im_gt_instances.gt_boxes[ins_idx.item()].tensor[0] / factor).int() 203 | sampled_points = sampled_points_with_uncertainty(pred_score, ins_box) 204 | all_points.append((ins_idx.cpu().data, sampled_points.cpu().data)) 205 | torch.save(all_points, os.path.join(os.getenv('ROOT_PATH'), f'AdelaiDet/points/{img_id}.pt')) 206 | 207 | losses = {} 208 | losses.update(sem_losses) 209 | losses.update(proposal_losses) 210 | losses.update(mask_losses) 211 | return losses 212 | else: 213 | pred_instances_w_masks = self._forward_mask_heads_test(proposals, mask_feats) 214 | 215 | padded_im_h, padded_im_w = images_norm.tensor.size()[-2:] 216 | processed_results = [] 217 | for im_id, (input_per_image, image_size) in enumerate(zip(batched_inputs, images_norm.image_sizes)): 218 | height = input_per_image.get("height", image_size[0]) 219 | width = input_per_image.get("width", image_size[1]) 220 | 221 | instances_per_im = pred_instances_w_masks[pred_instances_w_masks.im_inds == im_id] 222 | instances_per_im = self.postprocess( 223 | instances_per_im, height, width, 224 | padded_im_h, padded_im_w 225 | ) 226 | 227 | processed_results.append({ 228 | "instances": instances_per_im 229 | }) 230 | 231 | return processed_results 232 | 233 | def _forward_mask_heads_train(self, proposals, mask_feats, gt_instances): 234 | # prepare the inputs for mask heads 235 | pred_instances = proposals["instances"] 236 | 237 | assert (self.max_proposals == -1) or (self.topk_proposals_per_im == -1), \ 238 | "MAX_PROPOSALS and TOPK_PROPOSALS_PER_IM cannot be used at the same time." 239 | if self.max_proposals != -1: 240 | if self.max_proposals < len(pred_instances): 241 | inds = torch.randperm(len(pred_instances), device=mask_feats.device).long() 242 | logger.info("clipping proposals from {} to {}".format( 243 | len(pred_instances), self.max_proposals 244 | )) 245 | pred_instances = pred_instances[inds[:self.max_proposals]] 246 | elif self.topk_proposals_per_im != -1: 247 | num_images = len(gt_instances) 248 | 249 | kept_instances = [] 250 | for im_id in range(num_images): 251 | instances_per_im = pred_instances[pred_instances.im_inds == im_id] 252 | if len(instances_per_im) == 0: 253 | kept_instances.append(instances_per_im) 254 | continue 255 | 256 | unique_gt_inds = instances_per_im.gt_inds.unique() 257 | num_instances_per_gt = max(int(self.topk_proposals_per_im / len(unique_gt_inds)), 1) 258 | 259 | for gt_ind in unique_gt_inds: 260 | instances_per_gt = instances_per_im[instances_per_im.gt_inds == gt_ind] 261 | 262 | if len(instances_per_gt) > num_instances_per_gt: 263 | scores = instances_per_gt.logits_pred.sigmoid().max(dim=1)[0] 264 | ctrness_pred = instances_per_gt.ctrness_pred.sigmoid() 265 | inds = (scores * ctrness_pred).topk(k=num_instances_per_gt, dim=0)[1] 266 | instances_per_gt = instances_per_gt[inds] 267 | 268 | kept_instances.append(instances_per_gt) 269 | 270 | pred_instances = Instances.cat(kept_instances) 271 | 272 | pred_instances.mask_head_params = pred_instances.top_feats 273 | 274 | loss_mask = self.mask_head( 275 | mask_feats, self.mask_branch.out_stride, 276 | pred_instances, gt_instances 277 | ) 278 | 279 | return loss_mask 280 | 281 | def _forward_mask_heads_test(self, proposals, mask_feats): 282 | # prepare the inputs for mask heads 283 | for im_id, per_im in enumerate(proposals): 284 | per_im.im_inds = per_im.locations.new_ones(len(per_im), dtype=torch.long) * im_id 285 | pred_instances = Instances.cat(proposals) 286 | pred_instances.mask_head_params = pred_instances.top_feat 287 | 288 | pred_instances_w_masks = self.mask_head( 289 | mask_feats, self.mask_branch.out_stride, pred_instances 290 | ) 291 | 292 | return pred_instances_w_masks 293 | 294 | def add_bitmasks(self, instances, im_h, im_w): 295 | for per_im_gt_inst in instances: 296 | if not per_im_gt_inst.has("gt_masks"): 297 | continue 298 | start = int(self.mask_out_stride // 2) 299 | if isinstance(per_im_gt_inst.get("gt_masks"), PolygonMasks): 300 | polygons = per_im_gt_inst.get("gt_masks").polygons 301 | per_im_bitmasks = [] 302 | per_im_bitmasks_full = [] 303 | for per_polygons in polygons: 304 | bitmask = polygons_to_bitmask(per_polygons, im_h, im_w) 305 | bitmask = torch.from_numpy(bitmask).to(self.device).float() 306 | start = int(self.mask_out_stride // 2) 307 | bitmask_full = bitmask.clone() 308 | bitmask = bitmask[start::self.mask_out_stride, start::self.mask_out_stride] 309 | 310 | assert bitmask.size(0) * self.mask_out_stride == im_h 311 | assert bitmask.size(1) * self.mask_out_stride == im_w 312 | 313 | per_im_bitmasks.append(bitmask) 314 | per_im_bitmasks_full.append(bitmask_full) 315 | 316 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 317 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 318 | else: # RLE format bitmask 319 | bitmasks = per_im_gt_inst.get("gt_masks").tensor 320 | h, w = bitmasks.size()[1:] 321 | # pad to new size 322 | bitmasks_full = F.pad(bitmasks, (0, im_w - w, 0, im_h - h), "constant", 0) 323 | bitmasks = bitmasks_full[:, start::self.mask_out_stride, start::self.mask_out_stride] 324 | per_im_gt_inst.gt_bitmasks = bitmasks 325 | per_im_gt_inst.gt_bitmasks_full = bitmasks_full 326 | 327 | def add_bitmasks_from_boxes(self, instances, images, image_masks, im_h, im_w): 328 | stride = self.mask_out_stride 329 | start = int(stride // 2) 330 | 331 | assert images.size(2) % stride == 0 332 | assert images.size(3) % stride == 0 333 | 334 | downsampled_images = F.avg_pool2d( 335 | images.float(), kernel_size=stride, 336 | stride=stride, padding=0 337 | )[:, [2, 1, 0]] 338 | image_masks = image_masks[:, start::stride, start::stride] 339 | 340 | for im_i, per_im_gt_inst in enumerate(instances): 341 | images_lab = color.rgb2lab(downsampled_images[im_i].byte().permute(1, 2, 0).cpu().numpy()) 342 | images_lab = torch.as_tensor(images_lab, device=downsampled_images.device, dtype=torch.float32) 343 | images_lab = images_lab.permute(2, 0, 1)[None] 344 | images_color_similarity = get_images_color_similarity( 345 | images_lab, image_masks[im_i], 346 | self.pairwise_size, self.pairwise_dilation 347 | ) 348 | 349 | per_im_boxes = per_im_gt_inst.gt_boxes.tensor 350 | per_im_bitmasks = [] 351 | per_im_bitmasks_full = [] 352 | for per_box in per_im_boxes: 353 | bitmask_full = torch.zeros((im_h, im_w)).to(self.device).float() 354 | bitmask_full[int(per_box[1]):int(per_box[3] + 1), int(per_box[0]):int(per_box[2] + 1)] = 1.0 355 | 356 | bitmask = bitmask_full[start::stride, start::stride] 357 | 358 | assert bitmask.size(0) * stride == im_h 359 | assert bitmask.size(1) * stride == im_w 360 | 361 | per_im_bitmasks.append(bitmask) 362 | per_im_bitmasks_full.append(bitmask_full) 363 | 364 | per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) 365 | per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) 366 | per_im_gt_inst.image_color_similarity = torch.cat([ 367 | images_color_similarity for _ in range(len(per_im_gt_inst)) 368 | ], dim=0) 369 | 370 | def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5): 371 | """ 372 | Resize the output instances. 373 | The input images are often resized when entering an object detector. 374 | As a result, we often need the outputs of the detector in a different 375 | resolution from its inputs. 376 | This function will resize the raw outputs of an R-CNN detector 377 | to produce outputs according to the desired output resolution. 378 | Args: 379 | results (Instances): the raw outputs from the detector. 380 | `results.image_size` contains the input image resolution the detector sees. 381 | This object might be modified in-place. 382 | output_height, output_width: the desired output resolution. 383 | Returns: 384 | Instances: the resized output from the model, based on the output resolution 385 | """ 386 | scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) 387 | resized_im_h, resized_im_w = results.image_size 388 | results = Instances((output_height, output_width), **results.get_fields()) 389 | 390 | if results.has("pred_boxes"): 391 | output_boxes = results.pred_boxes 392 | elif results.has("proposal_boxes"): 393 | output_boxes = results.proposal_boxes 394 | 395 | output_boxes.scale(scale_x, scale_y) 396 | output_boxes.clip(results.image_size) 397 | 398 | results = results[output_boxes.nonempty()] 399 | 400 | if results.has("pred_global_masks"): 401 | mask_h, mask_w = results.pred_global_masks.size()[-2:] 402 | factor_h = padded_im_h // mask_h 403 | factor_w = padded_im_w // mask_w 404 | assert factor_h == factor_w 405 | factor = factor_h 406 | pred_global_masks = aligned_bilinear( 407 | results.pred_global_masks, factor 408 | ) 409 | pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w] 410 | pred_global_masks = F.interpolate( 411 | pred_global_masks, 412 | size=(output_height, output_width), 413 | mode="bilinear", align_corners=False 414 | ) 415 | pred_global_masks = pred_global_masks[:, 0, :, :] 416 | 417 | if self.point_sup_enabled: 418 | # filter out any mask prediction outside of predicted boxes (see PointSup) 419 | pred_boxes = results.pred_boxes.tensor 420 | for i in range(pred_global_masks.size(0)): 421 | kept_mask = torch.zeros_like(pred_global_masks[0]).to(pred_boxes.device) 422 | x0,y0,x1,y1 = int(pred_boxes[i][0]),int(pred_boxes[i][1]),int(pred_boxes[i][2]),int(pred_boxes[i][3]) 423 | kept_mask[y0:y1, x0:x1] = 1 424 | pred_global_masks[i] *= kept_mask 425 | 426 | results.pred_masks = (pred_global_masks > mask_threshold).float() 427 | 428 | return results 429 | --------------------------------------------------------------------------------